未验证 提交 0f4229c5 编写于 作者: K kangguangli 提交者: GitHub

[Perf] remove sync_calc_stream and sync_comm_stream (#51989)

* remove sync_calc_stream and sync_comm_stream

* fix ci bug

* fix

* fix

* fix
上级 d1c7b386
...@@ -301,17 +301,6 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -301,17 +301,6 @@ class RawProgramOptimizer(MetaOptimizerBase):
if param.is_distributed: if param.is_distributed:
continue continue
grad_vars.append(grad)
block._insert_op(
idx + offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
OP_ROLE_KEY: OpRole.Backward,
},
)
offset += 1
block._insert_op( block._insert_op(
idx + offset, idx + offset,
type='c_allreduce_sum', type='c_allreduce_sum',
...@@ -326,17 +315,6 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -326,17 +315,6 @@ class RawProgramOptimizer(MetaOptimizerBase):
if grad is None: if grad is None:
return return
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op(
idx,
type='c_sync_comm_stream',
inputs={'X': grad_vars},
outputs={'Out': grad_vars},
attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
)
break
# This function helps reduce the number of allreduce by integrating op, which can save communication time. # This function helps reduce the number of allreduce by integrating op, which can save communication time.
# to use allreduce fuse, follow these codes: # to use allreduce fuse, follow these codes:
# strategy = paddle.distributed.fleet.DistributedStrategy() # strategy = paddle.distributed.fleet.DistributedStrategy()
...@@ -424,14 +402,20 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -424,14 +402,20 @@ class RawProgramOptimizer(MetaOptimizerBase):
OP_ROLE_KEY: OpRole.Backward, OP_ROLE_KEY: OpRole.Backward,
}, },
) )
if not self.calc_comm_same_stream: idx = 0
block._insert_op_without_sync( if not self.calc_comm_same_stream:
after_idx + 1, for i in range(len(grad_param_segments)):
type='c_sync_calc_stream', while block.ops[idx].type != 'c_allreduce_sum':
inputs={'X': fused_var}, idx += 1
outputs={'Out': fused_var}, grad_segment, param_segment = grad_param_segments[i]
attrs={OP_ROLE_KEY: OpRole.Backward}, for grad in grad_segment:
) block._insert_op_without_sync(
idx + 1,
type='depend',
inputs={'X': grad, 'Dep': fused_var},
outputs={'Out': grad},
)
idx += 1
# update the outputs_name_to_idx after insertion of sync/allreduce ops # update the outputs_name_to_idx after insertion of sync/allreduce ops
outputs_name_to_idx = self.__get_ouputs_name_to_idx( outputs_name_to_idx = self.__get_ouputs_name_to_idx(
...@@ -471,21 +455,6 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -471,21 +455,6 @@ class RawProgramOptimizer(MetaOptimizerBase):
}, },
) )
if self.calc_comm_same_stream:
block._sync_with_cpp()
return
# insert the sync comm op
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op_without_sync(
idx,
type='c_sync_comm_stream',
inputs={'X': fused_vars},
outputs={'Out': fused_vars},
attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward},
)
break
block._sync_with_cpp() block._sync_with_cpp()
def __get_ouputs_name_to_idx(self, first_backward_idx, block): def __get_ouputs_name_to_idx(self, first_backward_idx, block):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册