提交 2e07714f 编写于 作者: L lvmengsi

Merge branch 'fix_nas' into 'develop'

fix nas demo

See merge request !47
...@@ -53,19 +53,27 @@ def build_program(main_program, ...@@ -53,19 +53,27 @@ def build_program(main_program,
return data_loader, avg_cost, acc_top1, acc_top5 return data_loader, avg_cost, acc_top1, acc_top5
def search_mobilenetv2(config, args, image_size): def search_mobilenetv2(config, args, image_size, is_server=True):
factory = SearchSpaceFactory() factory = SearchSpaceFactory()
space = factory.get_search_space(config) space = factory.get_search_space(config)
### start a server and a client if is_server:
sa_nas = SANAS( ### start a server and a client
config, sa_nas = SANAS(
server_addr=("", 8883), config,
init_temperature=args.init_temperature, server_addr=("", 8883),
reduce_rate=args.reduce_rate, init_temperature=args.init_temperature,
search_steps=args.search_steps, reduce_rate=args.reduce_rate,
is_server=True) search_steps=args.search_steps,
### start a client is_server=True)
#sa_nas = SANAS(config, server_addr=("10.255.125.38", 8889), init_temperature=args.init_temperature, reduce_rate=args.reduce_rate, search_steps=args.search_steps, is_server=True) else:
### start a client
sa_nas = SANAS(
config,
server_addr=("10.255.125.38", 8883),
init_temperature=args.init_temperature,
reduce_rate=args.reduce_rate,
search_steps=args.search_steps,
is_server=False)
image_shape = [3, image_size, image_size] image_shape = [3, image_size, image_size]
for step in range(args.search_steps): for step in range(args.search_steps):
...@@ -122,22 +130,22 @@ def search_mobilenetv2(config, args, image_size): ...@@ -122,22 +130,22 @@ def search_mobilenetv2(config, args, image_size):
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
test_loader.set_sample_list_generator(test_reader, places=place) test_loader.set_sample_list_generator(test_reader, places=place)
#build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
#train_compiled_program = fluid.CompiledProgram( train_compiled_program = fluid.CompiledProgram(
# train_program).with_data_parallel( train_program).with_data_parallel(
# loss_name=avg_cost.name, build_strategy=build_strategy) loss_name=avg_cost.name, build_strategy=build_strategy)
#for epoch_id in range(args.retain_epoch): for epoch_id in range(args.retain_epoch):
# for batch_id, data in enumerate(train_loader()): for batch_id, data in enumerate(train_loader()):
# fetches = [avg_cost.name] fetches = [avg_cost.name]
# s_time = time.time() s_time = time.time()
# outs = exe.run(train_compiled_program, outs = exe.run(train_compiled_program,
# feed=data, feed=data,
# fetch_list=fetches)[0] fetch_list=fetches)[0]
# batch_time = time.time() - s_time batch_time = time.time() - s_time
# if batch_id % 10 == 0: if batch_id % 10 == 0:
# _logger.info( _logger.info(
# 'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'. 'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
# format(step, epoch_id, batch_id, outs[0], batch_time)) format(step, epoch_id, batch_id, outs[0], batch_time))
reward = [] reward = []
for batch_id, data in enumerate(test_loader()): for batch_id, data in enumerate(test_loader()):
...@@ -158,7 +166,7 @@ def search_mobilenetv2(config, args, image_size): ...@@ -158,7 +166,7 @@ def search_mobilenetv2(config, args, image_size):
finally_reward = np.mean(np.array(reward), axis=0) finally_reward = np.mean(np.array(reward), axis=0)
_logger.info( _logger.info(
'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format( 'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
step, finally_reward[0], finally_reward[1], finally_reward[2])) finally_reward[0], finally_reward[1], finally_reward[2]))
sa_nas.reward(float(finally_reward[1])) sa_nas.reward(float(finally_reward[1]))
...@@ -188,6 +196,11 @@ if __name__ == '__main__': ...@@ -188,6 +196,11 @@ if __name__ == '__main__':
type=float, type=float,
default=10.24, default=10.24,
help='init temperature.') help='init temperature.')
parser.add_argument(
'--is_server',
type=ast.literal_eval,
default=True,
help='Whether to start a server.')
# nas args # nas args
parser.add_argument( parser.add_argument(
'--max_flops', type=int, default=592948064, help='reduce rate.') '--max_flops', type=int, default=592948064, help='reduce rate.')
...@@ -260,4 +273,4 @@ if __name__ == '__main__': ...@@ -260,4 +273,4 @@ if __name__ == '__main__':
} }
config = [('MobileNetV2Space', config_info)] config = [('MobileNetV2Space', config_info)]
search_mobilenetv2(config, args, image_size) search_mobilenetv2(config, args, image_size, is_server=args.is_server)
...@@ -32,6 +32,7 @@ class MobileNetV1Space(SearchSpaceBase): ...@@ -32,6 +32,7 @@ class MobileNetV1Space(SearchSpaceBase):
input_size, input_size,
output_size, output_size,
block_num, block_num,
block_mask,
scale=1.0, scale=1.0,
class_dim=1000): class_dim=1000):
super(MobileNetV1Space, self).__init__(input_size, output_size, super(MobileNetV1Space, self).__init__(input_size, output_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册