diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 0ca900a0fa737084763468ee69e861e27540c785..09d140c5416538fc6830a60e05a703df2b9611d4 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -214,6 +214,16 @@ class ProcessGroup { "ProcessGroup%s does not support AllGather_Partial", GetBackendName())); } + virtual std::shared_ptr AllGather_Partial( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + int offset, + int length, + bool) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support AllGather_Partial", GetBackendName())); + } + virtual std::shared_ptr AllToAll( std::vector&, // NOLINT std::vector&) { // NOLINT diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 12f60faf80053294ca5ac2fdff0a62deab96578a..b406f596401effec68adaf45c61248a0053f64ed 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -1034,6 +1034,41 @@ std::shared_ptr ProcessGroupNCCL::AllGather_Partial( CommType::ALLGATHER); } +std::shared_ptr ProcessGroupNCCL::AllGather_Partial( + std::vector& in_tensors, + std::vector& out_tensors, + int offset, + int length, + bool sync_op, + bool use_calc_stream) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(out_tensors), + true, + platform::errors::InvalidArgument("All outputs should be in CudaPlace.")); + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream) { + return platform::dynload::ncclAllGather( + GetPointerByOffset(input.data(), offset, input.dtype()), + output.data(), + length, + platform::ToNCCLDataType(input.dtype()), + comm, + stream); + }, + CommType::ALLGATHER, + sync_op, + use_calc_stream); +} + std::shared_ptr ProcessGroupNCCL::AllToAll( std::vector& in_tensors, std::vector& out_tensors) { diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 24ba7c86b18386b05f18b6dd5ea2fb6cc393406a..6d15f6ebdeff05a5dd5c00f12838debaa9772626 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -182,6 +182,14 @@ class ProcessGroupNCCL : public ProcessGroupStream { int offset, int length) override; + std::shared_ptr AllGather_Partial( + std::vector& in_tensors, + std::vector& out_tensors, + int offset, + int length, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr AllToAll( std::vector& in, std::vector& out) override; diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.cc b/paddle/fluid/distributed/collective/ProcessGroupStream.cc index 43ca0bbb36d3f6df4aebbff2337b16df49e35ad6..222fe03b60bf19d3ce41a24d0e69873a8b668b15 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.cc @@ -154,5 +154,30 @@ std::shared_ptr ProcessGroupStream::Recv_Partial( "ProcessGroup%s does not support do recv_partial", GetBackendName())); } +std::shared_ptr ProcessGroupStream::AllGather_Partial( + std::vector& in_tensors, + std::vector& out_tensors, + int offset, + int length, + bool sync_op) { + return AllGather_Partial(in_tensors, + out_tensors, + offset, + length, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::AllGather_Partial( + std::vector& in_tensors, + std::vector& out_tensors, + int offset, + int length, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do recv_partial", GetBackendName())); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.h b/paddle/fluid/distributed/collective/ProcessGroupStream.h index f8ab562ad075cb4815a8a655d12647614cdd5fcd..1162c3e050925ab2e525b2ef25b0c94defe02733 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.h +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.h @@ -132,6 +132,21 @@ class ProcessGroupStream : public ProcessGroup { int length, bool sync_op, bool use_calc_stream); + + std::shared_ptr AllGather_Partial( + std::vector& in_tensors, + std::vector& out_tensors, + int offset, + int length, + bool sync_op) override; + + virtual std::shared_ptr AllGather_Partial( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + int offset, + int length, + bool sync_op, + bool use_calc_stream); }; } // namespace distributed diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 6a2d099ac177f84a0f2475dcbf1bd7e1531be2c2..39986c604c03630177934e0ff3be5d6c7fb04ccd 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -621,6 +621,37 @@ void BindDistributed(py::module *m) { py::arg("op"), py::call_guard()) + .def( + "all_gather_partial_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + int nranks, + int rank_id) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector in_tensors = {*in_dense}; + std::vector out_tensors = {*out_dense}; + int numel = (*in_dense).numel(); + int send_numel = numel / nranks; + int offset = send_numel * rank_id; + return self.AllGather_Partial(in_tensors, + out_tensors, + offset, + send_numel, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("in"), + py::arg("out"), + py::arg("num"), + py::arg("id"), + py::call_guard()) + .def( "send_on_calc_stream", [](distributed::ProcessGroupStream &self, diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index dc4dc05c7ba41a53585d77e3d0a6df0dd093b433..3506851e1db347fd810f5caf1fee77ad5f70f073 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -43,7 +43,26 @@ def _c_identity(tensor, group=None): return ring_id = 0 if group is None else group.id - if _non_static_mode(): + if in_dygraph_mode(): + from paddle.autograd import PyLayer + + class c_identity_eager(PyLayer): + + @staticmethod + def forward(ctx, tensor): + return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True, + 'ring_id', group.id, + 'use_model_parallel', True) + + @staticmethod + def backward(ctx, dy): + op_type = collective._get_reduce_op(ReduceOp.SUM, "_c_identity") + group.process_group.allreduce_on_calc_stream(dy, op_type) + return dy + + return c_identity_eager.apply(tensor) + + elif _in_legacy_dygraph(): return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True, 'ring_id', ring_id, 'use_model_parallel', True) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index e2ca6f8d2a034c4f036bdfee35c87b2cfad5312e..c1cf0527e1b2b25062549bab485acda13da9bd2c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -173,7 +173,9 @@ def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - return group.process_group.send_partial(tensor, dst, nranks, rank_id) + comm_op = group.process_group.send_partial_on_calc_stream \ + if use_calc_stream else group.process_group.send_partial + return comm_op(tensor, dst, nranks, rank_id) def send_partial(tensor, @@ -212,12 +214,9 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - task = group.process_group.recv_partial(tensor, src, nranks, rank_id) - if use_calc_stream: - task.wait() - return None - else: - return task + comm_op = group.process_group.recv_partial_on_calc_stream \ + if use_calc_stream else group.process_group.recv_partial + return comm_op(tensor, src, nranks, rank_id) def recv_partial(tensor, @@ -255,13 +254,9 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - task = group.process_group.all_gather_partial(tensor, tensor, nranks, - rank_id) - if use_calc_stream: - task.wait() - return None - else: - return task + comm_op = group.process_group.all_gather_partial_on_calc_stream \ + if use_calc_stream else group.process_group.all_gather_partial + return comm_op(tensor, tensor, nranks, rank_id) def allgather_partial(tensor,