From b7d219be656065b07dfb0e8e911e50d749da3a59 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Wed, 7 Sep 2022 11:22:12 +0800 Subject: [PATCH] add device context getter (#45790) --- paddle/fluid/distributed/collective/ProcessGroup.h | 5 +++++ .../fluid/distributed/collective/ProcessGroupNCCL.cc | 12 ++++++++++++ .../fluid/distributed/collective/ProcessGroupNCCL.h | 2 ++ 3 files changed, 19 insertions(+) diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 0937b267461..10b1686ddb8 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -89,6 +89,11 @@ class ProcessGroup { int GetSize() const { return size_; } virtual const std::string GetBackendName() const = 0; + virtual phi::DeviceContext* GetDeviceContext(const Place& place) const { + PADDLE_THROW(platform::errors::InvalidArgument( + "Does not support to get device_context from ProcessGroup%s.", + GetBackendName())); + } // TODO(liyurui): This API will be moved later virtual std::shared_ptr AllReduce( diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index a5260ac3b2e..239114ae618 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/platform/place.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/core/device_context.h" DECLARE_bool(nccl_blocking_wait); DECLARE_bool(use_stream_safe_cuda_allocator); @@ -1041,5 +1042,16 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const { return iter->second[0]->GetNcclComm(); } +phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( + const Place& place) const { + std::vector places = {place}; + const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places)); + PADDLE_ENFORCE_NE(iter, + places_to_ctx_.end(), + platform::errors::InvalidArgument( + "Cannot find device context in process group.")); + return iter->second[0].get(); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 50ef0b1f1ac..e0e298e9113 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -96,6 +96,8 @@ class ProcessGroupNCCL : public ProcessGroupStream { return std::string(NCCL_BACKEND_NAME); } + phi::DeviceContext* GetDeviceContext(const Place& place) const override; + std::shared_ptr AllReduce( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT -- GitLab