提交 51456e05 编写于 作者: S shippingwang

refine

上级 bf86e09d
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
extension/index extension/index
competition_support.md competition_support.md
model_zoo.md model_zoo.md
change_log.md update_history.md
faq.md faq.md
:math:`PaddlePaddle2020` :math:`PaddlePaddle2020`
...@@ -139,6 +139,7 @@ def get_file_list(params): ...@@ -139,6 +139,7 @@ def get_file_list(params):
full_lines = shuffle_lines(full_lines, params["shuffle_seed"]) full_lines = shuffle_lines(full_lines, params["shuffle_seed"])
# use only partial data for each trainer in distributed training # use only partial data for each trainer in distributed training
if params['mode'] == 'train':
img_per_trainer = len(full_lines) // trainers_num img_per_trainer = len(full_lines) // trainers_num
full_lines = full_lines[trainer_id::trainers_num][:img_per_trainer] full_lines = full_lines[trainer_id::trainers_num][:img_per_trainer]
......
...@@ -380,6 +380,7 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): ...@@ -380,6 +380,7 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
m.reset() m.reset()
batch_time = AverageMeter('cost', ':6.3f') batch_time = AverageMeter('cost', ':6.3f')
tic = time.time() tic = time.time()
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
for idx, batch in enumerate(dataloader()): for idx, batch in enumerate(dataloader()):
metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list) metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
batch_time.update(time.time() - tic) batch_time.update(time.time() - tic)
...@@ -387,6 +388,9 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): ...@@ -387,6 +388,9 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
for i, m in enumerate(metrics): for i, m in enumerate(metrics):
metric_list[i].update(m[0], len(batch[0])) metric_list[i].update(m[0], len(batch[0]))
fetchs_str = ''.join([str(m) for m in metric_list] + [str(batch_time)]) fetchs_str = ''.join([str(m) for m in metric_list] + [str(batch_time)])
if trainer_id == 0:
logger.info("[epoch:%3d][%s][step:%4d]%s" % logger.info("[epoch:%3d][%s][step:%4d]%s" %
(epoch, mode, idx, fetchs_str)) (epoch, mode, idx, fetchs_str))
if trainer_id == 0:
logger.info("END [epoch:%3d][%s]%s"%(epoch, mode, fetchs_str)) logger.info("END [epoch:%3d][%s]%s"%(epoch, mode, fetchs_str))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册