diff --git a/demo/nas/sa_nas_mobilenetv2.py b/demo/nas/sa_nas_mobilenetv2.py index 98153b707d9a2f78e163d0f800501aa79eda3649..142c2c08f09e7888ab255b1d6ce762a50c8e1966 100644 --- a/demo/nas/sa_nas_mobilenetv2.py +++ b/demo/nas/sa_nas_mobilenetv2.py @@ -53,19 +53,27 @@ def build_program(main_program, return data_loader, avg_cost, acc_top1, acc_top5 -def search_mobilenetv2(config, args, image_size): +def search_mobilenetv2(config, args, image_size, is_server=True): factory = SearchSpaceFactory() space = factory.get_search_space(config) - ### start a server and a client - sa_nas = SANAS( - config, - server_addr=("", 8883), - init_temperature=args.init_temperature, - reduce_rate=args.reduce_rate, - search_steps=args.search_steps, - is_server=True) - ### start a client - #sa_nas = SANAS(config, server_addr=("10.255.125.38", 8889), init_temperature=args.init_temperature, reduce_rate=args.reduce_rate, search_steps=args.search_steps, is_server=True) + if is_server: + ### start a server and a client + sa_nas = SANAS( + config, + server_addr=("", 8883), + init_temperature=args.init_temperature, + reduce_rate=args.reduce_rate, + search_steps=args.search_steps, + is_server=True) + else: + ### start a client + sa_nas = SANAS( + config, + server_addr=("10.255.125.38", 8883), + init_temperature=args.init_temperature, + reduce_rate=args.reduce_rate, + search_steps=args.search_steps, + is_server=False) image_shape = [3, image_size, image_size] for step in range(args.search_steps): @@ -122,22 +130,22 @@ def search_mobilenetv2(config, args, image_size): places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) test_loader.set_sample_list_generator(test_reader, places=place) - #build_strategy = fluid.BuildStrategy() - #train_compiled_program = fluid.CompiledProgram( - # train_program).with_data_parallel( - # loss_name=avg_cost.name, build_strategy=build_strategy) - #for epoch_id in range(args.retain_epoch): - # for batch_id, data in enumerate(train_loader()): - # fetches = [avg_cost.name] - # s_time = time.time() - # outs = exe.run(train_compiled_program, - # feed=data, - # fetch_list=fetches)[0] - # batch_time = time.time() - s_time - # if batch_id % 10 == 0: - # _logger.info( - # 'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'. - # format(step, epoch_id, batch_id, outs[0], batch_time)) + build_strategy = fluid.BuildStrategy() + train_compiled_program = fluid.CompiledProgram( + train_program).with_data_parallel( + loss_name=avg_cost.name, build_strategy=build_strategy) + for epoch_id in range(args.retain_epoch): + for batch_id, data in enumerate(train_loader()): + fetches = [avg_cost.name] + s_time = time.time() + outs = exe.run(train_compiled_program, + feed=data, + fetch_list=fetches)[0] + batch_time = time.time() - s_time + if batch_id % 10 == 0: + _logger.info( + 'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'. + format(step, epoch_id, batch_id, outs[0], batch_time)) reward = [] for batch_id, data in enumerate(test_loader()): @@ -158,7 +166,7 @@ def search_mobilenetv2(config, args, image_size): finally_reward = np.mean(np.array(reward), axis=0) _logger.info( 'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format( - step, finally_reward[0], finally_reward[1], finally_reward[2])) + finally_reward[0], finally_reward[1], finally_reward[2])) sa_nas.reward(float(finally_reward[1])) @@ -188,6 +196,11 @@ if __name__ == '__main__': type=float, default=10.24, help='init temperature.') + parser.add_argument( + '--is_server', + type=ast.literal_eval, + default=True, + help='Whether to start a server.') # nas args parser.add_argument( '--max_flops', type=int, default=592948064, help='reduce rate.') @@ -260,4 +273,4 @@ if __name__ == '__main__': } config = [('MobileNetV2Space', config_info)] - search_mobilenetv2(config, args, image_size) + search_mobilenetv2(config, args, image_size, is_server=args.is_server) diff --git a/paddleslim/nas/search_space/mobilenetv1.py b/paddleslim/nas/search_space/mobilenetv1.py index cfa704552ddf9c5ae37851fc628caa5663aa2a69..3976d21df1e3ad2c5ac344dab59ad32adeaedb79 100644 --- a/paddleslim/nas/search_space/mobilenetv1.py +++ b/paddleslim/nas/search_space/mobilenetv1.py @@ -32,6 +32,7 @@ class MobileNetV1Space(SearchSpaceBase): input_size, output_size, block_num, + block_mask, scale=1.0, class_dim=1000): super(MobileNetV1Space, self).__init__(input_size, output_size,