提交 9070f374 编写于 作者: G guosheng

Fix ce with pyreader in Transformer

上级 a6534450
...@@ -344,8 +344,11 @@ def py_reader_provider_wrapper(data_reader): ...@@ -344,8 +344,11 @@ def py_reader_provider_wrapper(data_reader):
def test_context(exe, train_exe, dev_count): def test_context(exe, train_exe, dev_count):
# Context to do validation. # Context to do validation.
startup_prog = fluid.Program()
test_prog = fluid.Program() test_prog = fluid.Program()
startup_prog = fluid.Program()
if args.enable_ce:
test_prog.random_seed = 1000
startup_prog.random_seed = 1000
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
sum_cost, avg_cost, predict, token_num, pyreader = transformer( sum_cost, avg_cost, predict, token_num, pyreader = transformer(
...@@ -509,10 +512,12 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost, ...@@ -509,10 +512,12 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
time_consumed)) time_consumed))
else: else:
print("epoch: %d, consumed %fs" % (pass_id, time_consumed)) print("epoch: %d, consumed %fs" % (pass_id, time_consumed))
if not args.enable_ce:
fluid.io.save_persistables( fluid.io.save_persistables(
exe, exe,
os.path.join(TrainTaskConfig.ckpt_dir, os.path.join(TrainTaskConfig.ckpt_dir,
"pass_" + str(pass_id) + ".checkpoint"), train_prog) "pass_" + str(pass_id) + ".checkpoint"),
train_prog)
if args.enable_ce: # For CE if args.enable_ce: # For CE
print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost)) print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册