diff --git a/demo/nas/sa_nas_mobilenetv2.py b/demo/nas/sa_nas_mobilenetv2.py index 9c9fe9a40e0d1a18afebc0f1d9c45e16e2e33639..f7c898b1f7f31c55e5e0336739158721235c8c70 100644 --- a/demo/nas/sa_nas_mobilenetv2.py +++ b/demo/nas/sa_nas_mobilenetv2.py @@ -31,13 +31,35 @@ def create_data_loader(image_shape): return data_loader, data, label +def build_program(main_program, + startup_program, + image_shape, + archs, + args, + is_test=False): + with fluid.program_guard(main_program, startup_program): + data_loader, data, label = create_data_loader(image_shape) + output = archs(data) + + softmax_out = fluid.layers.softmax(input=output, use_cudnn=False) + cost = fluid.layers.cross_entropy(input=softmax_out, label=label) + avg_cost = fluid.layers.mean(cost) + acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5) + + if is_test == False: + optimizer = create_optimizer(args) + optimizer.minimize(avg_cost) + return data_loader, avg_cost, acc_top1, acc_top5 + + def search_mobilenetv2(config, args, image_size): factory = SearchSpaceFactory() space = factory.get_search_space(config) ### start a server and a client sa_nas = SANAS( config, - server_addr=("", 8889), + server_addr=("", 8883), init_temperature=args.init_temperature, reduce_rate=args.reduce_rate, search_steps=args.search_steps, @@ -52,26 +74,22 @@ def search_mobilenetv2(config, args, image_size): train_program = fluid.Program() test_program = fluid.Program() startup_program = fluid.Program() - with fluid.program_guard(train_program, startup_program): - train_loader, data, label = create_data_loader(image_shape) - output = archs(data) - current_flops = flops(train_program) - print('step: {}, current_flops: {}'.format(step, current_flops)) - if current_flops > args.max_flops: - continue - - softmax_out = fluid.layers.softmax(input=output, use_cudnn=False) - cost = fluid.layers.cross_entropy(input=softmax_out, label=label) - avg_cost = fluid.layers.mean(cost) - acc_top1 = fluid.layers.accuracy( - input=softmax_out, label=label, k=1) - acc_top5 = fluid.layers.accuracy( - input=softmax_out, label=label, k=5) - - test_program = train_program.clone(for_test=True) + train_loader, avg_cost, acc_top1, acc_top5 = build_program( + train_program, startup_program, image_shape, archs, args) - optimizer = create_optimizer(args) - optimizer.minimize(avg_cost) + current_flops = flops(train_program) + print('step: {}, current_flops: {}'.format(step, current_flops)) + if current_flops > args.max_flops: + continue + + test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program( + test_program, + startup_program, + image_shape, + archs, + args, + is_test=True) + test_program = test_program.clone(for_test=True) place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) @@ -98,7 +116,7 @@ def search_mobilenetv2(config, args, image_size): batch_size=args.batch_size, drop_last=False) - test_loader, _, _ = create_data_loader(image_shape) + #test_loader, _, _ = create_data_loader(image_shape) train_loader.set_sample_list_generator( train_reader, places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) @@ -106,25 +124,44 @@ def search_mobilenetv2(config, args, image_size): test_reader, places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) + build_strategy = fluid.BuildStrategy() + train_compiled_program = fluid.CompiledProgram( + train_program).with_data_parallel( + loss_name=avg_cost.name, build_strategy=build_strategy) for epoch_id in range(args.retain_epoch): for batch_id, data in enumerate(train_loader()): fetches = [avg_cost.name] s_time = time.time() - outs = exe.run(train_program, feed=data, fetch_list=fetches)[0] + outs = exe.run(train_compiled_program, + feed=data, + fetch_list=fetches)[0] batch_time = time.time() - s_time if batch_id % 10 == 0: _logger.info( 'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'. format(step, epoch_id, batch_id, outs[0], batch_time)) - for data in test_loader(): - test_fetches = [avg_cost.name, acc_top1.name, acc_top5.name] - reward = exe.run(test_program, feed=data, fetch_list=fetches)[0] + reward = [] + for batch_id, data in enumerate(test_loader()): + test_fetches = [ + test_avg_cost.name, test_acc_top1.name, test_acc_top5.name + ] + batch_reward = exe.run(test_program, + feed=data, + fetch_list=test_fetches) + reward_avg = np.mean(np.array(batch_reward), axis=1) + reward.append(reward_avg) + + _logger.info( + 'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'. + format(step, test_outs[0], test_outs[1], test_outs[2])) + + finally_reward = np.mean(np.array(reward), axis=0) _logger.info( - 'TEST: step: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'.format( - step, test_outs[0], test_outs[1], test_outs[2])) + 'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format( + step, finally_reward[0], finally_reward[1], finally_reward[2])) - sa_nas.reward(float(avg_cost)) + sa_nas.reward(float(finally_reward[1])) if __name__ == '__main__':