未验证 提交 403d58bb 编写于 作者: L LiYuRio 提交者: GitHub

return pointer rather than reference (#48152)

上级 02c51f3b
......@@ -96,7 +96,7 @@ class ProcessGroup {
virtual std::string GetBackendName() const = 0;
virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const {
virtual phi::DeviceContext* GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support get device_context.",
GetBackendName()));
......
......@@ -282,12 +282,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Barrier(
return task;
}
const phi::DeviceContext& ProcessGroupBKCL::GetDeviceContext(
phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext(
const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false);
}
const phi::DeviceContext& ProcessGroupBKCL::GetDeviceContext(
phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
const std::string& key = GetKeyFromPlace(place);
if (use_calc_stream) {
......
......@@ -77,10 +77,10 @@ class ProcessGroupBKCL : public ProcessGroupStream {
return std::string(BKCL_BACKEND_NAME);
}
const phi::DeviceContext& GetDeviceContext(const Place& place) const override;
phi::DeviceContext* GetDeviceContext(const Place& place) const override;
const phi::DeviceContext& GetDeviceContext(
const Place& place, bool use_calc_stream) const override;
phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
......
......@@ -299,7 +299,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
return task;
}
const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext(
phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
const Place& place) const {
const std::string key = GetKeyFromPlace(place);
const auto& iter = places_to_ctx_.find(key);
......@@ -308,7 +308,7 @@ const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext(
places_to_ctx_.end(),
platform::errors::NotFound(
"Cannot find the device context in this process group."));
return *iter->second[0];
return iter->second[0].get();
}
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
......
......@@ -93,7 +93,7 @@ class ProcessGroupCustom : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;
const phi::DeviceContext& GetDeviceContext(const Place& place) const override;
phi::DeviceContext* GetDeviceContext(const Place& place) const override;
phi::ccl::CCLComm CustomCCLComm(const Place& place) const;
......
......@@ -180,9 +180,8 @@ class ProcessGroupGloo : public ProcessGroup {
std::string GetBackendName() const override { return "GLOO"; }
const phi::DeviceContext& GetDeviceContext(
const Place& place) const override {
return *platform::DeviceContextPool::Instance().Get(place);
phi::DeviceContext* GetDeviceContext(const Place& place) const override {
return platform::DeviceContextPool::Instance().Get(place);
}
// Helper functions for Gloo.
......
......@@ -94,17 +94,17 @@ void ProcessGroupNCCL::GroupEnd() {
NCCL_CHECK(platform::dynload::ncclGroupEnd());
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false);
}
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
const std::string& key = GetKeyFromPlace(place);
if (use_calc_stream) {
const auto& iter = place_to_calc_ctx_.find(key);
return *iter->second;
return iter->second;
} else {
const auto& iter = place_to_comm_ctx_.find(key);
PADDLE_ENFORCE_NE(
......@@ -112,7 +112,7 @@ const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
place_to_comm_ctx_.end(),
platform::errors::NotFound(
"Cannot find the device context in this process group."));
return *iter->second;
return iter->second.get();
}
}
......
......@@ -34,7 +34,7 @@
#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#else
#elif PADDLE_WITH_NCCL
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
......@@ -83,10 +83,10 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
std::string GetBackendName() const override { return "NCCL"; }
const phi::DeviceContext& GetDeviceContext(const Place& place) const override;
phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override;
const phi::DeviceContext& GetDeviceContext(
const Place& place, bool use_calc_stream) const override;
phi::DeviceContext* GetDeviceContext(const Place& place) const override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
......
......@@ -20,7 +20,7 @@ namespace distributed {
ProcessGroupStream::ProcessGroupStream(int rank, int size, int gid)
: ProcessGroup(rank, size, gid) {}
const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
phi::DeviceContext* ProcessGroupStream::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support get device_context.", GetBackendName()));
......
......@@ -57,9 +57,10 @@ class ProcessGroupStream : public ProcessGroup {
public:
ProcessGroupStream(int rank, int size, int gid);
virtual ~ProcessGroupStream() = default;
using ProcessGroup::GetDeviceContext;
virtual const phi::DeviceContext& GetDeviceContext(
const Place& place, bool use_calc_stream) const;
virtual phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
......
......@@ -1053,9 +1053,9 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
}
group->task = process_group_->AllReduce(in_out, in_out, opts);
const auto &context = process_group_->GetDeviceContext(inner_place_);
group->SplitTensorsDev(context);
group->task->UpdateWaitChain(context);
auto *context = process_group_->GetDeviceContext(inner_place_);
group->SplitTensorsDev(*context);
group->task->UpdateWaitChain(*context);
// split in FinalizeBackward()
}
......
......@@ -271,14 +271,14 @@ void BindDistributed(py::module *m) {
in_tensor.impl());
auto in_dense = *p_in_tensor;
const auto &dev_ctx = self.GetDeviceContext(in_tensor.place());
auto *dev_ctx = self.GetDeviceContext(in_tensor.place());
auto task = self.AllGather(out_dense,
in_dense,
/*offset*/ 0,
/*numel*/ -1,
sync_op);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx);
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(*dev_ctx);
return task;
},
py::arg("out"),
......@@ -334,7 +334,7 @@ void BindDistributed(py::module *m) {
auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty
const auto &dev_ctx =
auto *dev_ctx =
self.GetDeviceContext(in_tensor_list.back().place());
int world_size = self.GetSize();
auto task =
......@@ -343,8 +343,8 @@ void BindDistributed(py::module *m) {
GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size),
sync_op);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx);
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(*dev_ctx);
return task;
},
py::arg("out"),
......@@ -770,15 +770,14 @@ void BindDistributed(py::module *m) {
in_tensor.impl());
auto in_dense = *p_in_tensor;
const auto &dev_ctx =
self.GetDeviceContext(in_tensor.place(), true);
auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), true);
auto task = self.AllGather(out_dense,
in_dense,
/*offset*/ 0,
/*numel*/ -1,
/*sync_op*/ true,
/*use_calc_stream*/ true);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
return task;
},
py::arg("out"),
......@@ -886,7 +885,7 @@ void BindDistributed(py::module *m) {
auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty
const auto &dev_ctx = self.GetDeviceContext(
auto *dev_ctx = self.GetDeviceContext(
in_tensor_list.back().place(), /*use_calc_stream*/ true);
int world_size = self.GetSize();
auto task =
......@@ -896,7 +895,7 @@ void BindDistributed(py::module *m) {
GetDefaultSplitSizes(in_dense, world_size),
/*sync_op*/ true,
/*use_calc_stream*/ true);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
return task;
},
py::arg("out"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册