未验证 提交 75a3c654 编写于 作者: D dyning 提交者: GitHub

Merge pull request #57 from shippingwang/fix_multicard_eval

fix multi card eval bug
......@@ -11,7 +11,7 @@
extension/index
competition_support.md
model_zoo.md
change_log.md
update_history.md
faq.md
:math:`PaddlePaddle2020`
......@@ -139,8 +139,9 @@ def get_file_list(params):
full_lines = shuffle_lines(full_lines, params["shuffle_seed"])
# use only partial data for each trainer in distributed training
img_per_trainer = len(full_lines) // trainers_num
full_lines = full_lines[trainer_id::trainers_num][:img_per_trainer]
if params['mode'] == 'train':
img_per_trainer = len(full_lines) // trainers_num
full_lines = full_lines[trainer_id::trainers_num][:img_per_trainer]
return full_lines
......
......@@ -380,6 +380,7 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
m.reset()
batch_time = AverageMeter('cost', ':6.3f')
tic = time.time()
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
for idx, batch in enumerate(dataloader()):
metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
batch_time.update(time.time() - tic)
......@@ -387,6 +388,9 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
for i, m in enumerate(metrics):
metric_list[i].update(m[0], len(batch[0]))
fetchs_str = ''.join([str(m) for m in metric_list] + [str(batch_time)])
logger.info("[epoch:%3d][%s][step:%4d]%s" %
if trainer_id == 0:
logger.info("[epoch:%3d][%s][step:%4d]%s" %
(epoch, mode, idx, fetchs_str))
logger.info("END [epoch:%3d][%s]%s"%(epoch, mode, fetchs_str))
if trainer_id == 0:
logger.info("END [epoch:%3d][%s]%s"%(epoch, mode, fetchs_str))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册