提交 4be8333e 编写于 作者: C ceci3

update demo

上级 5f44fc78
......@@ -120,26 +120,24 @@ def search_mobilenetv2(config, args, image_size):
train_loader.set_sample_list_generator(
train_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
test_loader.set_sample_list_generator(
test_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
test_loader.set_sample_list_generator(test_reader, places=place)
build_strategy = fluid.BuildStrategy()
train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(args.retain_epoch):
for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name]
s_time = time.time()
outs = exe.run(train_compiled_program,
feed=data,
fetch_list=fetches)[0]
batch_time = time.time() - s_time
if batch_id % 10 == 0:
_logger.info(
'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
format(step, epoch_id, batch_id, outs[0], batch_time))
#build_strategy = fluid.BuildStrategy()
#train_compiled_program = fluid.CompiledProgram(
# train_program).with_data_parallel(
# loss_name=avg_cost.name, build_strategy=build_strategy)
#for epoch_id in range(args.retain_epoch):
# for batch_id, data in enumerate(train_loader()):
# fetches = [avg_cost.name]
# s_time = time.time()
# outs = exe.run(train_compiled_program,
# feed=data,
# fetch_list=fetches)[0]
# batch_time = time.time() - s_time
# if batch_id % 10 == 0:
# _logger.info(
# 'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
# format(step, epoch_id, batch_id, outs[0], batch_time))
reward = []
for batch_id, data in enumerate(test_loader()):
......@@ -154,7 +152,8 @@ def search_mobilenetv2(config, args, image_size):
_logger.info(
'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'.
format(step, test_outs[0], test_outs[1], test_outs[2]))
format(step, batch_id, batch_reward[0], batch_reward[1],
batch_reward[2]))
finally_reward = np.mean(np.array(reward), axis=0)
_logger.info(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册