From d3bad33f45304468440c5f34e1e1947762dda571 Mon Sep 17 00:00:00 2001 From: shippingwang Date: Tue, 30 Jun 2020 08:00:51 +0000 Subject: [PATCH] add print_interval and refine override --- ppcls/utils/config.py | 7 ++++++- tools/eval.py | 2 +- tools/program.py | 25 +++++++++++++++++++------ tools/run.sh | 3 ++- tools/train.py | 10 +++++----- 5 files changed, 33 insertions(+), 14 deletions(-) diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py index e1712a8a..54ee5175 100644 --- a/ppcls/utils/config.py +++ b/ppcls/utils/config.py @@ -144,9 +144,14 @@ def override(dl, ks, v): override(dl[k], ks[1:], v) else: if len(ks) == 1: - assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl)) + #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl)) + if not ks[0] in dl: + logger.warning('A new filed ({}) detected!'.format(ks[0], dl)) dl[ks[0]] = str2num(v) else: + assert ks[0] in dl, ( + '({}) doesn\'t exist in {}, a new dict field is invalid'. + format(ks[0], dl)) override(dl[ks[0]], ks[1:], v) diff --git a/tools/eval.py b/tools/eval.py index 291f77f0..7aa370c6 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -74,7 +74,7 @@ def main(args): compiled_valid_prog = program.compile(config, valid_prog) program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1, - 'eval') + 'eval', config) if __name__ == '__main__': diff --git a/tools/program.py b/tools/program.py index 4b27e89a..3c796a47 100644 --- a/tools/program.py +++ b/tools/program.py @@ -410,6 +410,7 @@ def run(dataloader, fetchs, epoch=0, mode='train', + config=None, vdl_writer=None): """ Feed data to the model and fetch the measures and loss @@ -443,16 +444,28 @@ def run(dataloader, logger.scaler('loss', metrics[0][0], total_step, vdl_writer) total_step += 1 if mode == 'eval': - logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) + if idx % config.get('print_interval', 1) == 0: + logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, + fetchs_str)) else: epoch_str = "epoch:{:<3d}".format(epoch) step_str = "{:s} step:{:<4d}".format(mode, idx) - logger.info("{:s} {:s} {:s}".format( - logger.coloring(epoch_str, "HEADER") - if idx == 0 else epoch_str, - logger.coloring(step_str, "PURPLE"), - logger.coloring(fetchs_str, 'OKGREEN'))) + # Keep the first 10 batches statistics, They are important for develop + if epoch == 0 and idx < 10: + logger.info("{:s} {:s} {:s}".format( + logger.coloring(epoch_str, "HEADER") + if idx == 0 else epoch_str, + logger.coloring(step_str, "PURPLE"), + logger.coloring(fetchs_str, 'OKGREEN'))) + + else: + if idx % config.get('print_interval', 1) == 0: + logger.info("{:s} {:s} {:s}".format( + logger.coloring(epoch_str, "HEADER") + if idx == 0 else epoch_str, + logger.coloring(step_str, "PURPLE"), + logger.coloring(fetchs_str, 'OKGREEN'))) end_str = ''.join([str(m.mean) + ' ' for m in metric_list] + [batch_time.total]) + 's' diff --git a/tools/run.sh b/tools/run.sh index 55f2918d..5e8043b1 100755 --- a/tools/run.sh +++ b/tools/run.sh @@ -5,4 +5,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH python -m paddle.distributed.launch \ --selected_gpus="0,1,2,3" \ tools/train.py \ - -c ./configs/ResNet/ResNet50.yaml + -c ./configs/ResNet/ResNet50.yaml \ + -o print_interval=10 diff --git a/tools/train.py b/tools/train.py index 3456e2ae..e9188eb0 100644 --- a/tools/train.py +++ b/tools/train.py @@ -110,21 +110,21 @@ def main(args): for epoch_id in range(config.epochs): # 1. train with train dataset program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, - epoch_id, 'train', vdl_writer) + epoch_id, 'train', config, vdl_writer) if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: # 2. validate with validate dataset if config.validate and epoch_id % config.valid_interval == 0: if config.get('use_ema'): logger.info(logger.coloring("EMA validate start...")) with ema.apply(exe): - top1_acc = program.run(valid_dataloader, exe, - compiled_valid_prog, - valid_fetchs, epoch_id, 'valid') + top1_acc = program.run( + valid_dataloader, exe, compiled_valid_prog, + valid_fetchs, epoch_id, 'valid', config) logger.info(logger.coloring("EMA validate over!")) top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, - epoch_id, 'valid') + epoch_id, 'valid', config) if top1_acc > best_top1_acc: best_top1_acc = top1_acc message = "The best top1 acc {:.5f}, in epoch: {:d}".format( -- GitLab