From 0175f566c7fe637c7e7a3f04e4ee631dc4dc34a1 Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Wed, 10 Feb 2021 10:55:00 +0800 Subject: [PATCH] [cherry-pick] Solve inconsistent order in each card in dynamic graph (#30965) * support if else control * fix conflict --- paddle/fluid/imperative/reducer.cc | 83 ++++++++++++++++-------------- paddle/fluid/imperative/reducer.h | 4 +- 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 10e8b398318..fd0fe2d7cf5 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -22,11 +22,6 @@ std::shared_ptr Reducer::s_instance_ = NULL; // context is used to select the stream for concat void Group::ConcatTensors(const platform::CUDADeviceContext &context) { - VLOG(3) << "Before concat, set output tensor size is " << all_length_; - auto tensor = dense_contents_.GetMutable(); - tensor->Resize(framework::make_ddim({all_length_})) - .mutable_data(context.GetPlace(), dtype_); - switch (dtype_) { case framework::proto::VarType::FP16: ConcatTensorsForAllReduce(context, dense_tensors_, @@ -179,6 +174,9 @@ void Reducer::InitializeDenseGroups( p_group->length_.push_back(size); + // for concat operator + p_group->dense_tensors_.push_back(framework::Tensor()); + // check the dtype and place, it must be same. auto dtype = var->DataType(); auto place = var->Place(); @@ -200,6 +198,7 @@ void Reducer::InitializeDenseGroups( place_ = place; } } + p_group->all_length_ = all_length; } // Each parameter will be initialized according to the group information. @@ -234,6 +233,9 @@ void Reducer::InitializeGroups( } else { // process the dense gradient. InitializeDenseGroups(variable_indices_, &group); + auto tensor = group.dense_contents_.GetMutable(); + tensor->Resize(framework::make_ddim({group.all_length_})) + .mutable_data(place_, group.dtype_); } // map variables to this group by VariableLocator @@ -295,9 +297,6 @@ void Reducer::PrepareForBackward( next_group_ = 0; std::for_each(groups_.begin(), groups_.end(), [](Group &group) { group.pending_ = group.variable_indices_.size(); - group.all_length_ = 0; - group.dense_tensors_.clear(); - group.dense_tensors_.reserve(group.pending_); group.sparse_contents_ = nullptr; }); @@ -423,22 +422,35 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { auto group_index = var_locator.group_index; auto &group = groups_[group_index]; - if (is_used_var) { - auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar(); - if (!group.is_sparse_) { - auto grad = var_warpper->MutableVar(); - auto inside_group_index = var_locator.inside_group_index; - auto length = group.length_[inside_group_index]; - - auto tensor = grad->GetMutable(); - framework::Tensor tmp; - tmp.ShareDataWith(*tensor).Resize({static_cast(length)}); - group.dense_tensors_.push_back(std::move(tmp)); - group.all_length_ += length; + if (!group.is_sparse_) { + // process dense group + auto inside_group_index = var_locator.inside_group_index; + auto length = group.length_[inside_group_index]; + auto &group_tensor = group.dense_tensors_[inside_group_index]; + if (is_used_var) { + auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar(); + auto tensor = + var_warpper->MutableVar()->GetMutable(); + group_tensor.ShareDataWith(*tensor).Resize( + {static_cast(length)}); } else { + if (!group_tensor.IsInitialized()) { + group_tensor.Resize({static_cast(length)}); + group_tensor.mutable_data(place_, group.dtype_); + auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_); + operators::math::set_constant(*dev_ctx, &group_tensor, 0.0); + } + } + } else { + // process sparse group + if (is_used_var) { + auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar(); group.sparse_contents_ = var_warpper->MutableVar(); + } else { + group.sparse_contents_ = nullptr; } } + if (--group.pending_ == 0) { // can start allreduce MarkGroupReady(group_index); @@ -478,24 +490,19 @@ void Reducer::MarkGroupReady(size_t group_index) { << "] has no var to allreduce"; } } else { - if (!group.dense_tensors_.empty()) { - VLOG(3) << "dense group [" << next_group_ - << "] start allreduce in ring[" << run_order << "]"; - // Select common commstream to concat tensors - // group.dense_tensors ---> group.dense_contents_ - group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order)); - - // Start allreduce - parallel_ctx_->AllReduceByStream( - group.dense_contents_, &(group.dense_contents_), run_order, false); - - // Select common commstream to split tensors - // group.dense_contents_ ---> group.dense_tensors - group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order)); - } else { - VLOG(3) << "The dense group[" << next_group_ - << "] has no var to allreduce"; - } + VLOG(3) << "dense group [" << next_group_ << "] start allreduce in ring[" + << run_order << "]"; + // Select common commstream to concat tensors + // group.dense_tensors ---> group.dense_contents_ + group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order)); + + // Start allreduce + parallel_ctx_->AllReduceByStream( + group.dense_contents_, &(group.dense_contents_), run_order, false); + + // Select common commstream to split tensors + // group.dense_contents_ ---> group.dense_tensors + group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order)); } } } diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index 62b61616026..1ff10371e44 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -29,6 +29,7 @@ #include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/string/string_helper.h" #if defined(PADDLE_WITH_NCCL) @@ -201,7 +202,8 @@ class Reducer { int nrings_ = 1; // Following variables are to help rebuild group - bool has_rebuilt_group_{false}; + // TODO(shenliang03): Support rebuild in the future. + bool has_rebuilt_group_{true}; std::vector> rebuild_vars_; std::vector rebuild_var_indices_; const std::vector group_size_limits_; -- GitLab