From 9070f374ba4f84a01121bd55c50d919f0bfbeebf Mon Sep 17 00:00:00 2001 From: guosheng Date: Tue, 9 Oct 2018 18:58:14 +0800 Subject: [PATCH] Fix ce with pyreader in Transformer --- .../transformer/train.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index f3432ccb..60530331 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -344,8 +344,11 @@ def py_reader_provider_wrapper(data_reader): def test_context(exe, train_exe, dev_count): # Context to do validation. - startup_prog = fluid.Program() test_prog = fluid.Program() + startup_prog = fluid.Program() + if args.enable_ce: + test_prog.random_seed = 1000 + startup_prog.random_seed = 1000 with fluid.program_guard(test_prog, startup_prog): with fluid.unique_name.guard(): sum_cost, avg_cost, predict, token_num, pyreader = transformer( @@ -509,10 +512,12 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost, time_consumed)) else: print("epoch: %d, consumed %fs" % (pass_id, time_consumed)) - fluid.io.save_persistables( - exe, - os.path.join(TrainTaskConfig.ckpt_dir, - "pass_" + str(pass_id) + ".checkpoint"), train_prog) + if not args.enable_ce: + fluid.io.save_persistables( + exe, + os.path.join(TrainTaskConfig.ckpt_dir, + "pass_" + str(pass_id) + ".checkpoint"), + train_prog) if args.enable_ce: # For CE print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost)) -- GitLab