未验证 提交 db7d6808 编写于 作者: J JZ-LIANG 提交者: GitHub

bugfix for pp (#48353)

上级 ea83f898
...@@ -469,14 +469,29 @@ class RecomputePass(PassBase): ...@@ -469,14 +469,29 @@ class RecomputePass(PassBase):
ckpt_ops_dict[fwd_op_id][0] = False ckpt_ops_dict[fwd_op_id][0] = False
if rc_op: if rc_op:
insert_dependencies_for_two_ops( prior_op = main_block.ops[rc_op.idx - 1]
main_block, posterior_op = rc_op
idx, prior_mesh = (
main_block.ops[rc_op.idx - 1], self._dist_context.get_op_dist_attr_for_program(
rc_op, prior_op
self._dist_context, ).process_mesh
sync=False, )
posterior_mesh = (
self._dist_context.get_op_dist_attr_for_program(
posterior_op
).process_mesh
) )
# NOTE if two recompute segements across two pipeline stages
# not need dependecies for it
if prior_mesh == posterior_mesh:
insert_dependencies_for_two_ops(
main_block,
idx,
prior_op,
posterior_op,
self._dist_context,
sync=False,
)
main_program._sync_with_cpp() main_program._sync_with_cpp()
def reset_op_dist_attr(self, op, var_name_dict): def reset_op_dist_attr(self, op, var_name_dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册