diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 1b712a857c37cb62f934e1121c4f7ee9c00d4b72..7abecd36e3d00f5f5e27c627764dbdcea6a3d5ea 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -185,11 +185,12 @@ class ProcessGroup { GetBackendName())); } - virtual std::shared_ptr Send(phi::DenseTensor*, - int dst_rank, - int64_t offset, - int64_t numel, - bool sync_op) { + virtual std::shared_ptr Send( + const phi::DenseTensor& tensor, + int dst_rank, + int64_t offset, + int64_t numel, + bool sync_op) { PADDLE_THROW(platform::errors::Unimplemented( "ProcessGroup%s does not support send with sync_op flag.", GetBackendName())); diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 016249963d79b682dad98b8ba1800821eeb0678e..96666f50c91ef76bc43cea8f5506c1c02b2949f7 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -137,21 +137,17 @@ std::shared_ptr ProcessGroupNCCL::AllGather( // numel > 0 indicates the tensor need to be sliced const phi::DenseTensor& in_tensor_maybe_partial = numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; - return Collective( - out_tensor, - in_tensor_maybe_partial, - [](phi::DenseTensor* output, - const phi::DenseTensor& input, - ncclComm_t comm, - gpuStream_t stream) { + return RunFnInNCCLEnv( + [&](ncclComm_t comm, gpuStream_t stream) { NCCL_CHECK(platform::dynload::ncclAllGather( - input.data(), - output->data(), - input.numel(), - platform::ToNCCLDataType(input.dtype()), + in_tensor_maybe_partial.data(), + out_tensor->data(), + in_tensor_maybe_partial.numel(), + platform::ToNCCLDataType(in_tensor_maybe_partial.dtype()), comm, stream)); }, + in_tensor_maybe_partial, CommType::ALLGATHER, sync_op, use_calc_stream); @@ -163,22 +159,18 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( const AllreduceOptions& opts, bool sync_op, bool use_calc_stream) { - return Collective( - out_tensor, - in_tensor, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - ncclComm_t comm, - gpuStream_t stream) { + return RunFnInNCCLEnv( + [&](ncclComm_t comm, gpuStream_t stream) { NCCL_CHECK(platform::dynload::ncclAllReduce( - input.data(), - output->data(), - input.numel(), - platform::ToNCCLDataType(input.type()), + in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + platform::ToNCCLDataType(in_tensor.dtype()), ToNCCLRedType(opts.reduce_op), comm, stream)); }, + in_tensor, CommType::ALLREDUCE, sync_op, use_calc_stream); @@ -215,37 +207,32 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( CheckSizeOnEachRank(out_dim, out_size_each_rank, size_); CheckSizeOnEachRank(in_dim, in_size_each_rank, size_); - return Collective( - out_tensor, - in_tensor, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - ncclComm_t comm, - gpuStream_t stream) { - int64_t in_row_size = input.numel() / in_dim[0], - out_row_size = output->numel() / out_dim[0]; + return RunFnInNCCLEnv( + [&](ncclComm_t comm, gpuStream_t stream) { + int64_t in_row_size = in_tensor.numel() / in_dim[0], + out_row_size = out_tensor->numel() / out_dim[0]; int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; phi::DenseTensor input_partial, output_partial; GroupStart(); for (auto i = 0; i < size_; i++) { in_numel = in_size_each_rank[i] * in_row_size; - input_partial = GetPartialTensor(input, in_offset, in_numel); + input_partial = GetPartialTensor(in_tensor, in_offset, in_numel); NCCL_CHECK(platform::dynload::ncclSend( input_partial.data(), in_numel, - platform::ToNCCLDataType(input.dtype()), + platform::ToNCCLDataType(input_partial.dtype()), i, comm, stream)); in_offset += in_numel; out_numel = out_size_each_rank[i] * out_row_size; - output_partial = GetPartialTensor(*output, out_offset, out_numel); + output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel); NCCL_CHECK(platform::dynload::ncclRecv( output_partial.data(), out_numel, - platform::ToNCCLDataType(output->dtype()), + platform::ToNCCLDataType(output_partial.dtype()), i, comm, stream)); @@ -253,6 +240,7 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( } GroupEnd(); }, + in_tensor, CommType::ALLTOALL, sync_op, use_calc_stream); @@ -286,23 +274,19 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( const BroadcastOptions& opts, bool sync_op, bool use_calc_stream) { - return Collective( - out_tensor, - in_tensor, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - ncclComm_t comm, - gpuStream_t stream) { + return RunFnInNCCLEnv( + [&](ncclComm_t comm, gpuStream_t stream) { int root = opts.source_rank + opts.source_root; NCCL_CHECK(platform::dynload::ncclBroadcast( - input.data(), - output->data(), - input.numel(), - platform::ToNCCLDataType(input.type()), + in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + platform::ToNCCLDataType(in_tensor.dtype()), root, comm, stream)); }, + in_tensor, CommType::BROADCAST, sync_op, use_calc_stream); @@ -314,23 +298,19 @@ std::shared_ptr ProcessGroupNCCL::Reduce( const ReduceOptions& opts, bool sync_op, bool use_calc_stream) { - return Collective( - out_tensor, - in_tensor, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - ncclComm_t comm, - gpuStream_t stream) { + return RunFnInNCCLEnv( + [&](ncclComm_t comm, gpuStream_t stream) { NCCL_CHECK(platform::dynload::ncclReduce( - input.data(), - output->data(), - input.numel(), - platform::ToNCCLDataType(input.dtype()), + in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + platform::ToNCCLDataType(in_tensor.dtype()), ToNCCLRedType(opts.reduce_op), opts.root_rank, comm, stream)); }, + in_tensor, CommType::REDUCE, sync_op, use_calc_stream); @@ -342,22 +322,18 @@ std::shared_ptr ProcessGroupNCCL::ReduceScatter( const ReduceScatterOptions& opts, bool sync_op, bool use_calc_stream) { - return Collective( - out_tensor, - in_tensor, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - ncclComm_t comm, - gpuStream_t stream) { + return RunFnInNCCLEnv( + [&](ncclComm_t comm, gpuStream_t stream) { NCCL_CHECK(platform::dynload::ncclReduceScatter( - input.data(), - output->data(), - output->numel(), - platform::ToNCCLDataType(input.dtype()), + in_tensor.data(), + out_tensor->data(), + out_tensor->numel(), + platform::ToNCCLDataType(in_tensor.dtype()), ToNCCLRedType(opts.reduce_op), comm, stream)); }, + in_tensor, CommType::REDUCE_SCATTER, sync_op, use_calc_stream); @@ -369,47 +345,43 @@ std::shared_ptr ProcessGroupNCCL::Scatter( const ScatterOptions& opts, bool sync_op, bool use_calc_stream) { - return Collective( - out_tensor, - in_tensor, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - ncclComm_t comm, - gpuStream_t stream) { - int64_t numel = input.numel() / size_; + return RunFnInNCCLEnv( + [&](ncclComm_t comm, gpuStream_t stream) { + int64_t numel = in_tensor.numel() / size_; if (rank_ == opts.root_rank) { int64_t offset = 0; phi::DenseTensor partial_tensor; GroupStart(); for (auto i = 0; i < size_; i++) { - partial_tensor = GetPartialTensor(input, offset, numel); + partial_tensor = GetPartialTensor(in_tensor, offset, numel); NCCL_CHECK(platform::dynload::ncclSend( partial_tensor.data(), numel, - platform::ToNCCLDataType(input.dtype()), + platform::ToNCCLDataType(partial_tensor.dtype()), i, comm, stream)); offset += numel; } NCCL_CHECK(platform::dynload::ncclRecv( - output->data(), + out_tensor->data(), numel, - platform::ToNCCLDataType(output->dtype()), + platform::ToNCCLDataType(out_tensor->dtype()), opts.root_rank, comm, stream)); GroupEnd(); } else { NCCL_CHECK(platform::dynload::ncclRecv( - output->data(), + out_tensor->data(), numel, - platform::ToNCCLDataType(output->dtype()), + platform::ToNCCLDataType(out_tensor->dtype()), opts.root_rank, comm, stream)); } }, + in_tensor, CommType::SCATTER, sync_op, use_calc_stream); @@ -428,54 +400,43 @@ std::shared_ptr ProcessGroupNCCL::Recv( partial_tensor = GetPartialTensor(*tensor, offset, numel); tensor = &partial_tensor; } - return PointToPoint( - tensor, - src_rank, - [](phi::DenseTensor* output, - int src, - ncclComm_t comm, - gpuStream_t stream) { + return RunFnInNCCLEnv( + [&](ncclComm_t comm, gpuStream_t stream) { NCCL_CHECK(platform::dynload::ncclRecv( - output->data(), - output->numel(), - platform::ToNCCLDataType(output->dtype()), - src, + tensor->data(), + tensor->numel(), + platform::ToNCCLDataType(tensor->dtype()), + src_rank, comm, stream)); }, + *tensor, CommType::RECV, sync_op, use_calc_stream); } std::shared_ptr ProcessGroupNCCL::Send( - phi::DenseTensor* tensor, + const phi::DenseTensor& tensor, int dst_rank, int64_t offset, int64_t numel, bool sync_op, bool use_calc_stream) { // numel > 0 indicates the tensor need to be sliced - phi::DenseTensor partial_tensor; - if (numel > 0) { - partial_tensor = GetPartialTensor(*tensor, offset, numel); - tensor = &partial_tensor; - } - return PointToPoint( - tensor, - dst_rank, - [](phi::DenseTensor* input, - int dst, - ncclComm_t comm, - gpuStream_t stream) { + const phi::DenseTensor& tensor_maybe_partial = + numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor; + return RunFnInNCCLEnv( + [&](ncclComm_t comm, gpuStream_t stream) { NCCL_CHECK(platform::dynload::ncclSend( - input->data(), - input->numel(), - platform::ToNCCLDataType(input->dtype()), - dst, + tensor_maybe_partial.data(), + tensor_maybe_partial.numel(), + platform::ToNCCLDataType(tensor_maybe_partial.dtype()), + dst_rank, comm, stream)); }, + tensor_maybe_partial, CommType::SEND, sync_op, use_calc_stream); @@ -548,54 +509,13 @@ void ProcessGroupNCCL::SyncCalcStream(const Place& place) { calc_event.Wait(platform::Place2DeviceType(place), comm_ctx); } -template -std::shared_ptr ProcessGroupNCCL::Collective( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - Fn fn, - CommType comm_type, - bool sync_op, - bool use_calc_stream) { - const auto& place = in_tensor.place(); - const auto& key = GetKeyFromPlace(place); - - platform::CUDADeviceGuard cuda_guard(place); - - if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { - CreateNCCLEnvCache(place, key); - } - - if (!use_calc_stream) { - SyncCalcStream(place); - } - - auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream); - - const auto* calc_ctx = place_to_calc_ctx_.at(key); - const auto& comm_ctx = place_to_comm_ctx_.at(key); - auto nccl_comm = comm_ctx->nccl_comm(); - auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); - fn(out_tensor, in_tensor, nccl_comm, nccl_stream); - - if (!use_calc_stream) { - if (FLAGS_use_stream_safe_cuda_allocator) { - memory::RecordStream(in_tensor.Holder(), nccl_stream); - } - task->UpdateWaitChain(*comm_ctx); - } - - return task; -} - -template -std::shared_ptr ProcessGroupNCCL::PointToPoint( - phi::DenseTensor* tensor, - int rank, - Fn fn, +std::shared_ptr ProcessGroupNCCL::RunFnInNCCLEnv( + std::function fn, + const phi::DenseTensor& tensor, CommType comm_type, bool sync_op, bool use_calc_stream) { - const auto& place = tensor->place(); + const auto& place = tensor.place(); const auto& key = GetKeyFromPlace(place); platform::CUDADeviceGuard cuda_guard(place); @@ -614,11 +534,11 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( const auto& comm_ctx = place_to_comm_ctx_.at(key); auto nccl_comm = comm_ctx->nccl_comm(); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); - fn(tensor, rank, nccl_comm, nccl_stream); + fn(nccl_comm, nccl_stream); if (!use_calc_stream) { if (FLAGS_use_stream_safe_cuda_allocator) { - memory::RecordStream(tensor->Holder(), nccl_stream); + memory::RecordStream(tensor.Holder(), nccl_stream); } task->UpdateWaitChain(*comm_ctx); } diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 73d484caca16995aadb9285264cfe7c94542bef1..5153b7a678dd4dd770cf8d1fbac9b2c8b4a16606 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -150,7 +150,7 @@ class ProcessGroupNCCL final : public ProcessGroupStream { bool sync_op, bool use_calc_stream) override; - std::shared_ptr Send(phi::DenseTensor* tensor, + std::shared_ptr Send(const phi::DenseTensor& tensor, int dst_rank, int64_t offset, int64_t numel, @@ -210,23 +210,13 @@ class ProcessGroupNCCL final : public ProcessGroupStream { void CreateNCCLEnvCache(const Place& place, const std::string& place_key); - template - std::shared_ptr Collective( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - Fn fn, + std::shared_ptr RunFnInNCCLEnv( + std::function fn, + const phi::DenseTensor& tensor, CommType comm_type, bool sync_op, bool use_calc_stream); - template - std::shared_ptr PointToPoint(phi::DenseTensor* tensor, - int rank, - Fn fn, - CommType op_type, - bool sync_op, - bool use_calc_stream); - void SyncCalcStream(const Place& place); // TODO(sunyilun): methods below will be removed later diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.cc b/paddle/fluid/distributed/collective/ProcessGroupStream.cc index 332298ecfd4a24b708cce822f44745aa3f9adf4c..e1ee425f3f8888da27c966b6ef81058294e352db 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.cc @@ -212,7 +212,7 @@ std::shared_ptr ProcessGroupStream::Recv( } std::shared_ptr ProcessGroupStream::Send( - phi::DenseTensor* tensor, + const phi::DenseTensor& tensor, int dst_rank, int64_t offset, int64_t numel, @@ -226,7 +226,7 @@ std::shared_ptr ProcessGroupStream::Send( } std::shared_ptr ProcessGroupStream::Send( - phi::DenseTensor*, + const phi::DenseTensor& tensor, int dst_rank, int64_t offset, int64_t numel, diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.h b/paddle/fluid/distributed/collective/ProcessGroupStream.h index fcdbd88562edf785ddf557cd6a186c39fbaaf566..4ad75be3658b97ed28aa3de91d5ceb4652fe6d1a 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.h +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.h @@ -168,18 +168,19 @@ class ProcessGroupStream : public ProcessGroup { bool sync_op, bool use_calc_stream); - std::shared_ptr Send(phi::DenseTensor* tensor, + std::shared_ptr Send(const phi::DenseTensor& tensor, int dst_rank, int64_t offset, int64_t numel, bool sync_op) override; - virtual std::shared_ptr Send(phi::DenseTensor* tensor, - int dst_rank, - int64_t offset, - int64_t numel, - bool sync_op, - bool use_calc_stream); + virtual std::shared_ptr Send( + const phi::DenseTensor& tensor, + int dst_rank, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream); }; } // namespace distributed diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index 580f815c9ab6ebf5e35b5a8acc98e7160f777916..439630a7f1dd7c305e985fe0410e8af91bbaddd0 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -226,7 +226,7 @@ struct GlobalGatherProcessGroupFunctor { int idx = i + j * n_expert; if (cpu_global_count_data[idx]) { phi::DenseTensor tmp = *x; - pg->Send(&tmp, + pg->Send(tmp, j, send_ptr * in_feat, cpu_global_count_data[idx] * in_feat, diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index a6eb714662200045dcc88109383ece4772c99df4..4ccf9dee2631f294b2a0ae23763b828cc2fe0d8d 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -224,7 +224,7 @@ struct GlobalScatterProcessGroupFunctor { int idx = i + j * n_expert; if (cpu_local_count_data[idx]) { phi::DenseTensor tmp = *x; - pg->Send(&tmp, + pg->Send(tmp, j, expert_ptr[idx] * in_feat, cpu_local_count_data[idx] * in_feat, diff --git a/paddle/fluid/operators/collective/partial_send_op.cu.cc b/paddle/fluid/operators/collective/partial_send_op.cu.cc index b7196473c9ac1fc674cba1e9ead19042a0c20ee9..7d4125be8d32e740a29376f54a74cefb8d1812c3 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cu.cc @@ -70,7 +70,7 @@ class PartialSendCUDAKernel : public framework::OpKernel { // Use ProcessGroup distributed::ProcessGroup* pg = map->get(rid); phi::DenseTensor tmp = *x; - auto task = pg->Send(&tmp, peer, offset, send_numel, /*sync_op*/ true); + auto task = pg->Send(tmp, peer, offset, send_numel, /*sync_op*/ true); task->Wait(); } else { gpuStream_t stream = nullptr; diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 6b612fda5337ee2be9a1ce3e9208c526278bfe33..c5d03ce8853e312540d5d6f5a7f090e5d3852992 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -168,7 +168,7 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto p_dense = std::dynamic_pointer_cast(tensor.impl()); - auto *out_dense = p_dense.get(); + auto out_dense = *p_dense; // numel == -1 indicates sending the whole tensor return self.Send( out_dense, dst, /*offset*/ 0, /*numel*/ -1, sync_op); @@ -189,7 +189,7 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto p_dense = std::dynamic_pointer_cast(tensor.impl()); - auto *out_dense = p_dense.get(); + auto out_dense = *p_dense; int64_t numel = p_dense->numel(); int64_t send_numel = numel / nranks; @@ -1126,7 +1126,7 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto p_dense = std::dynamic_pointer_cast(tensor.impl()); - auto *out_dense = p_dense.get(); + auto out_dense = *p_dense; // numel == -1 indicates sending the whole tensor return self.Send(out_dense, dst, @@ -1149,7 +1149,7 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto p_dense = std::dynamic_pointer_cast(tensor.impl()); - auto *out_dense = p_dense.get(); + auto out_dense = *p_dense; int64_t numel = p_dense->numel(); int64_t send_numel = numel / nranks;