未验证 提交 0f4229c5 编写于 作者: K kangguangli 提交者: GitHub

[Perf] remove sync_calc_stream and sync_comm_stream (#51989)

* remove sync_calc_stream and sync_comm_stream

* fix ci bug

* fix

* fix

* fix
上级 d1c7b386
......@@ -301,17 +301,6 @@ class RawProgramOptimizer(MetaOptimizerBase):
if param.is_distributed:
continue
grad_vars.append(grad)
block._insert_op(
idx + offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
OP_ROLE_KEY: OpRole.Backward,
},
)
offset += 1
block._insert_op(
idx + offset,
type='c_allreduce_sum',
......@@ -326,17 +315,6 @@ class RawProgramOptimizer(MetaOptimizerBase):
if grad is None:
return
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op(
idx,
type='c_sync_comm_stream',
inputs={'X': grad_vars},
outputs={'Out': grad_vars},
attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
)
break
# This function helps reduce the number of allreduce by integrating op, which can save communication time.
# to use allreduce fuse, follow these codes:
# strategy = paddle.distributed.fleet.DistributedStrategy()
......@@ -424,14 +402,20 @@ class RawProgramOptimizer(MetaOptimizerBase):
OP_ROLE_KEY: OpRole.Backward,
},
)
if not self.calc_comm_same_stream:
block._insert_op_without_sync(
after_idx + 1,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: OpRole.Backward},
)
idx = 0
if not self.calc_comm_same_stream:
for i in range(len(grad_param_segments)):
while block.ops[idx].type != 'c_allreduce_sum':
idx += 1
grad_segment, param_segment = grad_param_segments[i]
for grad in grad_segment:
block._insert_op_without_sync(
idx + 1,
type='depend',
inputs={'X': grad, 'Dep': fused_var},
outputs={'Out': grad},
)
idx += 1
# update the outputs_name_to_idx after insertion of sync/allreduce ops
outputs_name_to_idx = self.__get_ouputs_name_to_idx(
......@@ -471,21 +455,6 @@ class RawProgramOptimizer(MetaOptimizerBase):
},
)
if self.calc_comm_same_stream:
block._sync_with_cpp()
return
# insert the sync comm op
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op_without_sync(
idx,
type='c_sync_comm_stream',
inputs={'X': fused_vars},
outputs={'Out': fused_vars},
attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
)
break
block._sync_with_cpp()
def __get_ouputs_name_to_idx(self, first_backward_idx, block):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册