提交 5f44fc78 编写于 作者: C ceci3

update demo

上级 de55821b
...@@ -31,13 +31,35 @@ def create_data_loader(image_shape): ...@@ -31,13 +31,35 @@ def create_data_loader(image_shape):
return data_loader, data, label return data_loader, data, label
def build_program(main_program,
startup_program,
image_shape,
archs,
args,
is_test=False):
with fluid.program_guard(main_program, startup_program):
data_loader, data, label = create_data_loader(image_shape)
output = archs(data)
softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)
if is_test == False:
optimizer = create_optimizer(args)
optimizer.minimize(avg_cost)
return data_loader, avg_cost, acc_top1, acc_top5
def search_mobilenetv2(config, args, image_size): def search_mobilenetv2(config, args, image_size):
factory = SearchSpaceFactory() factory = SearchSpaceFactory()
space = factory.get_search_space(config) space = factory.get_search_space(config)
### start a server and a client ### start a server and a client
sa_nas = SANAS( sa_nas = SANAS(
config, config,
server_addr=("", 8889), server_addr=("", 8883),
init_temperature=args.init_temperature, init_temperature=args.init_temperature,
reduce_rate=args.reduce_rate, reduce_rate=args.reduce_rate,
search_steps=args.search_steps, search_steps=args.search_steps,
...@@ -52,26 +74,22 @@ def search_mobilenetv2(config, args, image_size): ...@@ -52,26 +74,22 @@ def search_mobilenetv2(config, args, image_size):
train_program = fluid.Program() train_program = fluid.Program()
test_program = fluid.Program() test_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program): train_loader, avg_cost, acc_top1, acc_top5 = build_program(
train_loader, data, label = create_data_loader(image_shape) train_program, startup_program, image_shape, archs, args)
output = archs(data)
current_flops = flops(train_program)
print('step: {}, current_flops: {}'.format(step, current_flops))
if current_flops > args.max_flops:
continue
softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
acc_top1 = fluid.layers.accuracy(
input=softmax_out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(
input=softmax_out, label=label, k=5)
test_program = train_program.clone(for_test=True)
optimizer = create_optimizer(args) current_flops = flops(train_program)
optimizer.minimize(avg_cost) print('step: {}, current_flops: {}'.format(step, current_flops))
if current_flops > args.max_flops:
continue
test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program(
test_program,
startup_program,
image_shape,
archs,
args,
is_test=True)
test_program = test_program.clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -98,7 +116,7 @@ def search_mobilenetv2(config, args, image_size): ...@@ -98,7 +116,7 @@ def search_mobilenetv2(config, args, image_size):
batch_size=args.batch_size, batch_size=args.batch_size,
drop_last=False) drop_last=False)
test_loader, _, _ = create_data_loader(image_shape) #test_loader, _, _ = create_data_loader(image_shape)
train_loader.set_sample_list_generator( train_loader.set_sample_list_generator(
train_reader, train_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
...@@ -106,25 +124,44 @@ def search_mobilenetv2(config, args, image_size): ...@@ -106,25 +124,44 @@ def search_mobilenetv2(config, args, image_size):
test_reader, test_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
build_strategy = fluid.BuildStrategy()
train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(args.retain_epoch): for epoch_id in range(args.retain_epoch):
for batch_id, data in enumerate(train_loader()): for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name] fetches = [avg_cost.name]
s_time = time.time() s_time = time.time()
outs = exe.run(train_program, feed=data, fetch_list=fetches)[0] outs = exe.run(train_compiled_program,
feed=data,
fetch_list=fetches)[0]
batch_time = time.time() - s_time batch_time = time.time() - s_time
if batch_id % 10 == 0: if batch_id % 10 == 0:
_logger.info( _logger.info(
'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'. 'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
format(step, epoch_id, batch_id, outs[0], batch_time)) format(step, epoch_id, batch_id, outs[0], batch_time))
for data in test_loader(): reward = []
test_fetches = [avg_cost.name, acc_top1.name, acc_top5.name] for batch_id, data in enumerate(test_loader()):
reward = exe.run(test_program, feed=data, fetch_list=fetches)[0] test_fetches = [
test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
]
batch_reward = exe.run(test_program,
feed=data,
fetch_list=test_fetches)
reward_avg = np.mean(np.array(batch_reward), axis=1)
reward.append(reward_avg)
_logger.info(
'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'.
format(step, test_outs[0], test_outs[1], test_outs[2]))
finally_reward = np.mean(np.array(reward), axis=0)
_logger.info( _logger.info(
'TEST: step: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'.format( 'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
step, test_outs[0], test_outs[1], test_outs[2])) step, finally_reward[0], finally_reward[1], finally_reward[2]))
sa_nas.reward(float(avg_cost)) sa_nas.reward(float(finally_reward[1]))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册