未验证 提交 c7727885 编写于 作者: W Wennie396 提交者: GitHub

fix fetch problem in pass_utils.py and eval_loss in parallelizer_v2.py (#56539)

* fix eval_loss bug in parallelizer_v2.py

* fix fetch problem in pass_utils.py
上级 589588f3
......@@ -414,7 +414,11 @@ class Parallelizer:
pass_manager = PassManager(new_pass_list)
pass_manager.apply([main_program], [startup_program])
if self._strategy.pipeline.enable and use_new_executor():
if (
self.is_train
and self._strategy.pipeline.enable
and use_new_executor()
):
main_program._pipeline_opt = {}
main_program._pipeline_opt["standalone_opt"] = {
"schedule_mode": self._strategy.pipeline.schedule_mode,
......
......@@ -441,7 +441,7 @@ def _program_for_fthenb_and_1f1b(program):
for op in src_block.ops:
if is_lr_sched_op(op):
lr_ops.append(op)
if is_forward_op(op):
elif is_forward_op(op):
fwd_ops.append(op)
elif is_backward_op(op):
bwd_ops.append(op)
......@@ -502,6 +502,17 @@ def _program_for_fthenb_and_1f1b(program):
opt_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, opt_block, opt_ops)
for fetch_op in src_block.ops:
if fetch_op.type in ["fetch", "fetch_v2"]:
in_name = fetch_op.input_arg_names[0]
dst_block = None
for block in [lr_block, fwd_block, bwd_block, opt_block]:
if block._find_var_recursive(in_name):
dst_block = block
break
if dst_block:
_create_program(src_block, dst_block, fetch_op)
lr_prog._sync_with_cpp()
fwd_prog._sync_with_cpp()
bwd_prog._sync_with_cpp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册