diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 25f6ff8355d7324ea71b1c877b97b98c7387ebcd..e792d2a38dc7e0767b82f96b59ff09729bcb5cd8 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -178,6 +178,7 @@ message PipelineConfig { optional int32 accumulate_steps = 2 [ default = 1 ]; optional string schedule_mode = 3 [ default = '1F1B' ]; optional bool p2p_cache_shape = 4 [ default = true ]; + optional bool enable_partial_send_recv = 5 [ default = true ]; } message TensorParallelConfig { diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 5488cdf32262b72199e05d8589ec3c0bd33fad64..937122e8a7b33010e96aae6f3f4367f54b7ca2a6 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -46,7 +46,10 @@ class PipelineParallel(MetaParallelBase): 'micro_batch_size'] self.accumulate_steps = self._strategy.pipeline_configs[ 'accumulate_steps'] - + # If sent tensor are not the same from different hosts, + # they shouldn't been sent partially and then concated as a whole tensor. + self._enable_partial_send_recv = self._strategy.pipeline_configs[ + 'enable_partial_send_recv'] self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape'] self.num_stages = self._hcg.get_pipe_parallel_world_size() @@ -58,7 +61,8 @@ class PipelineParallel(MetaParallelBase): self._real_pp_world_size = self.num_stages self._real_pp_rank = self.stage_id - p2p.initialize_p2p_groups(hcg, self._using_cache) + p2p.initialize_p2p_groups(hcg, self._using_cache, + self._enable_partial_send_recv) self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 8e048e3db6dab0f5e5370afb493ee8f47d9e4cc1..ca438326386f0dc7a411310d945961eb94ea7926 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -22,12 +22,14 @@ from .utils import paddle_2_number, paddle_2_number, number_2_dtype _hcg = None _use_cache = False +_enable_partial_send_recv = True -def initialize_p2p_groups(hcg, use_cache=True): - global _hcg, _use_cache +def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True): + global _hcg, _use_cache, _enable_partial_send_recv _hcg = hcg _use_cache = use_cache + _enable_partial_send_recv = enable_partial_send_recv send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups( ) @@ -157,7 +159,8 @@ _send_recv_meta = SendRecvMeta() def _is_valid_send_recv_partial(tensor, mp_degree): - + if not _enable_partial_send_recv: + return False tensor_numel = np.prod(tensor.shape) assert tensor_numel != 0, "can't send/recv zero element" return mp_degree > 1 and tensor_numel % mp_degree == 0