diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index f3432ccbd4c5dc8e9d9d669886945ed37303224a..605303313aeae781fe357b866b77c826ff27a143 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -344,8 +344,11 @@ def py_reader_provider_wrapper(data_reader): def test_context(exe, train_exe, dev_count): # Context to do validation. - startup_prog = fluid.Program() test_prog = fluid.Program() + startup_prog = fluid.Program() + if args.enable_ce: + test_prog.random_seed = 1000 + startup_prog.random_seed = 1000 with fluid.program_guard(test_prog, startup_prog): with fluid.unique_name.guard(): sum_cost, avg_cost, predict, token_num, pyreader = transformer( @@ -509,10 +512,12 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost, time_consumed)) else: print("epoch: %d, consumed %fs" % (pass_id, time_consumed)) - fluid.io.save_persistables( - exe, - os.path.join(TrainTaskConfig.ckpt_dir, - "pass_" + str(pass_id) + ".checkpoint"), train_prog) + if not args.enable_ce: + fluid.io.save_persistables( + exe, + os.path.join(TrainTaskConfig.ckpt_dir, + "pass_" + str(pass_id) + ".checkpoint"), + train_prog) if args.enable_ce: # For CE print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))