From 5c9fce0e04a166bedb75d4826850e359d1125cfc Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Thu, 1 Jul 2021 19:19:34 +0800 Subject: [PATCH] add p2p (#33864) --- .../paddle/distributed/fleet/base/topology.py | 21 ++++ .../fleet/meta_parallel/pipeline_parallel.py | 110 ++++++------------ .../pp_utils/p2p_communication.py | 70 +++++++++++ 3 files changed, 126 insertions(+), 75 deletions(-) create mode 100644 python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 3e89e9de181..004b3fb0f66 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -164,9 +164,27 @@ class HybridCommunicateGroup(object): self._dp_group, self._check_group) logger.info(debug_str) + # create p2p_groups and no new group + self._p2p_groups = self._build_p2p_lists() + global _HYBRID_PARALLEL_GROUP _HYBRID_PARALLEL_GROUP = self + def _build_p2p_lists(self): + comm_lists = self._topo.get_comm_list('pipe') + p2p_lists = [] + for rank in range(self.nranks): + for comm_ranks in comm_lists: + assert len(comm_ranks) == self._pp_degree + if rank in comm_ranks: + idx = comm_ranks.index(rank) + next_rank = comm_ranks[(idx + 1) % self._pp_degree] + p2p_lists.append([rank, next_rank]) + break + assert len( + p2p_lists) == self.nranks, "len(p2p_lists) should be equal nranks" + return p2p_lists + def get_parallel_mode(self): # there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and @@ -286,6 +304,9 @@ class HybridCommunicateGroup(object): # TODO should the src rank related to the shard rank for each parameter ? return self._sharding_comm_group.ranks[0] + def get_p2p_groups(self): + return self._p2p_groups + # check parallel group def get_check_parallel_group(self): return self._check_comm_group diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 0bb6315290e..343e6db04c2 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -24,6 +24,7 @@ from ..utils.hybrid_parallel_util import broadcast_mp_parameters from ..utils.hybrid_parallel_util import broadcast_dp_parameters from ..utils.log_util import logger from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer +from .pp_utils import p2p_communication as p2p __all__ = [] @@ -63,6 +64,7 @@ class PipelineParallel(MetaParallelBase): self.prev_stage_id = self.stage_id - 1 self.next_stage_id = self.stage_id + 1 self.pp_group = self._hcg.get_pipe_parallel_group() + p2p.initialize_p2p_groups(hcg) self.is_first_stage = self.stage_id == 0 self.is_last_stage = (self.stage_id == (self.num_stages - 1)) @@ -275,97 +277,86 @@ class PipelineParallel(MetaParallelBase): if isinstance(data, paddle.Tensor): tensor_type = paddle.to_tensor([0]) # send tensor type - paddle.distributed.send( - tensor_type, peer, use_calc_stream=True, group=self.pp_group) + p2p.send(tensor_type, self.next_stage_id) # send len(shape) dims = paddle.to_tensor(len(data.shape)) - paddle.distributed.send( - dims, peer, use_calc_stream=True, group=self.pp_group) + p2p.send(dims, self.next_stage_id) # send shape shape = paddle.to_tensor(data.shape) - paddle.distributed.send( - shape, peer, use_calc_stream=True, group=self.pp_group) + p2p.send(shape, self.next_stage_id) # send dtype dtype = paddle.to_tensor(paddle_2_number(data.dtype)) - paddle.distributed.send( - dtype, peer, use_calc_stream=True, group=self.pp_group) + p2p.send(dtype, self.next_stage_id) elif isinstance(data, tuple): tensor_type = paddle.to_tensor([1]) - paddle.distributed.send( - tensor_type, peer, use_calc_stream=True, group=self.pp_group) + p2p.send(tensor_type, self.next_stage_id) + nums = paddle.to_tensor(len(data)) - paddle.distributed.send( - nums, peer, use_calc_stream=True, group=self.pp_group) + p2p.send(nums, self.next_stage_id) + for idx, d in enumerate(data): assert isinstance(d, paddle.Tensor) # send len(shape) dims = paddle.to_tensor(len(d.shape)) - paddle.distributed.send( - dims, peer, use_calc_stream=True, group=self.pp_group) + p2p.send(dims, self.next_stage_id) # send shape shape = paddle.to_tensor(d.shape) - paddle.distributed.send( - shape, peer, use_calc_stream=True, group=self.pp_group) + p2p.send(shape, self.next_stage_id) # send dtype dtype = paddle.to_tensor(paddle_2_number(d.dtype)) - paddle.distributed.send( - dtype, peer, use_calc_stream=True, group=self.pp_group) + p2p.send(dtype, self.next_stage_id) def _recv_meta(self, peer): tensor_type = paddle.to_tensor([0]) - paddle.distributed.recv( - tensor_type, peer, use_calc_stream=True, group=self.pp_group) + p2p.recv(tensor_type, self.prev_stage_id) + tensor_type = tensor_type.item() if tensor_type == 0: # recv len(shape) dims = paddle.to_tensor([0]) - paddle.distributed.recv( - dims, peer, use_calc_stream=True, group=self.pp_group) + p2p.recv(dims, self.prev_stage_id) + dims = dims.item() # recv shape shape = paddle.to_tensor([0] * dims) - paddle.distributed.recv( - shape, peer, use_calc_stream=True, group=self.pp_group) + p2p.recv(shape, self.prev_stage_id) + shape = shape.numpy().tolist() # recv dtype dtype = paddle.to_tensor([0]) - paddle.distributed.recv( - dtype, peer, use_calc_stream=True, group=self.pp_group) + p2p.recv(dtype, self.prev_stage_id) + return self._allocate_cache( shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0] elif tensor_type == 1: num = paddle.to_tensor([0]) - paddle.distributed.recv( - num, peer, use_calc_stream=True, group=self.pp_group) + p2p.recv(num, self.prev_stage_id) num = num.item() shapes = [] dtypes = [] for i in range(num): # recv len(shape) dims = paddle.to_tensor([0]) - paddle.distributed.recv( - dims, peer, use_calc_stream=True, group=self.pp_group) + p2p.recv(dims, self.prev_stage_id) # recv shape dims = dims.item() shape = paddle.to_tensor([0] * dims) - paddle.distributed.recv( - shape, peer, use_calc_stream=True, group=self.pp_group) + p2p.recv(shape, self.prev_stage_id) shapes.append(shape.numpy().tolist()) # recv dtype dtype = paddle.to_tensor([0]) - paddle.distributed.recv( - dtype, peer, use_calc_stream=True, group=self.pp_group) + p2p.recv(dtype, self.prev_stage_id) dtypes.append(number_2_dtype(dtype.item())) caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0] @@ -380,39 +371,25 @@ class PipelineParallel(MetaParallelBase): self._send_meta(outputs, self.next_stage_id) if isinstance(outputs, paddle.Tensor): - paddle.distributed.send( - outputs, - self.next_stage_id, - use_calc_stream=True, - group=self.pp_group) + p2p.send(outputs, self.next_stage_id) + elif isinstance(outputs, tuple): for output in outputs: - paddle.distributed.send( - output, - self.next_stage_id, - use_calc_stream=True, - group=self.pp_group) + p2p.send(output, self.next_stage_id) def _send_gradients(self, cache_id): inputs = self.caches['inputs'][cache_id] if isinstance(inputs, paddle.Tensor): assert inputs.grad is not None - paddle.distributed.send( - paddle.to_tensor(inputs.grad), - self.prev_stage_id, - use_calc_stream=True, - group=self.pp_group) + p2p.send(inputs.grad, self.prev_stage_id) else: for idx, d in enumerate(inputs): # Skip tensors that will not produce a grad if not is_float_tensor(d): assert d.grad is None continue - paddle.distributed.send( - d.grad, - self.prev_stage_id, - use_calc_stream=True, - group=self.pp_group) + p2p.send(d.grad, self.prev_stage_id) + self.caches['inputs'][cache_id] = None def _recv_activations(self, cache_id): @@ -421,11 +398,7 @@ class PipelineParallel(MetaParallelBase): self.recv_cache = self._recv_meta(self.prev_stage_id) if isinstance(self.recv_cache, paddle.Tensor): - paddle.distributed.recv( - self.recv_cache, - self.prev_stage_id, - use_calc_stream=True, - group=self.pp_group) + p2p.recv(self.recv_cache, self.prev_stage_id) inputs = self.recv_cache.clone().detach() inputs.stop_gradient = not is_float_tensor(inputs) else: @@ -433,12 +406,7 @@ class PipelineParallel(MetaParallelBase): inputs = [None] * len(self.recv_cache) for idx, d in enumerate(self.recv_cache): assert isinstance(d, paddle.Tensor) - - paddle.distributed.recv( - d, - self.prev_stage_id, - use_calc_stream=True, - group=self.pp_group) + p2p.recv(d, self.prev_stage_id) inputs[idx] = d.clone().detach() inputs = tuple(inputs) @@ -466,19 +434,11 @@ class PipelineParallel(MetaParallelBase): sizes, dtypes, num_caches=1)[0] if isinstance(self.grad_tensors, paddle.Tensor): - paddle.distributed.recv( - self.grad_tensors, - self.next_stage_id, - use_calc_stream=True, - group=self.pp_group) + p2p.recv(self.grad_tensors, self.next_stage_id) else: assert isinstance(outputs, tuple) for d in self.grad_tensors: - paddle.distributed.recv( - d, - self.next_stage_id, - use_calc_stream=True, - group=self.pp_group) + p2p.recv(d, self.next_stage_id) def _step(self): self.optimizer.step() diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py new file mode 100644 index 00000000000..c6131106122 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.distributed as dist + +_groups = None +_hcg = None + + +def initialize_p2p_groups(hcg): + global _groups, _hcg + _groups = [dist.new_group(ranks=group) for group in hcg.get_p2p_groups()] + _hcg = hcg + + +def send(tensor, dest_stage): + global _groups, _hcg + src_stage = _hcg.get_stage_id() + src_rank = _hcg.get_rank_from_stage(stage_id=src_stage) + + _is_valid_communciate(src_stage, dest_stage) + group = _get_send_recv_group(src_stage, dest_stage) + dst_rank = _hcg.get_rank_from_stage(stage_id=dest_stage) + return dist.broadcast(tensor, src_rank, group=group) + + +def recv(tensor, src_stage): + global _groups, _hcg + dest_stage = _hcg.get_stage_id() + + _is_valid_communciate(src_stage, dest_stage) + group = _get_send_recv_group(src_stage, dest_stage) + src_rank = _hcg.get_rank_from_stage(stage_id=src_stage) + return dist.broadcast(tensor, src_rank, group=group) + + +def _is_valid_communciate(src_stage, dest_stage): + first_stage = 0 + last_stage = _hcg.get_pipe_parallel_world_size() - 1 + assert abs(src_stage-dest_stage) == 1 or \ + (src_stage == first_stage and dest_stage == last_stage) or \ + (src_stage == last_stage and dest_stage == first_stage) + + +def _get_send_recv_group(src_stage, dest_stage): + global _groups, _hcg + stage_id = None + first_stage = 0 + last_stage = _hcg.get_pipe_parallel_world_size() - 1 + if (src_stage == first_stage and dest_stage == last_stage) or \ + (dest_stage == first_stage and src_stage == last_stage): + stage_id = last_stage + elif src_stage > dest_stage: + stage_id = dest_stage + else: + stage_id = src_stage + group_id = _hcg.get_rank_from_stage(stage_id=stage_id) + return _groups[group_id] -- GitLab