From db7d6808fd4eba8375ed45b5fe9cccbf915de333 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 25 Nov 2022 10:52:09 +0800 Subject: [PATCH] bugfix for pp (#48353) --- .../passes/auto_parallel_recompute.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 44e02fb3ff..72a116a5eb 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -469,14 +469,29 @@ class RecomputePass(PassBase): ckpt_ops_dict[fwd_op_id][0] = False if rc_op: - insert_dependencies_for_two_ops( - main_block, - idx, - main_block.ops[rc_op.idx - 1], - rc_op, - self._dist_context, - sync=False, + prior_op = main_block.ops[rc_op.idx - 1] + posterior_op = rc_op + prior_mesh = ( + self._dist_context.get_op_dist_attr_for_program( + prior_op + ).process_mesh + ) + posterior_mesh = ( + self._dist_context.get_op_dist_attr_for_program( + posterior_op + ).process_mesh ) + # NOTE if two recompute segements across two pipeline stages + # not need dependecies for it + if prior_mesh == posterior_mesh: + insert_dependencies_for_two_ops( + main_block, + idx, + prior_op, + posterior_op, + self._dist_context, + sync=False, + ) main_program._sync_with_cpp() def reset_op_dist_attr(self, op, var_name_dict): -- GitLab