未验证 提交 c498ff33 编写于 作者: Z zhaoyingli 提交者: GitHub

fix supplement_explicit_dependencies when amp-o2 (#56445)

上级 6acb85a5
...@@ -1406,6 +1406,10 @@ def is_prim_op(op): ...@@ -1406,6 +1406,10 @@ def is_prim_op(op):
return op.type.endswith("_p") return op.type.endswith("_p")
def is_comm_op(op):
return op.has_attr("ring_id")
def get_loss_op(block): def get_loss_op(block):
loss_ops = [] loss_ops = []
for op in block.ops: for op in block.ops:
......
...@@ -20,6 +20,7 @@ from paddle.distributed.auto_parallel.static.operators.common import ( ...@@ -20,6 +20,7 @@ from paddle.distributed.auto_parallel.static.operators.common import (
from paddle.distributed.auto_parallel.static.utils import ( from paddle.distributed.auto_parallel.static.utils import (
OpRole, OpRole,
insert_dependencies_for_vars, insert_dependencies_for_vars,
is_comm_op,
) )
from .auto_parallel_sharding import ShardingPass, _supported_optimizer_type from .auto_parallel_sharding import ShardingPass, _supported_optimizer_type
...@@ -109,7 +110,11 @@ class AutoParalSupplementDepPass(PassBase): ...@@ -109,7 +110,11 @@ class AutoParalSupplementDepPass(PassBase):
for idx, op in enumerate(main_block.ops): for idx, op in enumerate(main_block.ops):
if op.type == "check_finite_and_unscale": if op.type == "check_finite_and_unscale":
if first_check_op: if first_check_op:
last_backward_op = main_block.ops[idx - 1] last_backward_op = None
for last_idx in range(idx - 1, 0, -1):
if not is_comm_op(main_block.ops[last_idx]):
last_backward_op = main_block.ops[last_idx]
break
prior_varname = last_backward_op.output_arg_names[0] prior_varname = last_backward_op.output_arg_names[0]
first_check_op = False first_check_op = False
deps_map[idx] = ( deps_map[idx] = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册