提交 77586c67 编写于 作者: C ceci3

update

上级 faa73513
...@@ -16,6 +16,14 @@ import imagenet_reader ...@@ -16,6 +16,14 @@ import imagenet_reader
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
reduce_rate = 0.85
init_temperature = 10.24
max_flops = 321208544
server_address = ""
port = 8909
retain_epoch = 5
def create_data_loader(image_shape): def create_data_loader(image_shape):
data_shape = [-1] + image_shape data_shape = [-1] + image_shape
data = fluid.data(name='data', shape=data_shape, dtype='float32') data = fluid.data(name='data', shape=data_shape, dtype='float32')
...@@ -27,6 +35,7 @@ def create_data_loader(image_shape): ...@@ -27,6 +35,7 @@ def create_data_loader(image_shape):
iterable=True) iterable=True)
return data_loader, data, label return data_loader, data, label
def conv_bn_layer(input, def conv_bn_layer(input,
filter_size, filter_size,
num_filters, num_filters,
...@@ -50,20 +59,31 @@ def conv_bn_layer(input, ...@@ -50,20 +59,31 @@ def conv_bn_layer(input,
bn_name = name + '_bn' bn_name = name + '_bn'
return fluid.layers.batch_norm( return fluid.layers.batch_norm(
input=conv, input=conv,
act = act, act=act,
param_attr=ParamAttr(name=bn_name + '_scale'), param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(name=bn_name + '_offset'), bias_attr=ParamAttr(name=bn_name + '_offset'),
moving_mean_name=bn_name + '_mean', moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance') moving_variance_name=bn_name + '_variance')
def search_mobilenetv2_block(config, args, image_size): def search_mobilenetv2_block(config, args, image_size):
image_shape = [3, image_size, image_size] image_shape = [3, image_size, image_size]
if args.is_server: if args.is_server:
sa_nas = SANAS(config, server_addr=("", args.port), init_temperature=args.init_temperature, reduce_rate=args.reduce_rate, search_steps=args.search_steps, is_server=True) sa_nas = SANAS(
config,
server_addr=("", port),
init_temperature=init_temperature,
reduce_rate=reduce_rate,
search_steps=args.search_steps,
is_server=True)
else: else:
sa_nas = SANAS(config, server_addr=(args.server_address, args.port), init_temperature=args.init_temperature, reduce_rate=args.reduce_rate, search_steps=args.search_steps, is_server=False) sa_nas = SANAS(
config,
server_addr=(server_address, port),
init_temperature=init_temperature,
reduce_rate=reduce_rate,
search_steps=args.search_steps,
is_server=False)
for step in range(args.search_steps): for step in range(args.search_steps):
archs = sa_nas.next_archs()[0] archs = sa_nas.next_archs()[0]
...@@ -73,10 +93,30 @@ def search_mobilenetv2_block(config, args, image_size): ...@@ -73,10 +93,30 @@ def search_mobilenetv2_block(config, args, image_size):
startup_program = fluid.Program() startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
train_loader, data, label = create_data_loader(image_shape) train_loader, data, label = create_data_loader(image_shape)
data = conv_bn_layer(input=data, num_filters=32, filter_size=3, stride=2, padding='SAME', act='relu6', name='mobilenetv2_conv1') data = conv_bn_layer(
input=data,
num_filters=32,
filter_size=3,
stride=2,
padding='SAME',
act='relu6',
name='mobilenetv2_conv1')
data = archs(data)[0] data = archs(data)[0]
data = conv_bn_layer(input=data, num_filters=1280, filter_size=1, stride=1, padding='SAME', act='relu6', name='mobilenetv2_last_conv') data = conv_bn_layer(
data = fluid.layers.pool2d(input=data, pool_size=7, pool_stride=1, pool_type='avg', global_pooling=True, name='mobilenetv2_last_pool') input=data,
num_filters=1280,
filter_size=1,
stride=1,
padding='SAME',
act='relu6',
name='mobilenetv2_last_conv')
data = fluid.layers.pool2d(
input=data,
pool_size=7,
pool_stride=1,
pool_type='avg',
global_pooling=True,
name='mobilenetv2_last_pool')
output = fluid.layers.fc( output = fluid.layers.fc(
input=data, input=data,
size=args.class_dim, size=args.class_dim,
...@@ -86,8 +126,10 @@ def search_mobilenetv2_block(config, args, image_size): ...@@ -86,8 +126,10 @@ def search_mobilenetv2_block(config, args, image_size):
softmax_out = fluid.layers.softmax(input=output, use_cudnn=False) softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=softmax_out, label=label) cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1) acc_top1 = fluid.layers.accuracy(
acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5) input=softmax_out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(
input=softmax_out, label=label, k=5)
test_program = train_program.clone(for_test=True) test_program = train_program.clone(for_test=True)
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
...@@ -98,7 +140,7 @@ def search_mobilenetv2_block(config, args, image_size): ...@@ -98,7 +140,7 @@ def search_mobilenetv2_block(config, args, image_size):
current_flops = flops(train_program) current_flops = flops(train_program)
print('step: {}, current_flops: {}'.format(step, current_flops)) print('step: {}, current_flops: {}'.format(step, current_flops))
if current_flops > args.max_flops: if current_flops > max_flops:
continue continue
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
...@@ -132,12 +174,11 @@ def search_mobilenetv2_block(config, args, image_size): ...@@ -132,12 +174,11 @@ def search_mobilenetv2_block(config, args, image_size):
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
test_loader.set_sample_list_generator(test_reader, places=place) test_loader.set_sample_list_generator(test_reader, places=place)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
train_compiled_program = fluid.CompiledProgram( train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel( train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy) loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(args.retain_epoch): for epoch_id in range(retain_epoch):
for batch_id, data in enumerate(train_loader()): for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name] fetches = [avg_cost.name]
s_time = time.time() s_time = time.time()
...@@ -152,9 +193,7 @@ def search_mobilenetv2_block(config, args, image_size): ...@@ -152,9 +193,7 @@ def search_mobilenetv2_block(config, args, image_size):
reward = [] reward = []
for batch_id, data in enumerate(test_loader()): for batch_id, data in enumerate(test_loader()):
test_fetches = [ test_fetches = [avg_cost.name, acc_top1.name, acc_top5.name]
avg_cost.name, acc_top1.name, acc_top5.name
]
batch_reward = exe.run(test_program, batch_reward = exe.run(test_program,
feed=data, feed=data,
fetch_list=test_fetches) fetch_list=test_fetches)
...@@ -173,6 +212,7 @@ def search_mobilenetv2_block(config, args, image_size): ...@@ -173,6 +212,7 @@ def search_mobilenetv2_block(config, args, image_size):
sa_nas.reward(float(finally_reward[1])) sa_nas.reward(float(finally_reward[1]))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -191,36 +231,18 @@ if __name__ == '__main__': ...@@ -191,36 +231,18 @@ if __name__ == '__main__':
type=str, type=str,
default='cifar10', default='cifar10',
choices=['cifar10', 'imagenet'], choices=['cifar10', 'imagenet'],
help='server address.') help='dataset name.')
# controller
parser.add_argument(
'--reduce_rate', type=float, default=0.85, help='reduce rate.')
parser.add_argument(
'--init_temperature',
type=float,
default=10.24,
help='init temperature.')
parser.add_argument( parser.add_argument(
'--is_server', '--is_server',
type=ast.literal_eval, type=ast.literal_eval,
default=True, default=True,
help='Whether to start a server.') help='Whether to start a server.')
# nas args # nas args
parser.add_argument(
'--max_flops', type=int, default=592948064, help='reduce rate.')
parser.add_argument(
'--retain_epoch', type=int, default=5, help='train epoch before val.')
parser.add_argument(
'--end_epoch', type=int, default=500, help='end epoch present client.')
parser.add_argument( parser.add_argument(
'--search_steps', '--search_steps',
type=int, type=int,
default=100, default=100,
help='controller server number.') help='controller server number.')
parser.add_argument(
'--server_address', type=str, default=None, help='server address.')
parser.add_argument(
'--port', type=int, default=8889, help='server port.')
# optimizer args # optimizer args
parser.add_argument( parser.add_argument(
'--lr_strategy', '--lr_strategy',
...@@ -265,17 +287,12 @@ if __name__ == '__main__': ...@@ -265,17 +287,12 @@ if __name__ == '__main__':
elif args.data == 'imagenet': elif args.data == 'imagenet':
image_size = 224 image_size = 224
else: else:
raise NotImplemented( raise NotImplementedError(
'data must in [cifar10, imagenet], but received: {}'.format( 'data must in [cifar10, imagenet], but received: {}'.format(
args.data)) args.data))
# block mask means block number, 1 mean downsample, 0 means the size of feature map don't change after this block # block mask means block number, 1 mean downsample, 0 means the size of feature map don't change after this block
config_info = { config_info = {'block_mask': [0, 1, 1, 1, 1, 0, 1, 0]}
'input_size': None,
'output_size': None,
'block_num': None,
'block_mask': [0, 1, 1, 1, 1, 0, 1, 0]
}
config = [('MobileNetV2BlockSpace', config_info)] config = [('MobileNetV2BlockSpace', config_info)]
search_mobilenetv2_block(config, args, image_size) search_mobilenetv2_block(config, args, image_size)
...@@ -18,6 +18,13 @@ import imagenet_reader ...@@ -18,6 +18,13 @@ import imagenet_reader
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
reduce_rate = 0.85
init_temperature = 10.24
max_flops = 321208544
server_address = ""
port = 8909
retain_epoch = 5
def create_data_loader(image_shape): def create_data_loader(image_shape):
data_shape = [-1] + image_shape data_shape = [-1] + image_shape
...@@ -40,7 +47,11 @@ def build_program(main_program, ...@@ -40,7 +47,11 @@ def build_program(main_program,
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
data_loader, data, label = create_data_loader(image_shape) data_loader, data, label = create_data_loader(image_shape)
output = archs(data) output = archs(data)
output = fluid.layers.fc(input=output, size=args.class_dim, param_attr=ParamAttr(name='mobilenetv2_fc_weights'), bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) output = fluid.layers.fc(
input=output,
size=args.class_dim,
param_attr=ParamAttr(name='mobilenetv2_fc_weights'),
bias_attr=ParamAttr(name='mobilenetv2_fc_offset'))
softmax_out = fluid.layers.softmax(input=output, use_cudnn=False) softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=softmax_out, label=label) cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
...@@ -59,18 +70,18 @@ def search_mobilenetv2(config, args, image_size, is_server=True): ...@@ -59,18 +70,18 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
### start a server and a client ### start a server and a client
sa_nas = SANAS( sa_nas = SANAS(
config, config,
server_addr=("", args.port), server_addr=("", port),
init_temperature=args.init_temperature, init_temperature=init_temperature,
reduce_rate=args.reduce_rate, reduce_rate=reduce_rate,
search_steps=args.search_steps, search_steps=args.search_steps,
is_server=True) is_server=True)
else: else:
### start a client ### start a client
sa_nas = SANAS( sa_nas = SANAS(
config, config,
server_addr=(args.server_address, args.port), server_addr=(server_address, port),
init_temperature=args.init_temperature, init_temperature=init_temperature,
reduce_rate=args.reduce_rate, reduce_rate=reduce_rate,
search_steps=args.search_steps, search_steps=args.search_steps,
is_server=False) is_server=False)
...@@ -86,7 +97,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True): ...@@ -86,7 +97,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
current_flops = flops(train_program) current_flops = flops(train_program)
print('step: {}, current_flops: {}'.format(step, current_flops)) print('step: {}, current_flops: {}'.format(step, current_flops))
if current_flops > args.max_flops: if current_flops > max_flops:
continue continue
test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program( test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program(
...@@ -123,7 +134,6 @@ def search_mobilenetv2(config, args, image_size, is_server=True): ...@@ -123,7 +134,6 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
batch_size=args.batch_size, batch_size=args.batch_size,
drop_last=False) drop_last=False)
#test_loader, _, _ = create_data_loader(image_shape)
train_loader.set_sample_list_generator( train_loader.set_sample_list_generator(
train_reader, train_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
...@@ -133,7 +143,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True): ...@@ -133,7 +143,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
train_compiled_program = fluid.CompiledProgram( train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel( train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy) loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(args.retain_epoch): for epoch_id in range(retain_epoch):
for batch_id, data in enumerate(train_loader()): for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name] fetches = [avg_cost.name]
s_time = time.time() s_time = time.time()
...@@ -170,6 +180,99 @@ def search_mobilenetv2(config, args, image_size, is_server=True): ...@@ -170,6 +180,99 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
sa_nas.reward(float(finally_reward[1])) sa_nas.reward(float(finally_reward[1]))
def test_search_result(tokens, image_size, args, config):
sa_nas = SANAS(
config,
server_addr=("", 8887),
init_temperature=args.init_temperature,
reduce_rate=args.reduce_rate,
search_steps=args.search_steps,
is_server=True)
image_shape = [3, image_size, image_size]
archs = sa_nas.tokens2arch(tokens)
train_program = fluid.Program()
test_program = fluid.Program()
startup_program = fluid.Program()
train_loader, avg_cost, acc_top1, acc_top5 = build_program(
train_program, startup_program, image_shape, archs, args)
current_flops = flops(train_program)
print('current_flops: {}'.format(current_flops))
test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program(
test_program, startup_program, image_shape, archs, args, is_test=True)
test_program = test_program.clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
if args.data == 'cifar10':
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(cycle=False), buf_size=1024),
batch_size=args.batch_size,
drop_last=True)
test_reader = paddle.batch(
paddle.dataset.cifar.test10(cycle=False),
batch_size=args.batch_size,
drop_last=False)
elif args.data == 'imagenet':
train_reader = paddle.batch(
imagenet_reader.train(),
batch_size=args.batch_size,
drop_last=True)
test_reader = paddle.batch(
imagenet_reader.val(), batch_size=args.batch_size, drop_last=False)
train_loader.set_sample_list_generator(
train_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
test_loader.set_sample_list_generator(test_reader, places=place)
build_strategy = fluid.BuildStrategy()
train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(retain_epoch):
for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name]
s_time = time.time()
outs = exe.run(train_compiled_program,
feed=data,
fetch_list=fetches)[0]
batch_time = time.time() - s_time
if batch_id % 10 == 0:
_logger.info(
'TRAIN: epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
format(epoch_id, batch_id, outs[0], batch_time))
reward = []
for batch_id, data in enumerate(test_loader()):
test_fetches = [
test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
]
batch_reward = exe.run(test_program,
feed=data,
fetch_list=test_fetches)
reward_avg = np.mean(np.array(batch_reward), axis=1)
reward.append(reward_avg)
_logger.info(
'TEST: batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'.
format(batch_id, batch_reward[0], batch_reward[1],
batch_reward[2]))
finally_reward = np.mean(np.array(reward), axis=0)
_logger.info(
'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
finally_reward[0], finally_reward[1], finally_reward[2]))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -187,46 +290,26 @@ if __name__ == '__main__': ...@@ -187,46 +290,26 @@ if __name__ == '__main__':
default='cifar10', default='cifar10',
choices=['cifar10', 'imagenet'], choices=['cifar10', 'imagenet'],
help='server address.') help='server address.')
# controller
parser.add_argument(
'--reduce_rate', type=float, default=0.85, help='reduce rate.')
parser.add_argument(
'--init_temperature',
type=float,
default=10.24,
help='init temperature.')
parser.add_argument( parser.add_argument(
'--is_server', '--is_server',
type=ast.literal_eval, type=ast.literal_eval,
default=True, default=True,
help='Whether to start a server.') help='Whether to start a server.')
# nas args
parser.add_argument(
'--max_flops', type=int, default=592948064, help='reduce rate.')
parser.add_argument(
'--retain_epoch', type=int, default=5, help='train epoch before val.')
parser.add_argument(
'--end_epoch', type=int, default=500, help='end epoch present client.')
parser.add_argument( parser.add_argument(
'--search_steps', '--search_steps',
type=int, type=int,
default=100, default=100,
help='controller server number.') help='controller server number.')
parser.add_argument(
'--server_address', type=str, default=None, help='server address.')
parser.add_argument(
'--port', type=int, default=8889, help='server port.')
# optimizer args
parser.add_argument( parser.add_argument(
'--lr_strategy', '--lr_strategy',
type=str, type=str,
default='piecewise_decay', default='cosine_decay',
help='learning rate decay strategy.') help='learning rate decay strategy.')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate.') parser.add_argument('--lr', type=float, default=0.1, help='learning rate.')
parser.add_argument( parser.add_argument(
'--l2_decay', type=float, default=1e-4, help='learning rate decay.') '--l2_decay', type=float, default=1e-4, help='learning rate decay.')
parser.add_argument( parser.add_argument(
'--class_dim', type=int, default=1000, help='classify number.') '--class_dim', type=int, default=100, help='classify number.')
parser.add_argument( parser.add_argument(
'--step_epochs', '--step_epochs',
nargs='+', nargs='+',
...@@ -264,7 +347,7 @@ if __name__ == '__main__': ...@@ -264,7 +347,7 @@ if __name__ == '__main__':
image_size = 224 image_size = 224
block_num = 6 block_num = 6
else: else:
raise NotImplemented( raise NotImplementedError(
'data must in [cifar10, imagenet], but received: {}'.format( 'data must in [cifar10, imagenet], but received: {}'.format(
args.data)) args.data))
......
...@@ -236,9 +236,9 @@ class MobileNetV1Space(SearchSpaceBase): ...@@ -236,9 +236,9 @@ class MobileNetV1Space(SearchSpaceBase):
depthwise_conv = conv_bn_layer( depthwise_conv = conv_bn_layer(
input=input, input=input,
filter_size=kernel_size, filter_size=kernel_size,
num_filters=int(num_filters1 * scale), num_filters=output_channel,
stride=stride, stride=stride,
num_groups=int(num_groups * scale), num_groups=num_groups,
use_cudnn=False, use_cudnn=False,
name=name + '_dw') name=name + '_dw')
pointwise_conv = conv_bn_layer( pointwise_conv = conv_bn_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册