未验证 提交 db7d6808 编写于 作者: J JZ-LIANG 提交者: GitHub

bugfix for pp (#48353)

上级 ea83f898
......@@ -469,14 +469,29 @@ class RecomputePass(PassBase):
ckpt_ops_dict[fwd_op_id][0] = False
if rc_op:
insert_dependencies_for_two_ops(
main_block,
idx,
main_block.ops[rc_op.idx - 1],
rc_op,
self._dist_context,
sync=False,
prior_op = main_block.ops[rc_op.idx - 1]
posterior_op = rc_op
prior_mesh = (
self._dist_context.get_op_dist_attr_for_program(
prior_op
).process_mesh
)
posterior_mesh = (
self._dist_context.get_op_dist_attr_for_program(
posterior_op
).process_mesh
)
# NOTE if two recompute segements across two pipeline stages
# not need dependecies for it
if prior_mesh == posterior_mesh:
insert_dependencies_for_two_ops(
main_block,
idx,
prior_op,
posterior_op,
self._dist_context,
sync=False,
)
main_program._sync_with_cpp()
def reset_op_dist_attr(self, op, var_name_dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册