diff --git a/paddle/fluid/operators/collective/partial_allgather_op.cc b/paddle/fluid/operators/collective/partial_allgather_op.cc index bbe537823474162c53e5e0301c4e3ddaa6594ac8..bef2ff94d630853487bb5c04798387113e5567ae 100644 --- a/paddle/fluid/operators/collective/partial_allgather_op.cc +++ b/paddle/fluid/operators/collective/partial_allgather_op.cc @@ -68,14 +68,19 @@ reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us } }; +DECLARE_INPLACE_OP_INFERER(PartialAllGatherOpInplaceInferer, {"X", "Out"}); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_WITHOUT_GRADIENT(partial_allgather, ops::PartialAllGatherOp, - ops::PartialAllGatherOpMaker); +REGISTER_OPERATOR( + partial_allgather, ops::PartialAllGatherOp, ops::PartialAllGatherOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::PartialAllGatherOpInplaceInferer) REGISTER_OP_CPU_KERNEL(partial_allgather, ops::PartialAllGatherOpCPUKernel, diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index b2205391a253c35f1c1e2852ddfe1a28666066b9..d81783c677622a584fb31f8d170552ec3d2ed660 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -126,6 +126,7 @@ std::map> op_passing_outs_map = { {"accuracy", {"Correct", "Total"}}, {"fill_constant", {"Out"}}, {"recv_v2", {"Out"}}, + {"partial_recv", {"Out"}}, {"matmul", {"Out"}}, {"c_broadcast", {"Out"}}, {"c_sync_calc_stream", {"Out"}}, diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index c30167bb7c52b999b105304cbe8c98be06e003bb..9f2a4aaffb4745c087f5d204d04f89e8bd864d27 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from types import MethodType +import numpy as np import paddle import paddle.fluid as fluid @@ -39,6 +39,8 @@ class PipelineParallel(MetaParallelBase): self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 + self.is_pipe_partitioned = self.use_model_parallel + self.num_caches = 0 self.caches = { 'inputs': [], @@ -70,6 +72,9 @@ class PipelineParallel(MetaParallelBase): self.is_last_stage = (self.stage_id == (self.num_stages - 1)) self.global_rank = self._hcg.get_global_rank() + self.mp_degree = self._hcg.get_model_parallel_world_size() + self.mp_rank = self._hcg.get_model_parallel_rank() + logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( self.num_stages, self.stage_id)) @@ -159,8 +164,8 @@ class PipelineParallel(MetaParallelBase): else: inputs = self.caches['inputs'][cache_id] - outputs = self._layers.forward(inputs) self._clear_grads(inputs) + outputs = self._layers.forward(inputs) self.caches['outputs'][cache_id] = outputs @@ -369,6 +374,11 @@ class PipelineParallel(MetaParallelBase): caches = tuple(caches) return caches + def _is_valid_send_recv(self, tensor): + tensor_numel = np.prod(tensor.shape) + assert tensor_numel != 0, "can't send/recv zero element" + return tensor_numel % self.mp_degree == 0 + def _send_activations(self, cache_id): outputs = self.caches['outputs'][cache_id] @@ -377,24 +387,56 @@ class PipelineParallel(MetaParallelBase): self._send_meta(outputs, self.next_stage_id) if isinstance(outputs, paddle.Tensor): - p2p.send(outputs, self.next_stage_id) + if self.is_pipe_partitioned and self._is_valid_send_recv(outputs): + p2p.send_partial( + outputs.detach(), + self.next_stage_id, + mp_degree=self.mp_degree, + mp_rank=self.mp_rank) + else: + p2p.send(outputs.detach(), self.next_stage_id) elif isinstance(outputs, tuple): for output in outputs: - p2p.send(output, self.next_stage_id) + if self.is_pipe_partitioned and self._is_valid_send_recv( + output): + p2p.send_partial( + output.detach(), + self.next_stage_id, + mp_degree=self.mp_degree, + mp_rank=self.mp_rank) + else: + p2p.send(output.detach(), self.next_stage_id) def _send_gradients(self, cache_id): inputs = self.caches['inputs'][cache_id] if isinstance(inputs, paddle.Tensor): assert inputs.grad is not None - p2p.send(inputs.grad, self.prev_stage_id) + if self.is_pipe_partitioned and self._is_valid_send_recv( + inputs.grad): + grad = p2p.send_partial( + inputs.grad, + self.prev_stage_id, + mp_degree=self.mp_degree, + mp_rank=self.mp_rank) + else: + p2p.send(inputs.grad, self.prev_stage_id) else: for idx, d in enumerate(inputs): # Skip tensors that will not produce a grad if not is_float_tensor(d): assert d.grad is None continue - p2p.send(d.grad, self.prev_stage_id) + + if self.is_pipe_partitioned and self._is_valid_send_recv( + d.grad): + grad = p2p.send_partial( + d.grad, + self.prev_stage_id, + mp_degree=self.mp_degree, + mp_rank=self.mp_rank) + else: + p2p.send(d.grad, self.prev_stage_id) self.caches['inputs'][cache_id] = None @@ -404,15 +446,39 @@ class PipelineParallel(MetaParallelBase): self.recv_cache = self._recv_meta(self.prev_stage_id) if isinstance(self.recv_cache, paddle.Tensor): - p2p.recv(self.recv_cache, self.prev_stage_id) + if self.is_pipe_partitioned and self._is_valid_send_recv( + self.recv_cache): + p2p.recv_partial(self.recv_cache, self.prev_stage_id, + self.mp_degree, self.mp_rank) + p2p.partial_allgather_operator( + self.recv_cache, + mp_ranks=self.mp_degree, + mp_rank_id=self.mp_rank, + group=self._hcg.get_model_parallel_group(), + use_calc_stream=True) + else: + p2p.recv(self.recv_cache, self.prev_stage_id) + inputs = self.recv_cache.clone().detach() inputs.stop_gradient = not is_float_tensor(inputs) + else: assert isinstance(self.recv_cache, tuple) inputs = [None] * len(self.recv_cache) for idx, d in enumerate(self.recv_cache): - assert isinstance(d, paddle.Tensor) - p2p.recv(d, self.prev_stage_id) + if self.is_pipe_partitioned and self._is_valid_send_recv(d): + assert isinstance(d, paddle.Tensor) + p2p.recv_partial(d, self.prev_stage_id, self.mp_degree, + self.mp_rank) + p2p.partial_allgather_operator( + d, + mp_ranks=self.mp_degree, + mp_rank_id=self.mp_rank, + group=self._hcg.get_model_parallel_group(), + use_calc_stream=True) + else: + assert isinstance(d, paddle.Tensor) + p2p.recv(d, self.prev_stage_id) inputs[idx] = d.clone().detach() inputs = tuple(inputs) @@ -440,11 +506,33 @@ class PipelineParallel(MetaParallelBase): sizes, dtypes, num_caches=1)[0] if isinstance(self.grad_tensors, paddle.Tensor): - p2p.recv(self.grad_tensors, self.next_stage_id) + if self.is_pipe_partitioned and self._is_valid_send_recv( + self.grad_tensors): + p2p.recv_partial(self.grad_tensors, self.next_stage_id, + self.mp_degree, self.mp_rank) + p2p.partial_allgather_operator( + self.grad_tensors, + mp_ranks=self.mp_degree, + mp_rank_id=self.mp_rank, + group=self._hcg.get_model_parallel_group(), + use_calc_stream=True) + else: + p2p.recv(self.grad_tensors, self.next_stage_id) + else: assert isinstance(outputs, tuple) for d in self.grad_tensors: - p2p.recv(d, self.next_stage_id) + if self.is_pipe_partitioned and self._is_valid_send_recv(d): + p2p.recv_partial(d, self.next_stage_id, self.mp_degree, + self.mp_rank) + p2p.partial_allgather_operator( + d, + mp_ranks=self.mp_degree, + mp_rank_id=self.mp_rank, + group=self._hcg.get_model_parallel_group(), + use_calc_stream=True) + else: + p2p.recv(d, self.next_stage_id) def _step(self): if self.scaler: diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index f81164b778cc27158b44fa573d4dedfcf9a698dd..44090be94f1a7d15df8c937941d32a2fe6c7b54a 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -27,15 +27,67 @@ def initialize_p2p_groups(hcg): _hcg = hcg +def _is_valid_communciate(src_stage, dest_stage): + first_stage = 0 + last_stage = _hcg.get_pipe_parallel_world_size() - 1 + assert abs(src_stage-dest_stage) == 1 or \ + (src_stage == first_stage and dest_stage == last_stage) or \ + (src_stage == last_stage and dest_stage == first_stage) + + +def partial_send_operator(tensor, + dst=0, + mp_ranks=1, + mp_rank_id=0, + group=None, + use_calc_stream=True): + + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + return paddle.fluid.core.ops.partial_send( + tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', + dst, 'num', mp_ranks, 'id', mp_rank_id) + + +def partial_recv_operator(tensor, + src=0, + mp_ranks=1, + mp_rank_id=0, + group=None, + use_calc_stream=True): + + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + + return paddle.fluid.core.ops.partial_recv( + tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', + src, 'num', mp_ranks, 'id', mp_rank_id, 'dtype', tensor.dtype, + 'out_shape', tensor.shape) + + +def partial_allgather_operator(tensor, + mp_ranks=1, + mp_rank_id=0, + group=None, + use_calc_stream=True): + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + + return paddle.fluid.core.ops.partial_allgather_( + tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, + 'nranks', mp_ranks, 'rank', mp_rank_id) + + def send(tensor, dest_stage): global _groups, _hcg src_stage = _hcg.get_stage_id() - src_rank = _hcg.get_rank_from_stage(stage_id=src_stage) - _is_valid_communciate(src_stage, dest_stage) group = _get_send_recv_group(src_stage, dest_stage) - dst_rank = _hcg.get_rank_from_stage(stage_id=dest_stage) - return paddle.distributed.broadcast(tensor, src_rank, group=group) + return paddle.distributed.send( + tensor, dst=1 if dest_stage > src_stage else 0, group=group) def recv(tensor, src_stage): @@ -44,16 +96,35 @@ def recv(tensor, src_stage): _is_valid_communciate(src_stage, dest_stage) group = _get_send_recv_group(src_stage, dest_stage) - src_rank = _hcg.get_rank_from_stage(stage_id=src_stage) - return paddle.distributed.broadcast(tensor, src_rank, group=group) + return paddle.distributed.recv( + tensor, src=0 if dest_stage > src_stage else 1, group=group) -def _is_valid_communciate(src_stage, dest_stage): - first_stage = 0 - last_stage = _hcg.get_pipe_parallel_world_size() - 1 - assert abs(src_stage-dest_stage) == 1 or \ - (src_stage == first_stage and dest_stage == last_stage) or \ - (src_stage == last_stage and dest_stage == first_stage) +def send_partial(tensor, dest_stage, mp_degree, mp_rank): + global _groups, _hcg + src_stage = _hcg.get_stage_id() + _is_valid_communciate(src_stage, dest_stage) + group = _get_send_recv_group(src_stage, dest_stage) + return partial_send_operator( + tensor, + dst=1 if dest_stage > src_stage else 0, + mp_ranks=mp_degree, + mp_rank_id=mp_rank, + group=group) + + +def recv_partial(tensor, src_stage, mp_degree, mp_rank): + global _groups, _hcg + dest_stage = _hcg.get_stage_id() + + _is_valid_communciate(src_stage, dest_stage) + group = _get_send_recv_group(src_stage, dest_stage) + return partial_recv_operator( + tensor, + src=0 if dest_stage > src_stage else 1, + mp_ranks=mp_degree, + mp_rank_id=mp_rank, + group=group) def _get_send_recv_group(src_stage, dest_stage):