From 6385065664c76880eb8dea39161b8bc1602ddfab Mon Sep 17 00:00:00 2001 From: guoshengCS Date: Wed, 23 Jan 2019 13:59:23 +0800 Subject: [PATCH] Fix logging and checkpoint loading in Transformer --- .../transformer/train.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/fluid/PaddleNLP/neural_machine_translation/transformer/train.py b/fluid/PaddleNLP/neural_machine_translation/transformer/train.py index 5fc98868..16d48238 100644 --- a/fluid/PaddleNLP/neural_machine_translation/transformer/train.py +++ b/fluid/PaddleNLP/neural_machine_translation/transformer/train.py @@ -408,10 +408,19 @@ def test_context(exe, train_exe, dev_count): test_data = prepare_data_generator( args, is_test=True, count=dev_count, pyreader=pyreader) - exe.run(startup_prog) + exe.run(startup_prog) # to init pyreader for testing + if TrainTaskConfig.ckpt_path: + fluid.io.load_persistables( + exe, TrainTaskConfig.ckpt_path, main_program=test_prog) + + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.use_experimental_executor = True + build_strategy = fluid.BuildStrategy() test_exe = fluid.ParallelExecutor( use_cuda=TrainTaskConfig.use_gpu, main_program=test_prog, + build_strategy=build_strategy, + exec_strategy=exec_strategy, share_vars_from=train_exe) def test(exe=test_exe, pyreader=pyreader): @@ -457,7 +466,11 @@ def train_loop(exe, nccl2_trainer_id=0): # Initialize the parameters. if TrainTaskConfig.ckpt_path: - fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path) + exe.run(startup_prog) # to init pyreader for training + logging.info("load checkpoint from {}".format( + TrainTaskConfig.ckpt_path)) + fluid.io.load_persistables( + exe, TrainTaskConfig.ckpt_path, main_program=train_prog) else: logging.info("init fluid.framework.default_startup_program") exe.run(startup_prog) @@ -741,6 +754,7 @@ if __name__ == "__main__": LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s" logging.basicConfig( stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT) + logging.getLogger().setLevel(logging.INFO) args = parse_args() train(args) -- GitLab