From c498ff3323073f296a7ea1042049612a887609b6 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Tue, 22 Aug 2023 20:38:46 +0800 Subject: [PATCH] fix supplement_explicit_dependencies when amp-o2 (#56445) --- python/paddle/distributed/auto_parallel/static/utils.py | 4 ++++ .../auto_parallel_supplement_explicit_dependencies.py | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 80aac2e2575..34419145188 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -1406,6 +1406,10 @@ def is_prim_op(op): return op.type.endswith("_p") +def is_comm_op(op): + return op.has_attr("ring_id") + + def get_loss_op(block): loss_ops = [] for op in block.ops: diff --git a/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py b/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py index a947fe8bb3c..74482f8b2d2 100644 --- a/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py +++ b/python/paddle/distributed/passes/auto_parallel_supplement_explicit_dependencies.py @@ -20,6 +20,7 @@ from paddle.distributed.auto_parallel.static.operators.common import ( from paddle.distributed.auto_parallel.static.utils import ( OpRole, insert_dependencies_for_vars, + is_comm_op, ) from .auto_parallel_sharding import ShardingPass, _supported_optimizer_type @@ -109,7 +110,11 @@ class AutoParalSupplementDepPass(PassBase): for idx, op in enumerate(main_block.ops): if op.type == "check_finite_and_unscale": if first_check_op: - last_backward_op = main_block.ops[idx - 1] + last_backward_op = None + for last_idx in range(idx - 1, 0, -1): + if not is_comm_op(main_block.ops[last_idx]): + last_backward_op = main_block.ops[last_idx] + break prior_varname = last_backward_op.output_arg_names[0] first_check_op = False deps_map[idx] = ( -- GitLab