未验证 提交 1d015f12 编写于 作者: G Ghost Screaming 提交者: GitHub

Add enable_partial_send_recv switch in pipeline_configs (#46992) (#47083)

* Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result
is wrong.

* Support allow_partial switch, which can be configure in
pipeline_configs. If sent tensor are not the same from
different hosts, they shouldn't been sent partially and
then concated as a whole tensor.

* Change name allow_partial to enable_partial_send_recv.

* Add global variable _enable_partial_send_recv
上级 69515e90
......@@ -177,6 +177,7 @@ message PipelineConfig {
optional int32 accumulate_steps = 2 [ default = 1 ];
optional string schedule_mode = 3 [ default = '1F1B' ];
optional bool p2p_cache_shape = 4 [ default = true ];
optional bool enable_partial_send_recv = 5 [ default = true ];
}
message TensorParallelConfig {
......
......@@ -46,7 +46,10 @@ class PipelineParallel(MetaParallelBase):
'micro_batch_size']
self.accumulate_steps = self._strategy.pipeline_configs[
'accumulate_steps']
# If sent tensor are not the same from different hosts,
# they shouldn't been sent partially and then concated as a whole tensor.
self._enable_partial_send_recv = self._strategy.pipeline_configs[
'enable_partial_send_recv']
self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']
self.num_stages = self._hcg.get_pipe_parallel_world_size()
......@@ -58,7 +61,8 @@ class PipelineParallel(MetaParallelBase):
self._real_pp_world_size = self.num_stages
self._real_pp_rank = self.stage_id
p2p.initialize_p2p_groups(hcg, self._using_cache)
p2p.initialize_p2p_groups(hcg, self._using_cache,
self._enable_partial_send_recv)
self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0
......
......@@ -22,12 +22,14 @@ from .utils import paddle_2_number, paddle_2_number, number_2_dtype
_hcg = None
_use_cache = False
_enable_partial_send_recv = True
def initialize_p2p_groups(hcg, use_cache=True):
global _hcg, _use_cache
def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
global _hcg, _use_cache, _enable_partial_send_recv
_hcg = hcg
_use_cache = use_cache
_enable_partial_send_recv = enable_partial_send_recv
send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups(
)
......@@ -157,7 +159,8 @@ _send_recv_meta = SendRecvMeta()
def _is_valid_send_recv_partial(tensor, mp_degree):
if not _enable_partial_send_recv:
return False
tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element"
return mp_degree > 1 and tensor_numel % mp_degree == 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册