未验证 提交 b7d219be 编写于 作者: L LiYuRio 提交者: GitHub

add device context getter (#45790)

上级 1a372bd1
......@@ -89,6 +89,11 @@ class ProcessGroup {
int GetSize() const { return size_; }
virtual const std::string GetBackendName() const = 0;
virtual phi::DeviceContext* GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::InvalidArgument(
"Does not support to get device_context from ProcessGroup%s.",
GetBackendName()));
}
// TODO(liyurui): This API will be moved later
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
......
......@@ -20,6 +20,7 @@
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);
......@@ -1041,5 +1042,16 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
return iter->second[0]->GetNcclComm();
}
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_ctx_.end(),
platform::errors::InvalidArgument(
"Cannot find device context in process group."));
return iter->second[0].get();
}
} // namespace distributed
} // namespace paddle
......@@ -96,6 +96,8 @@ class ProcessGroupNCCL : public ProcessGroupStream {
return std::string(NCCL_BACKEND_NAME);
}
phi::DeviceContext* GetDeviceContext(const Place& place) const override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册