未验证 提交 c3faafbb 编写于 作者: B buchongyu 提交者: GitHub

[release/2.4] fix train_multi_machine bug in static ppdet (#6398)

上级 67767d25
...@@ -192,7 +192,6 @@ def main(): ...@@ -192,7 +192,6 @@ def main():
extra_keys) extra_keys)
exe.run(startup_prog) exe.run(startup_prog)
compiled_train_prog = fleet.main_program
if FLAGS.eval: if FLAGS.eval:
compiled_eval_prog = fluid.CompiledProgram(eval_prog) compiled_eval_prog = fluid.CompiledProgram(eval_prog)
...@@ -253,7 +252,7 @@ def main(): ...@@ -253,7 +252,7 @@ def main():
time_cost = np.mean(time_stat) time_cost = np.mean(time_stat)
eta_sec = (cfg.max_iters - it) * time_cost eta_sec = (cfg.max_iters - it) * time_cost
eta = str(datetime.timedelta(seconds=int(eta_sec))) eta = str(datetime.timedelta(seconds=int(eta_sec)))
outs = exe.run(compiled_train_prog, fetch_list=train_values) outs = exe.run(train_prog, fetch_list=train_values)
stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])} stats = {k: np.array(v).mean() for k, v in zip(train_keys, outs[:-1])}
# use vdl-paddle to log loss # use vdl-paddle to log loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册