From 57d5ffa5fd340831f90d0a6f15f7cab930cd8842 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Fri, 28 Oct 2022 12:59:25 +0800 Subject: [PATCH] [Dygraph] Fix memory bugs of no sync and SplitTensors in DataParallel (#47369) * fix no sync bugs * update * update task chain fix: update wait chain feat: add `GetDeviceContext` for gloo * fix oom * fix dev * update * update Co-authored-by: LiYuRio Co-authored-by: ForFishes <2282912238@qq.com> --- .../distributed/collective/ProcessGroup.cc | 2 + .../distributed/collective/ProcessGroup.h | 3 +- .../distributed/collective/ProcessGroupGloo.h | 5 ++ .../collective/ProcessGroupNCCL.cc | 13 ++-- .../distributed/collective/ProcessGroupNCCL.h | 8 ++- .../collective/ProcessGroupStream.cc | 2 +- .../collective/ProcessGroupStream.h | 4 +- paddle/fluid/distributed/collective/Utils.h | 42 ++++++------ .../fluid/distributed/collective/reducer.cc | 64 ++++++++++++------- paddle/fluid/distributed/collective/reducer.h | 6 +- paddle/fluid/imperative/reducer.cc | 9 ++- paddle/fluid/imperative/reducer.h | 3 +- paddle/fluid/pybind/distributed_py.cc | 17 +++-- paddle/fluid/pybind/imperative.cc | 1 + python/paddle/fluid/dygraph/parallel.py | 8 +-- 15 files changed, 113 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/distributed/collective/ProcessGroup.cc b/paddle/fluid/distributed/collective/ProcessGroup.cc index 1db8d221aa6..e7942b714e4 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.cc +++ b/paddle/fluid/distributed/collective/ProcessGroup.cc @@ -41,6 +41,8 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) { void ProcessGroup::Task::Synchronize() {} +void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {} + ProcessGroup::ProcessGroup(int rank, int size, const platform::Place& place, diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index ca1cf7dd48b..afe75baeb2a 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -66,6 +66,7 @@ class ProcessGroup { virtual bool IsCompleted(); virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); virtual void Synchronize(); + virtual void UpdateWaitChain(const phi::DeviceContext& ctx); bool IsSync() const { return sync_op_; } protected: @@ -92,7 +93,7 @@ class ProcessGroup { int GetSize() const { return size_; } virtual const std::string GetBackendName() const = 0; - virtual phi::DeviceContext* GetDeviceContext(const Place& place) const { + virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const { PADDLE_THROW(platform::errors::InvalidArgument( "Does not support to get device_context from ProcessGroup%s.", GetBackendName())); diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.h b/paddle/fluid/distributed/collective/ProcessGroupGloo.h index d911da91eb1..f20f39b31a7 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.h +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.h @@ -150,6 +150,11 @@ class ProcessGroupGloo : public ProcessGroup { return GLOO_BACKEND_NAME; } + const phi::DeviceContext& GetDeviceContext( + const Place& place) const override { + return *platform::DeviceContextPool::Instance().Get(place); + } + // Helper functions for Gloo. static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( const std::string& hostname); diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 2e18dfcc3ba..76d1d42c7d6 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -110,6 +110,11 @@ bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return true; } +void ProcessGroupNCCL::NCCLTask::UpdateWaitChain( + const phi::DeviceContext& ctx) { + control_events_[0].Record(*static_cast(&ctx)); +} + void ProcessGroupNCCL::CheckSplitSizes(std::vector* split_sizes, std::vector tensor_shape) { int64_t len_size = (*split_sizes).size(); @@ -1591,15 +1596,15 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const { return iter->second[0]->GetNcclComm(); } -phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( +const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( const Place& place) const { return GetDeviceContext(place, /*use_calc_stream*/ false); } -phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( +const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( const Place& place, bool use_calc_stream) const { if (use_calc_stream) { - return platform::DeviceContextPool::Instance().Get(place); + return *platform::DeviceContextPool::Instance().Get(place); } else { std::vector places = {place}; const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places)); @@ -1607,7 +1612,7 @@ phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( places_to_ctx_.end(), platform::errors::InvalidArgument( "Cannot find device context in process group.")); - return iter->second[0].get(); + return *iter->second[0]; } } diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 6427e9e3e2a..a501bf53023 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -75,6 +75,8 @@ class ProcessGroupNCCL : public ProcessGroupStream { virtual ~NCCLTask(); + void UpdateWaitChain(const phi::DeviceContext& ctx) override; + std::vector control_events_; std::vector barrierTensors_; @@ -96,10 +98,10 @@ class ProcessGroupNCCL : public ProcessGroupStream { return std::string(NCCL_BACKEND_NAME); } - phi::DeviceContext* GetDeviceContext(const Place& place) const override; + const phi::DeviceContext& GetDeviceContext(const Place& place) const override; - phi::DeviceContext* GetDeviceContext(const Place& place, - bool use_calc_stream) const override; + const phi::DeviceContext& GetDeviceContext( + const Place& place, bool use_calc_stream) const override; std::shared_ptr AllReduce( std::vector& in_tensors, // NOLINT diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.cc b/paddle/fluid/distributed/collective/ProcessGroupStream.cc index b2cfae088b2..11530ab872d 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.cc @@ -23,7 +23,7 @@ ProcessGroupStream::ProcessGroupStream(int rank, int gid) : ProcessGroup(rank, size, place, gid) {} -phi::DeviceContext* ProcessGroupStream::GetDeviceContext( +const phi::DeviceContext& ProcessGroupStream::GetDeviceContext( const Place& place, bool use_calc_stream) const { PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support get device_context.", GetBackendName())); diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.h b/paddle/fluid/distributed/collective/ProcessGroupStream.h index 2f0aa139104..56799c4bd3e 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.h +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.h @@ -54,8 +54,8 @@ class ProcessGroupStream : public ProcessGroup { ProcessGroupStream(int rank, int size, const platform::Place& place, int gid); virtual ~ProcessGroupStream() = default; - virtual phi::DeviceContext* GetDeviceContext(const Place& place, - bool use_calc_stream) const; + virtual const phi::DeviceContext& GetDeviceContext( + const Place& place, bool use_calc_stream) const; std::shared_ptr AllGather( std::vector& in_tensors, // NOLINT diff --git a/paddle/fluid/distributed/collective/Utils.h b/paddle/fluid/distributed/collective/Utils.h index c06c0345163..d9260b98dcf 100644 --- a/paddle/fluid/distributed/collective/Utils.h +++ b/paddle/fluid/distributed/collective/Utils.h @@ -25,18 +25,18 @@ namespace distributed { template struct ConcatDenseTensor { - void operator()(const DeviceContext *context, + void operator()(const DeviceContext &context, const std::vector &in, phi::DenseTensor *out, int axis = 0) { phi::funcs::ConcatFunctor concat_functor; - concat_functor(*context, in, axis, out); + concat_functor(context, in, axis, out); } }; template struct SplitDenseTensor { - void operator()(const DeviceContext *context, + void operator()(const DeviceContext &context, const phi::DenseTensor &in, std::vector *out, int axis = 0) { @@ -46,19 +46,19 @@ struct SplitDenseTensor { shape_refer.emplace_back(p_tensor); } phi::funcs::SplitFunctor split_functor; - split_functor(*context, in, shape_refer, axis, out); + split_functor(context, in, shape_refer, axis, out); } }; #ifdef PADDLE_WITH_CUSTOM_DEVICE template struct ConcatDenseTensor { - void operator()(const platform::CustomDeviceContext *context, + void operator()(const platform::CustomDeviceContext &context, const std::vector &in, phi::DenseTensor *out, int axis = 0) { auto *out_data = out->data(); - auto *device = phi::DeviceManager::GetDeviceWithPlace(context->GetPlace()); + auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace()); size_t offset = 0; for (const auto &tensor : in) { const auto *in_data = tensor.data(); @@ -71,12 +71,12 @@ struct ConcatDenseTensor { template struct SplitDenseTensor { - void operator()(const platform::CustomDeviceContext *context, + void operator()(const platform::CustomDeviceContext &context, const phi::DenseTensor &in, std::vector *out, int axis = 0) { auto *in_data = in.data(); - auto *device = phi::DeviceManager::GetDeviceWithPlace(context->GetPlace()); + auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace()); size_t offset = 0; for (auto *p_tensor : *out) { auto *out_data = p_tensor->data(); @@ -89,7 +89,7 @@ struct SplitDenseTensor { #endif template -void ConcatDenseTensorWithType(const DeviceContext *dev_ctx, +void ConcatDenseTensorWithType(const DeviceContext &dev_ctx, const std::vector &t_list, phi::DenseTensor *p_out, phi::DataType type) { @@ -126,7 +126,7 @@ void ConcatDenseTensorWithType(const DeviceContext *dev_ctx, } template -void SplitDenseTensorWithType(const DeviceContext *dev_ctx, +void SplitDenseTensorWithType(const DeviceContext &dev_ctx, const phi::DenseTensor &t_in, std::vector *p_list, phi::DataType type) { @@ -162,16 +162,16 @@ void SplitDenseTensorWithType(const DeviceContext *dev_ctx, } } -void ConcatTensor(const phi::DeviceContext *dev_ctx, +void ConcatTensor(const phi::DeviceContext &dev_ctx, const std::vector &tensor_list, const experimental::Tensor *tensor) { auto *dense_tensor = std::dynamic_pointer_cast(tensor->impl()).get(); - const auto &place = dev_ctx->GetPlace(); + const auto &place = dev_ctx.GetPlace(); if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - ConcatDenseTensorWithType(static_cast(dev_ctx), + ConcatDenseTensorWithType(static_cast(dev_ctx), tensor_list, dense_tensor, tensor->dtype()); @@ -183,7 +183,7 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx, } else if (platform::is_custom_place(place)) { #ifdef PADDLE_WITH_CUSTOM_DEVICE ConcatDenseTensorWithType( - static_cast(dev_ctx), + static_cast(dev_ctx), tensor_list, dense_tensor, tensor->dtype()); @@ -194,7 +194,7 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx, "CUSTOM_DEVICE support.")); #endif } else if (platform::is_cpu_place(place)) { - ConcatDenseTensorWithType(static_cast(dev_ctx), + ConcatDenseTensorWithType(static_cast(dev_ctx), tensor_list, dense_tensor, tensor->dtype()); @@ -204,20 +204,20 @@ void ConcatTensor(const phi::DeviceContext *dev_ctx, } } -void SplitTensor(const phi::DeviceContext *dev_ctx, +void SplitTensor(const phi::DeviceContext &dev_ctx, const phi::DenseTensor &tensor, const std::vector *tensor_list) { std::vector dense_list; for (auto &tensor : *tensor_list) { - auto p_tensor = + auto *p_tensor = std::dynamic_pointer_cast(tensor.impl()).get(); dense_list.emplace_back(p_tensor); } - const auto &place = dev_ctx->GetPlace(); + const auto &place = dev_ctx.GetPlace(); if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - SplitDenseTensorWithType(static_cast(dev_ctx), + SplitDenseTensorWithType(static_cast(dev_ctx), tensor, &dense_list, tensor.dtype()); @@ -229,7 +229,7 @@ void SplitTensor(const phi::DeviceContext *dev_ctx, } else if (platform::is_custom_place(place)) { #ifdef PADDLE_WITH_CUSTOM_DEVICE SplitDenseTensorWithType( - static_cast(dev_ctx), + static_cast(dev_ctx), tensor, &dense_list, tensor.dtype()); @@ -239,7 +239,7 @@ void SplitTensor(const phi::DeviceContext *dev_ctx, "please recompile or reinstall Paddle with CUSTOM_DEVICE support.")); #endif } else if (platform::is_cpu_place(place)) { - SplitDenseTensorWithType(static_cast(dev_ctx), + SplitDenseTensorWithType(static_cast(dev_ctx), tensor, &dense_list, tensor.dtype()); diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 0d46425b2e8..2c26828e5e1 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -16,6 +16,8 @@ #include "paddle/phi/backends/device_guard.h" #include "paddle/phi/backends/device_manager.h" +DECLARE_bool(use_stream_safe_cuda_allocator); + namespace paddle { namespace distributed { @@ -335,13 +337,20 @@ void EagerGroup::ConcatTensors(const platform::Place &place) { } } -void EagerGroup::SplitTensors(const platform::Place &place) { +void EagerGroup::SplitTensorsDev(const platform::DeviceContext &context) { + auto place = context.GetPlace(); if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto *default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); + auto &gpu_context = static_cast(context); SplitTensorsWithType( - *default_ctx, &dense_contents_, &dense_tensors_, dtype_); + gpu_context, &dense_contents_, &dense_tensors_, dtype_); + if (FLAGS_use_stream_safe_cuda_allocator) { + auto dense_tensor = + std::dynamic_pointer_cast(dense_contents_.impl()); + VLOG(3) << "Free dense_contents_ " << dense_contents_.numel(); + memory::RecordStream(dense_tensor->Holder(), gpu_context.stream()); + dense_contents_.reset(); + } #else PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't split grad tensor since it's not compiled with NCCL," @@ -349,10 +358,11 @@ void EagerGroup::SplitTensors(const platform::Place &place) { #endif } else if (platform::is_custom_place(place)) { #ifdef PADDLE_WITH_CUSTOM_DEVICE - auto *default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); SplitTensorsWithType( - *default_ctx, &dense_contents_, &dense_tensors_, dtype_); + static_cast(context), + &dense_contents_, + &dense_tensors_, + dtype_); #else PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't split grad tensor since it's not compiled with " @@ -360,10 +370,10 @@ void EagerGroup::SplitTensors(const platform::Place &place) { "Please recompile or reinstall Paddle with CUSTOM_DEVICE support.")); #endif } else if (platform::is_cpu_place(place)) { - auto *default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - SplitTensorsWithType( - *default_ctx, &dense_contents_, &dense_tensors_, dtype_); + SplitTensorsWithType(static_cast(context), + &dense_contents_, + &dense_tensors_, + dtype_); } else { PADDLE_THROW(platform::errors::Unimplemented( "Split grad tensor not supported on place (%s)", place)); @@ -578,9 +588,11 @@ void EagerReducer::TraverseBackwardGraph(const std::vector &outputs) { } } -void EagerReducer::PrepareForBackward(const std::vector &outputs) { +void EagerReducer::PrepareForBackward(const std::vector &outputs, + const bool is_sync) { VLOG(3) << "after forward, then reset count for backward."; - grad_need_hooks_ = true; + grad_need_hooks_ = is_sync; + next_group_ = 0; std::for_each(groups_.begin(), groups_.end(), [](EagerGroup &group) { group.pending_ = group.tensor_indices_.size(); @@ -648,9 +660,9 @@ void EagerReducer::AddDistHook(size_t var_index) { var_index)); // gradient synchronization is not required when grad_need_hooks_ is false. - if (!grad_need_hooks_) { - return; - } + // if (!grad_need_hooks_) { + // return; + // } VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name() << "@Grad] arrived and triggered disthook"; @@ -816,10 +828,12 @@ void EagerReducer::MarkGroupReady(size_t group_index) { for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; ++next_group_) { UNUSED auto &group = groups_[next_group_]; - if (group.is_sparse_) { - AllReduceSparse(&group, next_group_); - } else { - FusedAllReduceSchedule(&group, next_group_); + if (grad_need_hooks_) { + if (group.is_sparse_) { + AllReduceSparse(&group, next_group_); + } else { + FusedAllReduceSchedule(&group, next_group_); + } } } } @@ -907,16 +921,14 @@ void EagerReducer::ProcessUnusedDenseVars() { void EagerReducer::FinalizeBackward() { groups_need_finalize_ = false; - grad_need_hooks_ = false; for (auto &group : groups_) { - if (!group.is_sparse_) { + if (!group.is_sparse_ && grad_need_hooks_) { group.task->Synchronize(); } } for (auto &group : groups_) { - if (!group.is_sparse_) { - group.SplitTensors(inner_place_); + if (!group.is_sparse_ && grad_need_hooks_) { group.dense_contents_.reset(); } } @@ -928,6 +940,7 @@ void EagerReducer::FinalizeBackward() { VLOG(3) << "ProcessUnusedDenseVars is finished."; } + grad_need_hooks_ = false; VLOG(3) << "In the batch, Reducer is finished."; } @@ -954,6 +967,9 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, } group->task = process_group_->AllReduce(in_out, in_out, opts); + const auto &context = process_group_->GetDeviceContext(inner_place_); + group->SplitTensorsDev(context); + group->task->UpdateWaitChain(context); // split in FinalizeBackward() } diff --git a/paddle/fluid/distributed/collective/reducer.h b/paddle/fluid/distributed/collective/reducer.h index 90848920b7e..74db3db7467 100644 --- a/paddle/fluid/distributed/collective/reducer.h +++ b/paddle/fluid/distributed/collective/reducer.h @@ -74,7 +74,8 @@ class EagerGroup { void ConcatTensors(const platform::Place &); // context is used to select the stream for split - void SplitTensors(const platform::Place &); + + void SplitTensorsDev(const platform::DeviceContext &); friend std::ostream &operator<<(std::ostream &, const EagerGroup &); }; @@ -102,7 +103,8 @@ class EagerReducer { void InitializeGroups(const std::vector> &group_indices); void InitializeDenseGroups(const std::vector &tensor_indices_, EagerGroup *p_group); - void PrepareForBackward(const std::vector &outputs); + void PrepareForBackward(const std::vector &outputs, + const bool is_sync); void AddDistHook(size_t var_index); void MarkVarReady(const size_t var_index, const bool is_used_var); void MarkGroupReady(const size_t group_index); diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index f89fe234c20..3225222f617 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -675,9 +675,10 @@ void Reducer::TraverseBackwardGraph( // After each batch is calculated, the counter of each group(group.pending_) // and allreudce sequence counter(next_group_) will be cleaned up again. void Reducer::PrepareForBackward( - const std::vector> &outputs) { + const std::vector> &outputs, + const bool is_sync) { VLOG(3) << "after forward, then reset count for backward."; - grad_need_hooks_ = true; + grad_need_hooks_ = is_sync; next_group_ = 0; std::for_each(groups_.begin(), groups_.end(), [](Group &group) { group.pending_ = group.variable_indices_.size(); @@ -710,7 +711,9 @@ void Reducer::PrepareForBackward( if (find_unused_vars_once_ || find_unused_vars_each_step_) { unused_vars_.clear(); - TraverseBackwardGraph(outputs); + if (grad_need_hooks_) { + TraverseBackwardGraph(outputs); + } // only check once in first step find_unused_vars_once_ = false; } diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index c455f962788..902c3036acc 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -146,7 +146,8 @@ class Reducer { void PrepareDeps(const std::unordered_set& init_nodes); void PrepareForBackward( - const std::vector>& outputs); + const std::vector>& outputs, + const bool is_sync); void AddDistHook(size_t var_index); diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 6aa8e19c99c..fe1d82c766a 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -395,9 +395,10 @@ void BindDistributed(py::module *m) { concat_out_tensor.impl()); std::vector out_wrapper = {*out_dense}; - const auto *dev_ctx = self.GetDeviceContext(in_tensor.place()); + const auto &dev_ctx = self.GetDeviceContext(in_tensor.place()); auto task = self.AllGather(in_wrapper, out_wrapper, sync_op); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + task->UpdateWaitChain(dev_ctx); return task; }, py::arg("in"), @@ -495,10 +496,11 @@ void BindDistributed(py::module *m) { std::vector out_wrapper = {*out_dense}; // in_tensor_list should not be empty - const auto *dev_ctx = + const auto &dev_ctx = self.GetDeviceContext(in_tensor_list.back().place()); auto task = self.AllToAll(in_wrapper, out_wrapper, sync_op); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + task->UpdateWaitChain(dev_ctx); return task; }, py::arg("in"), @@ -796,7 +798,7 @@ void BindDistributed(py::module *m) { concat_out_tensor.impl()); std::vector out_wrapper = {*out_dense}; - const auto *dev_ctx = + const auto &dev_ctx = self.GetDeviceContext(in_tensor.place(), true); auto task = self.AllGather(in_wrapper, out_wrapper, @@ -905,7 +907,7 @@ void BindDistributed(py::module *m) { std::vector out_wrapper = {*out_dense}; // in_tensor_list must not be empty - const auto *dev_ctx = self.GetDeviceContext( + const auto &dev_ctx = self.GetDeviceContext( in_tensor_list.back().place(), /*use_calc_stream*/ true); auto task = self.AllToAll(in_wrapper, out_wrapper, @@ -1405,11 +1407,14 @@ void BindDistributed(py::module *m) { .def(py::init(&CreateEagerReducer)) .def( "prepare_for_backward", - [](distributed::EagerReducer &self, py::handle py_tensors) { + [](distributed::EagerReducer &self, + py::handle py_tensors, + bool is_sync) { auto params = CastPyArg2VectorOfTensor(py_tensors.ptr(), 0); - self.PrepareForBackward(params); + self.PrepareForBackward(params, is_sync); }, py::arg("tensors"), + py::arg("is_sync"), py::call_guard()); } diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 1eb5f8bd476..bd18d4b3319 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -2569,6 +2569,7 @@ void BindImperative(py::module *m_ptr) { .def("prepare_for_backward", &imperative::Reducer::PrepareForBackward, py::arg("vars"), + py::arg("is_sync"), py::call_guard()); m.def("assign_group_by_size", diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 51e0527e4fa..004c21c1346 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -818,13 +818,9 @@ class DataParallel(layers.Layer): def forward(self, *inputs, **kwargs): outputs = self._layers(*inputs, **kwargs) - if ( - self._strategy.nranks > 1 - and framework._dygraph_tracer()._has_grad - and self.grad_need_sync - ): + if self._strategy.nranks > 1 and framework._dygraph_tracer()._has_grad: self._reducer.prepare_for_backward( - list(self._find_varbase(outputs)) + list(self._find_varbase(outputs)), self.grad_need_sync ) return outputs -- GitLab