未验证 提交 c498ff33 编写于 作者: Z zhaoyingli 提交者: GitHub

fix supplement_explicit_dependencies when amp-o2 (#56445)

上级 6acb85a5
......@@ -1406,6 +1406,10 @@ def is_prim_op(op):
return op.type.endswith("_p")
def is_comm_op(op):
return op.has_attr("ring_id")
def get_loss_op(block):
loss_ops = []
for op in block.ops:
......
......@@ -20,6 +20,7 @@ from paddle.distributed.auto_parallel.static.operators.common import (
from paddle.distributed.auto_parallel.static.utils import (
OpRole,
insert_dependencies_for_vars,
is_comm_op,
)
from .auto_parallel_sharding import ShardingPass, _supported_optimizer_type
......@@ -109,7 +110,11 @@ class AutoParalSupplementDepPass(PassBase):
for idx, op in enumerate(main_block.ops):
if op.type == "check_finite_and_unscale":
if first_check_op:
last_backward_op = main_block.ops[idx - 1]
last_backward_op = None
for last_idx in range(idx - 1, 0, -1):
if not is_comm_op(main_block.ops[last_idx]):
last_backward_op = main_block.ops[last_idx]
break
prior_varname = last_backward_op.output_arg_names[0]
first_check_op = False
deps_map[idx] = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册