diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 09be2ca5e8788e7be31da38780df48ddefb5e225..8cc0cad8a5be010ebb826087bdafd632883b08dc 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -96,7 +96,7 @@ class ProcessGroup { virtual std::string GetBackendName() const = 0; - virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const { + virtual phi::DeviceContext* GetDeviceContext(const Place& place) const { PADDLE_THROW(platform::errors::Unimplemented( "ProcessGroup%s does not support get device_context.", GetBackendName())); diff --git a/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc b/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc index 5c122ce2a3216f69c01be577ddff3cdd51c210cd..42d7d7200edcd689d2991169bba20208dda6c82d 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc @@ -282,12 +282,12 @@ std::shared_ptr ProcessGroupBKCL::Barrier( return task; } -const phi::DeviceContext& ProcessGroupBKCL::GetDeviceContext( +phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext( const Place& place) const { return GetDeviceContext(place, /*use_calc_stream*/ false); } -const phi::DeviceContext& ProcessGroupBKCL::GetDeviceContext( +phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext( const Place& place, bool use_calc_stream) const { const std::string& key = GetKeyFromPlace(place); if (use_calc_stream) { diff --git a/paddle/fluid/distributed/collective/ProcessGroupBKCL.h b/paddle/fluid/distributed/collective/ProcessGroupBKCL.h index f7a95f9e48269f65231d92319c1ba1b0b5bee537..11c0dfbdc6234f026358d377c9c0d44478cda136 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupBKCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupBKCL.h @@ -77,10 +77,10 @@ class ProcessGroupBKCL : public ProcessGroupStream { return std::string(BKCL_BACKEND_NAME); } - const phi::DeviceContext& GetDeviceContext(const Place& place) const override; + phi::DeviceContext* GetDeviceContext(const Place& place) const override; - const phi::DeviceContext& GetDeviceContext( - const Place& place, bool use_calc_stream) const override; + phi::DeviceContext* GetDeviceContext(const Place& place, + bool use_calc_stream) const override; std::shared_ptr AllReduce( phi::DenseTensor* out_tensor, diff --git a/paddle/fluid/distributed/collective/ProcessGroupCustom.cc b/paddle/fluid/distributed/collective/ProcessGroupCustom.cc index 2a87c789937198b68c05c691d6bbfb09bf65f249..4eee250e48a52f93648d5f391c095c08e7d1cb32 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupCustom.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupCustom.cc @@ -299,7 +299,7 @@ std::shared_ptr ProcessGroupCustom::Barrier( return task; } -const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext( +phi::DeviceContext* ProcessGroupCustom::GetDeviceContext( const Place& place) const { const std::string key = GetKeyFromPlace(place); const auto& iter = places_to_ctx_.find(key); @@ -308,7 +308,7 @@ const phi::DeviceContext& ProcessGroupCustom::GetDeviceContext( places_to_ctx_.end(), platform::errors::NotFound( "Cannot find the device context in this process group.")); - return *iter->second[0]; + return iter->second[0].get(); } phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const { diff --git a/paddle/fluid/distributed/collective/ProcessGroupCustom.h b/paddle/fluid/distributed/collective/ProcessGroupCustom.h index 050e780ae120d57abd7188923577899442e20b7b..6aca3802586444da0fda33a166212eb6e23a577d 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupCustom.h +++ b/paddle/fluid/distributed/collective/ProcessGroupCustom.h @@ -93,7 +93,7 @@ class ProcessGroupCustom : public ProcessGroup { std::shared_ptr Barrier( const BarrierOptions& = BarrierOptions()) override; - const phi::DeviceContext& GetDeviceContext(const Place& place) const override; + phi::DeviceContext* GetDeviceContext(const Place& place) const override; phi::ccl::CCLComm CustomCCLComm(const Place& place) const; diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.h b/paddle/fluid/distributed/collective/ProcessGroupGloo.h index fd691e024c4a52c1c07f92caf42fb06accb46077..5e8dc1c5e602cbb4c3eb3330b5f70fa8262a1d5c 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.h +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.h @@ -180,9 +180,8 @@ class ProcessGroupGloo : public ProcessGroup { std::string GetBackendName() const override { return "GLOO"; } - const phi::DeviceContext& GetDeviceContext( - const Place& place) const override { - return *platform::DeviceContextPool::Instance().Get(place); + phi::DeviceContext* GetDeviceContext(const Place& place) const override { + return platform::DeviceContextPool::Instance().Get(place); } // Helper functions for Gloo. diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 3c7bc0ec8429f4feef1abac7c0c8a8328ccc8ee5..f9ceaf089992c98701a62fd9bfbe44d6907cf9c0 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -94,17 +94,17 @@ void ProcessGroupNCCL::GroupEnd() { NCCL_CHECK(platform::dynload::ncclGroupEnd()); } -const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( +phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( const Place& place) const { return GetDeviceContext(place, /*use_calc_stream*/ false); } -const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( +phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( const Place& place, bool use_calc_stream) const { const std::string& key = GetKeyFromPlace(place); if (use_calc_stream) { const auto& iter = place_to_calc_ctx_.find(key); - return *iter->second; + return iter->second; } else { const auto& iter = place_to_comm_ctx_.find(key); PADDLE_ENFORCE_NE( @@ -112,7 +112,7 @@ const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( place_to_comm_ctx_.end(), platform::errors::NotFound( "Cannot find the device context in this process group.")); - return *iter->second; + return iter->second.get(); } } diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index a52e5e61cd29559f2e9fe3e079ccd9a028c28d83..d50003ba5a7029f6511efe1871a08e5cc625a3a5 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -34,7 +34,7 @@ #ifdef PADDLE_WITH_RCCL #include "paddle/fluid/platform/dynload/rccl.h" -#else +#elif PADDLE_WITH_NCCL #include "paddle/fluid/platform/dynload/nccl.h" #endif @@ -83,10 +83,10 @@ class ProcessGroupNCCL final : public ProcessGroupStream { std::string GetBackendName() const override { return "NCCL"; } - const phi::DeviceContext& GetDeviceContext(const Place& place) const override; + phi::DeviceContext* GetDeviceContext(const Place& place, + bool use_calc_stream) const override; - const phi::DeviceContext& GetDeviceContext( - const Place& place, bool use_calc_stream) const override; + phi::DeviceContext* GetDeviceContext(const Place& place) const override; std::shared_ptr AllGather( phi::DenseTensor* out_tensor, diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.cc b/paddle/fluid/distributed/collective/ProcessGroupStream.cc index cd1e617a89e4cc70626517fb97ae904c3ee42429..332298ecfd4a24b708cce822f44745aa3f9adf4c 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.cc @@ -20,7 +20,7 @@ namespace distributed { ProcessGroupStream::ProcessGroupStream(int rank, int size, int gid) : ProcessGroup(rank, size, gid) {} -const phi::DeviceContext& ProcessGroupStream::GetDeviceContext( +phi::DeviceContext* ProcessGroupStream::GetDeviceContext( const Place& place, bool use_calc_stream) const { PADDLE_THROW(platform::errors::Unimplemented( "ProcessGroup%s does not support get device_context.", GetBackendName())); diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.h b/paddle/fluid/distributed/collective/ProcessGroupStream.h index be76429580d100ead1f455f73e41309cbe4a46c8..fcdbd88562edf785ddf557cd6a186c39fbaaf566 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.h +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.h @@ -57,9 +57,10 @@ class ProcessGroupStream : public ProcessGroup { public: ProcessGroupStream(int rank, int size, int gid); virtual ~ProcessGroupStream() = default; + using ProcessGroup::GetDeviceContext; - virtual const phi::DeviceContext& GetDeviceContext( - const Place& place, bool use_calc_stream) const; + virtual phi::DeviceContext* GetDeviceContext(const Place& place, + bool use_calc_stream) const; std::shared_ptr AllGather( phi::DenseTensor* out_tensor, diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index f8850660640c3f229da0ee7a282b211bdc6d5bfe..cd8c8ed2e0cc9c6aa3ab5e3405688175e74566c5 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -1053,9 +1053,9 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, } group->task = process_group_->AllReduce(in_out, in_out, opts); - const auto &context = process_group_->GetDeviceContext(inner_place_); - group->SplitTensorsDev(context); - group->task->UpdateWaitChain(context); + auto *context = process_group_->GetDeviceContext(inner_place_); + group->SplitTensorsDev(*context); + group->task->UpdateWaitChain(*context); // split in FinalizeBackward() } diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 0634f825a01102d7bbbc64d9ac524f0c4634ca1b..7fdf4a0930ebf959b7d54e44b6a144305fdb7f99 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -271,14 +271,14 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto in_dense = *p_in_tensor; - const auto &dev_ctx = self.GetDeviceContext(in_tensor.place()); + auto *dev_ctx = self.GetDeviceContext(in_tensor.place()); auto task = self.AllGather(out_dense, in_dense, /*offset*/ 0, /*numel*/ -1, sync_op); - SplitTensor(dev_ctx, *out_dense, &out_tensor_list); - task->UpdateWaitChain(dev_ctx); + SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); + task->UpdateWaitChain(*dev_ctx); return task; }, py::arg("out"), @@ -334,7 +334,7 @@ void BindDistributed(py::module *m) { auto in_dense = *p_in_tensor; // in_tensor_list should not be empty - const auto &dev_ctx = + auto *dev_ctx = self.GetDeviceContext(in_tensor_list.back().place()); int world_size = self.GetSize(); auto task = @@ -343,8 +343,8 @@ void BindDistributed(py::module *m) { GetDefaultSplitSizes(*out_dense, world_size), GetDefaultSplitSizes(in_dense, world_size), sync_op); - SplitTensor(dev_ctx, *out_dense, &out_tensor_list); - task->UpdateWaitChain(dev_ctx); + SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); + task->UpdateWaitChain(*dev_ctx); return task; }, py::arg("out"), @@ -770,15 +770,14 @@ void BindDistributed(py::module *m) { in_tensor.impl()); auto in_dense = *p_in_tensor; - const auto &dev_ctx = - self.GetDeviceContext(in_tensor.place(), true); + auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), true); auto task = self.AllGather(out_dense, in_dense, /*offset*/ 0, /*numel*/ -1, /*sync_op*/ true, /*use_calc_stream*/ true); - SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); return task; }, py::arg("out"), @@ -886,7 +885,7 @@ void BindDistributed(py::module *m) { auto in_dense = *p_in_tensor; // in_tensor_list should not be empty - const auto &dev_ctx = self.GetDeviceContext( + auto *dev_ctx = self.GetDeviceContext( in_tensor_list.back().place(), /*use_calc_stream*/ true); int world_size = self.GetSize(); auto task = @@ -896,7 +895,7 @@ void BindDistributed(py::module *m) { GetDefaultSplitSizes(in_dense, world_size), /*sync_op*/ true, /*use_calc_stream*/ true); - SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); return task; }, py::arg("out"),