未验证 提交 403d58bb 编写于 作者: L LiYuRio 提交者: GitHub

return pointer rather than reference (#48152)

上级 02c51f3b
...@@ -96,7 +96,7 @@ class ProcessGroup { ...@@ -96,7 +96,7 @@ class ProcessGroup {
virtual std::string GetBackendName() const = 0; virtual std::string GetBackendName() const = 0;
virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const { virtual phi::DeviceContext* GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support get device_context.", "ProcessGroup%s does not support get device_context.",
GetBackendName())); GetBackendName()));
......
...@@ -282,12 +282,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Barrier( ...@@ -282,12 +282,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Barrier(
return task; return task;
} }
const phi::DeviceContext& ProcessGroupBKCL::GetDeviceContext( phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext(
const Place& place) const { const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false); return GetDeviceContext(place, /*use_calc_stream*/ false);
} }
const phi::DeviceContext& ProcessGroupBKCL::GetDeviceContext( phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const { const Place& place, bool use_calc_stream) const {
const std::string& key = GetKeyFromPlace(place); const std::string& key = GetKeyFromPlace(place);
if (use_calc_stream) { if (use_calc_stream) {
......
...@@ -77,10 +77,10 @@ class ProcessGroupBKCL : public ProcessGroupStream { ...@@ -77,10 +77,10 @@ class ProcessGroupBKCL : public ProcessGroupStream {
return std::string(BKCL_BACKEND_NAME); return std::string(BKCL_BACKEND_NAME);
} }
const phi::DeviceContext& GetDeviceContext(const Place& place) const override; phi::DeviceContext* GetDeviceContext(const Place& place) const override;
const phi::DeviceContext& GetDeviceContext( phi::DeviceContext* GetDeviceContext(const Place& place,
const Place& place, bool use_calc_stream) const override; bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
......
...@@ -299,7 +299,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier( ...@@ -299,7 +299,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
return task; return task;
} }
const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext( phi::DeviceContext* ProcessGroupCustom::GetDeviceContext(
const Place& place) const { const Place& place) const {
const std::string key = GetKeyFromPlace(place); const std::string key = GetKeyFromPlace(place);
const auto& iter = places_to_ctx_.find(key); const auto& iter = places_to_ctx_.find(key);
...@@ -308,7 +308,7 @@ const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext( ...@@ -308,7 +308,7 @@ const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext(
places_to_ctx_.end(), places_to_ctx_.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Cannot find the device context in this process group.")); "Cannot find the device context in this process group."));
return *iter->second[0]; return iter->second[0].get();
} }
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const { phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
......
...@@ -93,7 +93,7 @@ class ProcessGroupCustom : public ProcessGroup { ...@@ -93,7 +93,7 @@ class ProcessGroupCustom : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Barrier( std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override; const BarrierOptions& = BarrierOptions()) override;
const phi::DeviceContext& GetDeviceContext(const Place& place) const override; phi::DeviceContext* GetDeviceContext(const Place& place) const override;
phi::ccl::CCLComm CustomCCLComm(const Place& place) const; phi::ccl::CCLComm CustomCCLComm(const Place& place) const;
......
...@@ -180,9 +180,8 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -180,9 +180,8 @@ class ProcessGroupGloo : public ProcessGroup {
std::string GetBackendName() const override { return "GLOO"; } std::string GetBackendName() const override { return "GLOO"; }
const phi::DeviceContext& GetDeviceContext( phi::DeviceContext* GetDeviceContext(const Place& place) const override {
const Place& place) const override { return platform::DeviceContextPool::Instance().Get(place);
return *platform::DeviceContextPool::Instance().Get(place);
} }
// Helper functions for Gloo. // Helper functions for Gloo.
......
...@@ -94,17 +94,17 @@ void ProcessGroupNCCL::GroupEnd() { ...@@ -94,17 +94,17 @@ void ProcessGroupNCCL::GroupEnd() {
NCCL_CHECK(platform::dynload::ncclGroupEnd()); NCCL_CHECK(platform::dynload::ncclGroupEnd());
} }
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
const Place& place) const { const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false); return GetDeviceContext(place, /*use_calc_stream*/ false);
} }
const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const { const Place& place, bool use_calc_stream) const {
const std::string& key = GetKeyFromPlace(place); const std::string& key = GetKeyFromPlace(place);
if (use_calc_stream) { if (use_calc_stream) {
const auto& iter = place_to_calc_ctx_.find(key); const auto& iter = place_to_calc_ctx_.find(key);
return *iter->second; return iter->second;
} else { } else {
const auto& iter = place_to_comm_ctx_.find(key); const auto& iter = place_to_comm_ctx_.find(key);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -112,7 +112,7 @@ const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( ...@@ -112,7 +112,7 @@ const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
place_to_comm_ctx_.end(), place_to_comm_ctx_.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Cannot find the device context in this process group.")); "Cannot find the device context in this process group."));
return *iter->second; return iter->second.get();
} }
} }
......
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
#ifdef PADDLE_WITH_RCCL #ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h" #include "paddle/fluid/platform/dynload/rccl.h"
#else #elif PADDLE_WITH_NCCL
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#endif #endif
...@@ -83,10 +83,10 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -83,10 +83,10 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
std::string GetBackendName() const override { return "NCCL"; } std::string GetBackendName() const override { return "NCCL"; }
const phi::DeviceContext& GetDeviceContext(const Place& place) const override; phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override;
const phi::DeviceContext& GetDeviceContext( phi::DeviceContext* GetDeviceContext(const Place& place) const override;
const Place& place, bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
......
...@@ -20,7 +20,7 @@ namespace distributed { ...@@ -20,7 +20,7 @@ namespace distributed {
ProcessGroupStream::ProcessGroupStream(int rank, int size, int gid) ProcessGroupStream::ProcessGroupStream(int rank, int size, int gid)
: ProcessGroup(rank, size, gid) {} : ProcessGroup(rank, size, gid) {}
const phi::DeviceContext& ProcessGroupStream::GetDeviceContext( phi::DeviceContext* ProcessGroupStream::GetDeviceContext(
const Place& place, bool use_calc_stream) const { const Place& place, bool use_calc_stream) const {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support get device_context.", GetBackendName())); "ProcessGroup%s does not support get device_context.", GetBackendName()));
......
...@@ -57,9 +57,10 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -57,9 +57,10 @@ class ProcessGroupStream : public ProcessGroup {
public: public:
ProcessGroupStream(int rank, int size, int gid); ProcessGroupStream(int rank, int size, int gid);
virtual ~ProcessGroupStream() = default; virtual ~ProcessGroupStream() = default;
using ProcessGroup::GetDeviceContext;
virtual const phi::DeviceContext& GetDeviceContext( virtual phi::DeviceContext* GetDeviceContext(const Place& place,
const Place& place, bool use_calc_stream) const; bool use_calc_stream) const;
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
......
...@@ -1053,9 +1053,9 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, ...@@ -1053,9 +1053,9 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
} }
group->task = process_group_->AllReduce(in_out, in_out, opts); group->task = process_group_->AllReduce(in_out, in_out, opts);
const auto &context = process_group_->GetDeviceContext(inner_place_); auto *context = process_group_->GetDeviceContext(inner_place_);
group->SplitTensorsDev(context); group->SplitTensorsDev(*context);
group->task->UpdateWaitChain(context); group->task->UpdateWaitChain(*context);
// split in FinalizeBackward() // split in FinalizeBackward()
} }
......
...@@ -271,14 +271,14 @@ void BindDistributed(py::module *m) { ...@@ -271,14 +271,14 @@ void BindDistributed(py::module *m) {
in_tensor.impl()); in_tensor.impl());
auto in_dense = *p_in_tensor; auto in_dense = *p_in_tensor;
const auto &dev_ctx = self.GetDeviceContext(in_tensor.place()); auto *dev_ctx = self.GetDeviceContext(in_tensor.place());
auto task = self.AllGather(out_dense, auto task = self.AllGather(out_dense,
in_dense, in_dense,
/*offset*/ 0, /*offset*/ 0,
/*numel*/ -1, /*numel*/ -1,
sync_op); sync_op);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx); task->UpdateWaitChain(*dev_ctx);
return task; return task;
}, },
py::arg("out"), py::arg("out"),
...@@ -334,7 +334,7 @@ void BindDistributed(py::module *m) { ...@@ -334,7 +334,7 @@ void BindDistributed(py::module *m) {
auto in_dense = *p_in_tensor; auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty // in_tensor_list should not be empty
const auto &dev_ctx = auto *dev_ctx =
self.GetDeviceContext(in_tensor_list.back().place()); self.GetDeviceContext(in_tensor_list.back().place());
int world_size = self.GetSize(); int world_size = self.GetSize();
auto task = auto task =
...@@ -343,8 +343,8 @@ void BindDistributed(py::module *m) { ...@@ -343,8 +343,8 @@ void BindDistributed(py::module *m) {
GetDefaultSplitSizes(*out_dense, world_size), GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size), GetDefaultSplitSizes(in_dense, world_size),
sync_op); sync_op);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx); task->UpdateWaitChain(*dev_ctx);
return task; return task;
}, },
py::arg("out"), py::arg("out"),
...@@ -770,15 +770,14 @@ void BindDistributed(py::module *m) { ...@@ -770,15 +770,14 @@ void BindDistributed(py::module *m) {
in_tensor.impl()); in_tensor.impl());
auto in_dense = *p_in_tensor; auto in_dense = *p_in_tensor;
const auto &dev_ctx = auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), true);
self.GetDeviceContext(in_tensor.place(), true);
auto task = self.AllGather(out_dense, auto task = self.AllGather(out_dense,
in_dense, in_dense,
/*offset*/ 0, /*offset*/ 0,
/*numel*/ -1, /*numel*/ -1,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
return task; return task;
}, },
py::arg("out"), py::arg("out"),
...@@ -886,7 +885,7 @@ void BindDistributed(py::module *m) { ...@@ -886,7 +885,7 @@ void BindDistributed(py::module *m) {
auto in_dense = *p_in_tensor; auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty // in_tensor_list should not be empty
const auto &dev_ctx = self.GetDeviceContext( auto *dev_ctx = self.GetDeviceContext(
in_tensor_list.back().place(), /*use_calc_stream*/ true); in_tensor_list.back().place(), /*use_calc_stream*/ true);
int world_size = self.GetSize(); int world_size = self.GetSize();
auto task = auto task =
...@@ -896,7 +895,7 @@ void BindDistributed(py::module *m) { ...@@ -896,7 +895,7 @@ void BindDistributed(py::module *m) {
GetDefaultSplitSizes(in_dense, world_size), GetDefaultSplitSizes(in_dense, world_size),
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
return task; return task;
}, },
py::arg("out"), py::arg("out"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册