diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 3fd7a994a62fb4d73ab06f32b0319d16a535da95..2974295f72fed19d25919c9a4b30d484bc2f1f6e 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -177,6 +177,7 @@ message PipelineConfig { optional int32 accumulate_steps = 2 [ default = 1 ]; optional string schedule_mode = 3 [ default = '1F1B' ]; optional bool p2p_cache_shape = 4 [ default = true ]; + optional bool enable_partial_send_recv = 5 [ default = true ]; } message TensorParallelConfig { diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 56429b748064daeac2780d5414513fffa9003b58..7afda2d6f48db7a6884771c2f1daa1e967fa492f 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -46,7 +46,10 @@ class PipelineParallel(MetaParallelBase): 'micro_batch_size'] self.accumulate_steps = self._strategy.pipeline_configs[ 'accumulate_steps'] - + # If sent tensor are not the same from different hosts, + # they shouldn't been sent partially and then concated as a whole tensor. + self._enable_partial_send_recv = self._strategy.pipeline_configs[ + 'enable_partial_send_recv'] self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape'] self.num_stages = self._hcg.get_pipe_parallel_world_size() @@ -58,7 +61,8 @@ class PipelineParallel(MetaParallelBase): self._real_pp_world_size = self.num_stages self._real_pp_rank = self.stage_id - p2p.initialize_p2p_groups(hcg, self._using_cache) + p2p.initialize_p2p_groups(hcg, self._using_cache, + self._enable_partial_send_recv) self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index c1cf0527e1b2b25062549bab485acda13da9bd2c..f5461b046d27a278848ecb3d65ba197f77506258 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -22,12 +22,14 @@ from .utils import paddle_2_number, paddle_2_number, number_2_dtype _hcg = None _use_cache = False +_enable_partial_send_recv = True -def initialize_p2p_groups(hcg, use_cache=True): - global _hcg, _use_cache +def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True): + global _hcg, _use_cache, _enable_partial_send_recv _hcg = hcg _use_cache = use_cache + _enable_partial_send_recv = enable_partial_send_recv send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups( ) @@ -157,7 +159,8 @@ _send_recv_meta = SendRecvMeta() def _is_valid_send_recv_partial(tensor, mp_degree): - + if not _enable_partial_send_recv: + return False tensor_numel = np.prod(tensor.shape) assert tensor_numel != 0, "can't send/recv zero element" return mp_degree > 1 and tensor_numel % mp_degree == 0