From 51456e05852a394c64c3995edc90636859593c06 Mon Sep 17 00:00:00 2001 From: shippingwang Date: Tue, 21 Apr 2020 13:44:19 +0000 Subject: [PATCH] refine --- docs/zh_CN/index.rst | 2 +- docs/zh_CN/{change_log.md => update_history.md} | 0 ppcls/data/reader.py | 5 +++-- tools/program.py | 8 ++++++-- 4 files changed, 10 insertions(+), 5 deletions(-) rename docs/zh_CN/{change_log.md => update_history.md} (100%) diff --git a/docs/zh_CN/index.rst b/docs/zh_CN/index.rst index 151b4dad..12a486ed 100644 --- a/docs/zh_CN/index.rst +++ b/docs/zh_CN/index.rst @@ -11,7 +11,7 @@ extension/index competition_support.md model_zoo.md - change_log.md + update_history.md faq.md :math:`PaddlePaddle2020` diff --git a/docs/zh_CN/change_log.md b/docs/zh_CN/update_history.md similarity index 100% rename from docs/zh_CN/change_log.md rename to docs/zh_CN/update_history.md diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index 41ebed42..5bf83c21 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -139,8 +139,9 @@ def get_file_list(params): full_lines = shuffle_lines(full_lines, params["shuffle_seed"]) # use only partial data for each trainer in distributed training - img_per_trainer = len(full_lines) // trainers_num - full_lines = full_lines[trainer_id::trainers_num][:img_per_trainer] + if params['mode'] == 'train': + img_per_trainer = len(full_lines) // trainers_num + full_lines = full_lines[trainer_id::trainers_num][:img_per_trainer] return full_lines diff --git a/tools/program.py b/tools/program.py index 796f09a5..c518555e 100644 --- a/tools/program.py +++ b/tools/program.py @@ -380,6 +380,7 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): m.reset() batch_time = AverageMeter('cost', ':6.3f') tic = time.time() + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) for idx, batch in enumerate(dataloader()): metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list) batch_time.update(time.time() - tic) @@ -387,6 +388,9 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): for i, m in enumerate(metrics): metric_list[i].update(m[0], len(batch[0])) fetchs_str = ''.join([str(m) for m in metric_list] + [str(batch_time)]) - logger.info("[epoch:%3d][%s][step:%4d]%s" % + if trainer_id == 0: + + logger.info("[epoch:%3d][%s][step:%4d]%s" % (epoch, mode, idx, fetchs_str)) - logger.info("END [epoch:%3d][%s]%s"%(epoch, mode, fetchs_str)) + if trainer_id == 0: + logger.info("END [epoch:%3d][%s]%s"%(epoch, mode, fetchs_str)) -- GitLab