diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 0937b2674613288e32975e1aa9827fce583caad5..10b1686ddb85fe11951db0812da1c5fc9c7ef0e7 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -89,6 +89,11 @@ class ProcessGroup { int GetSize() const { return size_; } virtual const std::string GetBackendName() const = 0; + virtual phi::DeviceContext* GetDeviceContext(const Place& place) const { + PADDLE_THROW(platform::errors::InvalidArgument( + "Does not support to get device_context from ProcessGroup%s.", + GetBackendName())); + } // TODO(liyurui): This API will be moved later virtual std::shared_ptr AllReduce( diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index a5260ac3b2ef1ba81f1fbbea202d226590ec1f62..239114ae6188cea31d1be5d94521d6c461d3c356 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/platform/place.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/core/device_context.h" DECLARE_bool(nccl_blocking_wait); DECLARE_bool(use_stream_safe_cuda_allocator); @@ -1041,5 +1042,16 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const { return iter->second[0]->GetNcclComm(); } +phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( + const Place& place) const { + std::vector places = {place}; + const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places)); + PADDLE_ENFORCE_NE(iter, + places_to_ctx_.end(), + platform::errors::InvalidArgument( + "Cannot find device context in process group.")); + return iter->second[0].get(); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 50ef0b1f1ac28e31ffd7f91ab3fefd3b3cb87e8c..e0e298e9113e9e9a9d73d0339e562b0b39998e51 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -96,6 +96,8 @@ class ProcessGroupNCCL : public ProcessGroupStream { return std::string(NCCL_BACKEND_NAME); } + phi::DeviceContext* GetDeviceContext(const Place& place) const override; + std::shared_ptr AllReduce( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT