diff --git a/paddle/fluid/distributed/collective/ProcessGroup.cc b/paddle/fluid/distributed/collective/ProcessGroup.cc index 1db8d221aa67d3f5fc4cec027bcb141b8e9dae27..e7942b714e4f613b77f691de0eaefcab76102a9d 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.cc +++ b/paddle/fluid/distributed/collective/ProcessGroup.cc @@ -41,6 +41,8 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) { void ProcessGroup::Task::Synchronize() {} +void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {} + ProcessGroup::ProcessGroup(int rank, int size, const platform::Place& place, diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index ca1cf7dd48ba707a04e4cbaae56125187d1ecf8a..afe75baeb2a4f7de149545d30c0128fc0e830852 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -66,6 +66,7 @@ class ProcessGroup { virtual bool IsCompleted(); virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); virtual void Synchronize(); + virtual void UpdateWaitChain(const phi::DeviceContext& ctx); bool IsSync() const { return sync_op_; } protected: @@ -92,7 +93,7 @@ class ProcessGroup { int GetSize() const { return size_; } virtual const std::string GetBackendName() const = 0; - virtual phi::DeviceContext* GetDeviceContext(const Place& place) const { + virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const { PADDLE_THROW(platform::errors::InvalidArgument( "Does not support to get device_context from ProcessGroup%s.", GetBackendName())); diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.h b/paddle/fluid/distributed/collective/ProcessGroupGloo.h index d911da91eb1a32bc25811595f81bee4529de8546..f20f39b31a7a7a218c28cbe558a793f49c4c840c 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.h +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.h @@ -150,6 +150,11 @@ class ProcessGroupGloo : public ProcessGroup { return GLOO_BACKEND_NAME; } + const phi::DeviceContext& GetDeviceContext( + const Place& place) const override { + return *platform::DeviceContextPool::Instance().Get(place); + } + // Helper functions for Gloo. static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( const std::string& hostname); diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 2e18dfcc3ba1208f47a4ceeb2529826d46b44c34..76d1d42c7d653f908227f423b36b9a583105ad3f 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -110,6 +110,11 @@ bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return true; } +void ProcessGroupNCCL::NCCLTask::UpdateWaitChain( + const phi::DeviceContext& ctx) { + control_events_[0].Record(*static_cast(&ctx)); +} + void ProcessGroupNCCL::CheckSplitSizes(std::vector* split_sizes, std::vector tensor_shape) { int64_t len_size = (*split_sizes).size(); @@ -1591,15 +1596,15 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const { return iter->second[0]->GetNcclComm(); } -phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( +const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( const Place& place) const { return GetDeviceContext(place, /*use_calc_stream*/ false); } -phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( +const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( const Place& place, bool use_calc_stream) const { if (use_calc_stream) { - return platform::DeviceContextPool::Instance().Get(place); + return *platform::DeviceContextPool::Instance().Get(place); } else { std::vector places = {place}; const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places)); @@ -1607,7 +1612,7 @@ phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( places_to_ctx_.end(), platform::errors::InvalidArgument( "Cannot find device context in process group.")); - return iter->second[0].get(); + return *iter->second[0]; } } diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 6427e9e3e2ab1c8a06224cd88b1a60a4cf067c61..a501bf5302350139a12e5ddfaa9f20b3aba36a04 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -75,6 +75,8 @@ class ProcessGroupNCCL : public ProcessGroupStream { virtual ~NCCLTask(); + void UpdateWaitChain(const phi::DeviceContext& ctx) override; + std::vector control_events_; std::vector barrierTensors_; @@ -96,10 +98,10 @@ class ProcessGroupNCCL : public ProcessGroupStream { return std::string(NCCL_BACKEND_NAME); } - phi::DeviceContext* GetDeviceContext(const Place& place) const override; + const phi::DeviceContext& GetDeviceContext(const Place& place) const override; - phi::DeviceContext* GetDeviceContext(const Place& place, - bool use_calc_stream) const override; + const phi::DeviceContext& GetDeviceContext( + const Place& place, bool use_calc_stream) const override; std::shared_ptr AllReduce( std::vector& in_tensors, // NOLINT diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.cc b/paddle/fluid/distributed/collective/ProcessGroupStream.cc index b2cfae088b2271ab455430a7b06bd9714d31a1f4..11530ab872d22df004cb0503a2c97d658a8b0ea0 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.cc @@ -23,7 +23,7 @@ ProcessGroupStream::ProcessGroupStream(int rank, int gid) : ProcessGroup(rank, size, place, gid) {} -phi::DeviceContext* ProcessGroupStream::GetDeviceContext( +const phi::DeviceContext& ProcessGroupStream::GetDeviceContext( const Place& place, bool use_calc_stream) const { PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support get device_context.", GetBackendName())); diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.h b/paddle/fluid/distributed/collective/ProcessGroupStream.h index 2f0aa139104e929a24740f5827ee263648df18fe..56799c4bd3ed8366d1f087e41d77607e536d2ef5 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.h +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.h @@ -54,8 +54,8 @@ class ProcessGroupStream : public ProcessGroup { ProcessGroupStream(int rank, int size, const platform::Place& place, int gid); virtual ~ProcessGroupStream() = default; - virtual phi::DeviceContext* GetDeviceContext(const Place& place, - bool use_calc_stream) const; + virtual const phi::DeviceContext& GetDeviceContext( + const Place& place, bool use_calc_stream) const; std::shared_ptr AllGather( std::vector& in_tensors, // NOLINT diff --git a/paddle/fluid/distributed/collective/Utils.h b/paddle/fluid/distributed/collective/Utils.h index c06c0345163ed7c6d68e7256bc84ee07c183507e..d9260b98dcf44a44aa4f11d721eab9847c01f464 100644 --- a/paddle/fluid/distributed/collective/Utils.h +++ b/paddle/fluid/distributed/collective/Utils.h @@ -25,18 +25,18 @@ namespace distributed { template struct ConcatDenseTensor { - void operator()(const DeviceContext *context, + void operator()(const DeviceContext &context, const std::vector &in, phi::DenseTensor *out, int axis = 0) { phi::funcs::ConcatFunctor concat_functor; - concat_functor(*context, in, axis, out); + concat_functor(context, in, axis, out); } }; template struct SplitDenseTensor { - void operator()(const DeviceContext *context, + void operator()(const DeviceContext &context, const phi::DenseTensor &in, std::vector *out, int axis = 0) { @@ -46,19 +46,19 @@ struct SplitDenseTensor { shape_refer.emplace_back(p_tensor); } phi::funcs::SplitFunctor split_functor; - split_functor(*context, in, shape_refer, axis, out); + split_functor(context, in, shape_refer, axis, out); } }; #ifdef PADDLE_WITH_CUSTOM_DEVICE template struct ConcatDenseTensor { - void operator()(const platform::CustomDeviceContext *context, + void operator()(const platform::CustomDeviceContext &context, const std::vector &in, phi::DenseTensor *out, int axis = 0) { auto *out_data = out->data(); - auto *device = phi::DeviceManager::GetDeviceWithPlace(context->GetPlace()); + auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace()); size_t offset = 0; for (const auto &tensor : in) { const auto *in_data = tensor.data(); @@ -71,12 +71,12 @@ struct ConcatDenseTensor { template struct SplitDenseTensor { - void operator()(const platform::CustomDeviceContext *context, + void operator()(const platform::CustomDeviceContext &context, const phi::DenseTensor &in, std::vector *out, int axis = 0) { auto *in_data = in.data(); - auto *device = phi::DeviceManager::GetDeviceWithPlace(context->GetPlace()); + auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace()); size_t offset = 0; for (auto *p_tensor : *out) { auto *out_data = p_tensor->data(); @@ -89,7 +89,7 @@ struct SplitDenseTensor { #endif template -void ConcatDenseTensorWithType(const DeviceContext *dev_ctx, +void ConcatDenseTensorWithType(const DeviceContext &dev_ctx, const std::vector &t_list, phi::DenseTensor *p_out, phi::DataType type) { @@ -126,7 +126,7 @@ void ConcatDenseTensorWithType(const DeviceContext *dev_ctx, } template -void SplitDenseTensorWithType(const DeviceContext *dev_ctx, +void SplitDenseTensorWithType(const DeviceContext &dev_ctx, const phi::DenseTensor &t_in, std::vector *p_list, phi::DataType type) { @@ -162,16 +162,16 @@ void SplitDenseTensorWithType(const DeviceContext *dev_ctx, } } -void ConcatTensor(const phi::DeviceContext *dev_ctx, +void ConcatTensor(const phi::DeviceContext &dev_ctx, const std::vector &tensor_list, const experimental::Tensor *tensor) { auto *dense_tensor = std::dynamic_pointer_cast(tensor->impl()).get(); - const auto &place = dev_ctx->GetPlace(); + const auto &place = dev_ctx.GetPlace(); if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - ConcatDenseTensorWithType(static_cast(dev_ctx), + ConcatDenseTensorWithType(static_cast(dev_ctx), tensor_list, dense_tensor, tensor->dtype()); @@ -183,7 +183,7 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx, } else if (platform::is_custom_place(place)) { #ifdef PADDLE_WITH_CUSTOM_DEVICE ConcatDenseTensorWithType( - static_cast(dev_ctx), + static_cast(dev_ctx), tensor_list, dense_tensor, tensor->dtype()); @@ -194,7 +194,7 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx, "CUSTOM_DEVICE support.")); #endif } else if (platform::is_cpu_place(place)) { - ConcatDenseTensorWithType(static_cast(dev_ctx), + ConcatDenseTensorWithType(static_cast(dev_ctx), tensor_list, dense_tensor, tensor->dtype()); @@ -204,20 +204,20 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx, } } -void SplitTensor(const phi::DeviceContext *dev_ctx, +void SplitTensor(const phi::DeviceContext &dev_ctx, const phi::DenseTensor &tensor, const std::vector *tensor_list) { std::vector dense_list; for (auto &tensor : *tensor_list) { - auto p_tensor = + auto *p_tensor = std::dynamic_pointer_cast(tensor.impl()).get(); dense_list.emplace_back(p_tensor); } - const auto &place = dev_ctx->GetPlace(); + const auto &place = dev_ctx.GetPlace(); if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - SplitDenseTensorWithType(static_cast(dev_ctx), + SplitDenseTensorWithType(static_cast(dev_ctx), tensor, &dense_list, tensor.dtype()); @@ -229,7 +229,7 @@ void SplitTensor(const phi::DeviceContext *dev_ctx, } else if (platform::is_custom_place(place)) { #ifdef PADDLE_WITH_CUSTOM_DEVICE SplitDenseTensorWithType( - static_cast(dev_ctx), + static_cast(dev_ctx), tensor, &dense_list, tensor.dtype()); @@ -239,7 +239,7 @@ void SplitTensor(const phi::DeviceContext *dev_ctx, "please recompile or reinstall Paddle with CUSTOM_DEVICE support.")); #endif } else if (platform::is_cpu_place(place)) { - SplitDenseTensorWithType(static_cast(dev_ctx), + SplitDenseTensorWithType(static_cast(dev_ctx), tensor, &dense_list, tensor.dtype()); diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 0d46425b2e83274fb5eb62306e11a0b0a6d7221a..2c26828e5e1143382173a87f966d7fbec1a96c79 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -16,6 +16,8 @@ #include "paddle/phi/backends/device_guard.h" #include "paddle/phi/backends/device_manager.h" +DECLARE_bool(use_stream_safe_cuda_allocator); + namespace paddle { namespace distributed { @@ -335,13 +337,20 @@ void EagerGroup::ConcatTensors(const platform::Place &place) { } } -void EagerGroup::SplitTensors(const platform::Place &place) { +void EagerGroup::SplitTensorsDev(const platform::DeviceContext &context) { + auto place = context.GetPlace(); if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto *default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); + auto &gpu_context = static_cast(context); SplitTensorsWithType( - *default_ctx, &dense_contents_, &dense_tensors_, dtype_); + gpu_context, &dense_contents_, &dense_tensors_, dtype_); + if (FLAGS_use_stream_safe_cuda_allocator) { + auto dense_tensor = + std::dynamic_pointer_cast(dense_contents_.impl()); + VLOG(3) << "Free dense_contents_ " << dense_contents_.numel(); + memory::RecordStream(dense_tensor->Holder(), gpu_context.stream()); + dense_contents_.reset(); + } #else PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't split grad tensor since it's not compiled with NCCL," @@ -349,10 +358,11 @@ void EagerGroup::SplitTensors(const platform::Place &place) { #endif } else if (platform::is_custom_place(place)) { #ifdef PADDLE_WITH_CUSTOM_DEVICE - auto *default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); SplitTensorsWithType( - *default_ctx, &dense_contents_, &dense_tensors_, dtype_); + static_cast(context), + &dense_contents_, + &dense_tensors_, + dtype_); #else PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't split grad tensor since it's not compiled with " @@ -360,10 +370,10 @@ void EagerGroup::SplitTensors(const platform::Place &place) { "Please recompile or reinstall Paddle with CUSTOM_DEVICE support.")); #endif } else if (platform::is_cpu_place(place)) { - auto *default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - SplitTensorsWithType( - *default_ctx, &dense_contents_, &dense_tensors_, dtype_); + SplitTensorsWithType(static_cast(context), + &dense_contents_, + &dense_tensors_, + dtype_); } else { PADDLE_THROW(platform::errors::Unimplemented( "Split grad tensor not supported on place (%s)", place)); @@ -578,9 +588,11 @@ void EagerReducer::TraverseBackwardGraph(const std::vector &outputs) { } } -void EagerReducer::PrepareForBackward(const std::vector &outputs) { +void EagerReducer::PrepareForBackward(const std::vector &outputs, + const bool is_sync) { VLOG(3) << "after forward, then reset count for backward."; - grad_need_hooks_ = true; + grad_need_hooks_ = is_sync; + next_group_ = 0; std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) { group.pending_ = group.tensor_indices_.size(); @@ -648,9 +660,9 @@ void EagerReducer::AddDistHook(size_t var_index) { var_index)); // gradient synchronization is not required when grad_need_hooks_ is false. - if (!grad_need_hooks_) { - return; - } + // if (!grad_need_hooks_) { + // return; + // } VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name() << "@Grad] arrived and triggered disthook"; @@ -816,10 +828,12 @@ void EagerReducer::MarkGroupReady(size_t group_index) { for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; ++next_group_) { UNUSED auto &group = groups_[next_group_]; - if (group.is_sparse_) { - AllReduceSparse(&group, next_group_); - } else { - FusedAllReduceSchedule(&group, next_group_); + if (grad_need_hooks_) { + if (group.is_sparse_) { + AllReduceSparse(&group, next_group_); + } else { + FusedAllReduceSchedule(&group, next_group_); + } } } } @@ -907,16 +921,14 @@ void EagerReducer::ProcessUnusedDenseVars() { void EagerReducer::FinalizeBackward() { groups_need_finalize_ = false; - grad_need_hooks_ = false; for (auto &group : groups_) { - if (!group.is_sparse_) { + if (!group.is_sparse_ && grad_need_hooks_) { group.task->Synchronize(); } } for (auto &group : groups_) { - if (!group.is_sparse_) { - group.SplitTensors(inner_place_); + if (!group.is_sparse_ && grad_need_hooks_) { group.dense_contents_.reset(); } } @@ -928,6 +940,7 @@ void EagerReducer::FinalizeBackward() { VLOG(3) << "ProcessUnusedDenseVars is finished."; } + grad_need_hooks_ = false; VLOG(3) << "In the batch, Reducer is finished."; } @@ -954,6 +967,9 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, } group->task = process_group_->AllReduce(in_out, in_out, opts); + const auto &context = process_group_->GetDeviceContext(inner_place_); + group->SplitTensorsDev(context); + group->task->UpdateWaitChain(context); // split in FinalizeBackward() } diff --git a/paddle/fluid/distributed/collective/reducer.h b/paddle/fluid/distributed/collective/reducer.h index 90848920b7e938f4d50b1b57e34dd22d99fc1eea..74db3db7467298d26d982d3c0cedc4b8b2327705 100644 --- a/paddle/fluid/distributed/collective/reducer.h +++ b/paddle/fluid/distributed/collective/reducer.h @@ -74,7 +74,8 @@ class EagerGroup { void ConcatTensors(const platform::Place &); // context is used to select the stream for split - void SplitTensors(const platform::Place &); + + void SplitTensorsDev(const platform::DeviceContext &); friend std::ostream &operator<<(std::ostream &, const EagerGroup &); }; @@ -102,7 +103,8 @@ class EagerReducer { void InitializeGroups(const std::vector> &group_indices); void InitializeDenseGroups(const std::vector &tensor_indices_, EagerGroup *p_group); - void PrepareForBackward(const std::vector &outputs); + void PrepareForBackward(const std::vector &outputs, + const bool is_sync); void AddDistHook(size_t var_index); void MarkVarReady(const size_t var_index, const bool is_used_var); void MarkGroupReady(const size_t group_index); diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index f89fe234c201ab73d0262396b9f6b65313f868a1..3225222f617373a5a89656ccc4fd6d25d3557e86 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -675,9 +675,10 @@ void Reducer::TraverseBackwardGraph( // After each batch is calculated, the counter of each group(group.pending_) // and allreudce sequence counter(next_group_) will be cleaned up again. void Reducer::PrepareForBackward( - const std::vector> &outputs) { + const std::vector> &outputs, + const bool is_sync) { VLOG(3) << "after forward, then reset count for backward."; - grad_need_hooks_ = true; + grad_need_hooks_ = is_sync; next_group_ = 0; std::for_each(groups_.begin(), groups_.end(), [](Group &group) { group.pending_ = group.variable_indices_.size(); @@ -710,7 +711,9 @@ void Reducer::PrepareForBackward( if (find_unused_vars_once_ || find_unused_vars_each_step_) { unused_vars_.clear(); - TraverseBackwardGraph(outputs); + if (grad_need_hooks_) { + TraverseBackwardGraph(outputs); + } // only check once in first step find_unused_vars_once_ = false; } diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index c455f962788b891e2a039db2990437e5c51302f6..902c3036acc78b1d575f6d4a420682168e9fb000 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -146,7 +146,8 @@ class Reducer { void PrepareDeps(const std::unordered_set& init_nodes); void PrepareForBackward( - const std::vector>& outputs); + const std::vector>& outputs, + const bool is_sync); void AddDistHook(size_t var_index); diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 6aa8e19c99c61cc131a1e8610e47e6880a528a5c..fe1d82c766a0e0ad6a4f4e67dfdb1aaf6a931f5c 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -395,9 +395,10 @@ void BindDistributed(py::module *m) { concat_out_tensor.impl()); std::vector out_wrapper = {*out_dense}; - const auto *dev_ctx = self.GetDeviceContext(in_tensor.place()); + const auto &dev_ctx = self.GetDeviceContext(in_tensor.place()); auto task = self.AllGather(in_wrapper, out_wrapper, sync_op); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + task->UpdateWaitChain(dev_ctx); return task; }, py::arg("in"), @@ -495,10 +496,11 @@ void BindDistributed(py::module *m) { std::vector out_wrapper = {*out_dense}; // in_tensor_list should not be empty - const auto *dev_ctx = + const auto &dev_ctx = self.GetDeviceContext(in_tensor_list.back().place()); auto task = self.AllToAll(in_wrapper, out_wrapper, sync_op); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + task->UpdateWaitChain(dev_ctx); return task; }, py::arg("in"), @@ -796,7 +798,7 @@ void BindDistributed(py::module *m) { concat_out_tensor.impl()); std::vector out_wrapper = {*out_dense}; - const auto *dev_ctx = + const auto &dev_ctx = self.GetDeviceContext(in_tensor.place(), true); auto task = self.AllGather(in_wrapper, out_wrapper, @@ -905,7 +907,7 @@ void BindDistributed(py::module *m) { std::vector out_wrapper = {*out_dense}; // in_tensor_list must not be empty - const auto *dev_ctx = self.GetDeviceContext( + const auto &dev_ctx = self.GetDeviceContext( in_tensor_list.back().place(), /*use_calc_stream*/ true); auto task = self.AllToAll(in_wrapper, out_wrapper, @@ -1405,11 +1407,14 @@ void BindDistributed(py::module *m) { .def(py::init(&CreateEagerReducer)) .def( "prepare_for_backward", - [](distributed::EagerReducer &self, py::handle py_tensors) { + [](distributed::EagerReducer &self, + py::handle py_tensors, + bool is_sync) { auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0); - self.PrepareForBackward(params); + self.PrepareForBackward(params, is_sync); }, py::arg("tensors"), + py::arg("is_sync"), py::call_guard()); } diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 1eb5f8bd4764c52577ff1d3960c6b47c37dd48c4..bd18d4b3319b2e8cf8f435fd29413b2b1f19e73b 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -2569,6 +2569,7 @@ void BindImperative(py::module *m_ptr) { .def("prepare_for_backward", &imperative::Reducer::PrepareForBackward, py::arg("vars"), + py::arg("is_sync"), py::call_guard()); m.def("assign_group_by_size", diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 51e0527e4fa99fde78696fc6714cc15b1a6cbcb7..004c21c1346b1493277228ec051e7c79262b7658 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -818,13 +818,9 @@ class DataParallel(layers.Layer): def forward(self, *inputs, **kwargs): outputs = self._layers(*inputs, **kwargs) - if ( - self._strategy.nranks > 1 - and framework._dygraph_tracer()._has_grad - and self.grad_need_sync - ): + if self._strategy.nranks > 1 and framework._dygraph_tracer()._has_grad: self._reducer.prepare_for_backward( - list(self._find_varbase(outputs)) + list(self._find_varbase(outputs)), self.grad_need_sync ) return outputs