diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 80aac2e257530d1523a55bc3297aad6bc78e0779..34419145188220550f693b5290d7ba77bbc2a2da 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -1406,6 +1406,10 @@ def is_prim_op(op): return op.type.endswith("_p") +def is_comm_op(op): + return op.has_attr("ring_id") + + def get_loss_op(block): loss_ops = [] for op in block.ops: diff --git a/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py b/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py index a947fe8bb3c5cc71246a0ec4b1bc01168f354c8e..74482f8b2d20b37e68704fa949333879adf76316 100644 --- a/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py +++ b/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py @@ -20,6 +20,7 @@ from paddle.distributed.auto_parallel.static.operators.common import ( from paddle.distributed.auto_parallel.static.utils import ( OpRole, insert_dependencies_for_vars, + is_comm_op, ) from .auto_parallel_sharding import ShardingPass, _supported_optimizer_type @@ -109,7 +110,11 @@ class AutoParalSupplementDepPass(PassBase): for idx, op in enumerate(main_block.ops): if op.type == "check_finite_and_unscale": if first_check_op: - last_backward_op = main_block.ops[idx - 1] + last_backward_op = None + for last_idx in range(idx - 1, 0, -1): + if not is_comm_op(main_block.ops[last_idx]): + last_backward_op = main_block.ops[last_idx] + break prior_varname = last_backward_op.output_arg_names[0] first_check_op = False deps_map[idx] = (