“82b8a3c5d9a567e687b11fa03a55e1caffe1bceb”上不存在“python/paddle/incubate/hapi/datasets/__init__.py”
提交 4be8333e 编写于 作者: C ceci3

update demo

上级 5f44fc78
...@@ -120,26 +120,24 @@ def search_mobilenetv2(config, args, image_size): ...@@ -120,26 +120,24 @@ def search_mobilenetv2(config, args, image_size):
train_loader.set_sample_list_generator( train_loader.set_sample_list_generator(
train_reader, train_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
test_loader.set_sample_list_generator( test_loader.set_sample_list_generator(test_reader, places=place)
test_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
build_strategy = fluid.BuildStrategy() #build_strategy = fluid.BuildStrategy()
train_compiled_program = fluid.CompiledProgram( #train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel( # train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy) # loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(args.retain_epoch): #for epoch_id in range(args.retain_epoch):
for batch_id, data in enumerate(train_loader()): # for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name] # fetches = [avg_cost.name]
s_time = time.time() # s_time = time.time()
outs = exe.run(train_compiled_program, # outs = exe.run(train_compiled_program,
feed=data, # feed=data,
fetch_list=fetches)[0] # fetch_list=fetches)[0]
batch_time = time.time() - s_time # batch_time = time.time() - s_time
if batch_id % 10 == 0: # if batch_id % 10 == 0:
_logger.info( # _logger.info(
'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'. # 'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
format(step, epoch_id, batch_id, outs[0], batch_time)) # format(step, epoch_id, batch_id, outs[0], batch_time))
reward = [] reward = []
for batch_id, data in enumerate(test_loader()): for batch_id, data in enumerate(test_loader()):
...@@ -154,7 +152,8 @@ def search_mobilenetv2(config, args, image_size): ...@@ -154,7 +152,8 @@ def search_mobilenetv2(config, args, image_size):
_logger.info( _logger.info(
'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'. 'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'.
format(step, test_outs[0], test_outs[1], test_outs[2])) format(step, batch_id, batch_reward[0], batch_reward[1],
batch_reward[2]))
finally_reward = np.mean(np.array(reward), axis=0) finally_reward = np.mean(np.array(reward), axis=0)
_logger.info( _logger.info(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册