diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index b36793507f54bff72b4fc228fd8147e8ac46f24f..04dc51f1b943dcbea2fc21250064fc96c9d31879 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -152,6 +152,7 @@ message DistributedStrategy { optional bool fp16_allreduce = 25 [ default = false ]; optional bool sharding = 26 [ default = false ]; optional float last_comm_group_size_MB = 27 [ default = 1 ]; + optional bool find_unused_parameters = 28 [ default = true ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; diff --git a/paddle/fluid/imperative/bkcl_context.cc b/paddle/fluid/imperative/bkcl_context.cc index 886179feb1974906136ea1b90d7eefd81c658c79..16f9454e9376e4368a478cf8adf9e3f988868785 100644 --- a/paddle/fluid/imperative/bkcl_context.cc +++ b/paddle/fluid/imperative/bkcl_context.cc @@ -167,8 +167,6 @@ void BKCLParallelContext::WaitCompute(int ring_id) { platform::errors::OutOfRange("Ring id expected < nrings," "but got ring id = %d, nrings = %d", ring_id, strategy_.nrings_)); - // TODO(wangxi16): [Performance optimize] Maybe need to put Wait and - // bkcl_allreduce to comm thread, for bkcl_allreduce is blocking now. auto compute_dev_ctx = static_cast( platform::DeviceContextPool::Instance().Get(place_)); compute_dev_ctx->Wait(); @@ -188,6 +186,12 @@ void BKCLParallelContext::WaitComm(int ring_id) { comm_dev_ctx->Wait(); } +void BKCLParallelContext::SynchronizeCompute() { + auto compute_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place_)); + compute_dev_ctx->Wait(); +} + } // namespace imperative } // namespace paddle #endif diff --git a/paddle/fluid/imperative/bkcl_context.h b/paddle/fluid/imperative/bkcl_context.h index 86e4d97b3c774fe538aafcd7bd16e78523eab4e2..652b7689666c6c66c4efe6edda0c23acfc0cab27 100644 --- a/paddle/fluid/imperative/bkcl_context.h +++ b/paddle/fluid/imperative/bkcl_context.h @@ -47,6 +47,8 @@ class BKCLParallelContext : public ParallelContext { void WaitCompute(int ring_id) override; void WaitComm(int ring_id) override; + + void SynchronizeCompute() override; }; } // namespace imperative diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index 7e7c4ceea0bc23e8b1eb3daaa2140458abaaf765..b91fc460781c795b94b0bc3bdd04ee087c99cda6 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -173,6 +173,12 @@ void NCCLParallelContext::WaitComm(int ring_id) { #endif } +void NCCLParallelContext::SynchronizeCompute() { + auto *compute_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place_)); + compute_dev_ctx->Wait(); +} + #endif } // namespace imperative diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h index 292ef1661c36fa67ce263882106b7402fddc29b3..bcaeb811b108c5e07bd3710e09c058b03a4a736d 100644 --- a/paddle/fluid/imperative/nccl_context.h +++ b/paddle/fluid/imperative/nccl_context.h @@ -65,6 +65,8 @@ class NCCLParallelContext : public ParallelContext { void WaitComm(int ring_id) override; + void SynchronizeCompute() override; + private: // used for comm wait compute, compute_stream-->event-->comm_stream[ring_id] std::vector> compute_events_; diff --git a/paddle/fluid/imperative/parallel_context.h b/paddle/fluid/imperative/parallel_context.h index 9a76311f2ed6be34ccb4db246604aa8a221adc0e..f537a316014d60ed18250d72de3ec2b7dd95cf05 100644 --- a/paddle/fluid/imperative/parallel_context.h +++ b/paddle/fluid/imperative/parallel_context.h @@ -66,6 +66,9 @@ class ParallelContext { // if CPU, should do nothing. virtual void WaitComm(int ring_id) = 0; + // synchorize compute stream + virtual void SynchronizeCompute() = 0; + inline int GetNRings() const { return strategy_.nrings_; } inline int64_t GetNRanks() const { return strategy_.nranks_; } diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 4b18886821b8e76e9e9a9e7f27972f5543c84e81..5422b7ce9c85528122d7076ca18e78cfc729383d 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -315,6 +315,12 @@ Reducer::Reducer(const std::vector> &vars, VariableWrapper *grad) { this->AddDistHook(global_var_index); })); var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index; } + + // for checking var is ready once + vars_marked_ready_.resize(vars_.size(), false); + + // Initialize local used vars + local_used_vars_.resize(vars_.size(), 0); } void Reducer::InitializeDenseGroups( @@ -323,7 +329,7 @@ void Reducer::InitializeDenseGroups( for (size_t index = 0; index < variable_indices_.size(); ++index) { const auto variable_index = variable_indices_[index]; const auto &var = vars_[variable_index]; - const auto var_name = var->Name(); + const auto &var_name = var->Name(); PADDLE_ENFORCE_EQ(is_sparse_gradient_[variable_index], false, platform::errors::PreconditionNotMet( "Tensor %s's GRAD must be LoDTensor, but received " @@ -334,7 +340,7 @@ void Reducer::InitializeDenseGroups( PADDLE_ENFORCE_EQ(lod_tensor->IsInitialized(), true, platform::errors::PreconditionNotMet( "Tensor %s is not initialized.", var_name)); - auto size = lod_tensor->numel(); + const auto size = lod_tensor->numel(); PADDLE_ENFORCE_GT( size, 0, platform::errors::PreconditionNotMet( "The number of tensor %s's elements is 0.", var_name)); @@ -346,8 +352,8 @@ void Reducer::InitializeDenseGroups( p_group->dense_tensors_.push_back(framework::Tensor()); // check the dtype and place, it must be same. - auto dtype = var->DataType(); - auto place = var->Place(); + const auto &dtype = var->DataType(); + const auto &place = var->Place(); if (index > 0) { PADDLE_ENFORCE_EQ( dtype, p_group->dtype_, @@ -417,8 +423,7 @@ void Reducer::InitializeGroups( group.variable_indices_ = std::move(variable_indices_); groups_.emplace_back(std::move(group)); // Debug Message For Reducer - VLOG(3) << "The Group[" << group_index << "]:"; - VLOG(3) << groups_.back(); + VLOG(3) << "The Group[" << group_index << "]:" << groups_.back(); } } @@ -461,34 +466,38 @@ void Reducer::PrepareDeps(const std::unordered_set &init_nodes) { // and allreudce sequence counter(next_group_) will be cleaned up again. void Reducer::PrepareForBackward( const std::vector> &outputs) { - VLOG(3) << "start reseting count.."; + VLOG(3) << "after forward, then reset count for backward."; next_group_ = 0; std::for_each(groups_.begin(), groups_.end(), [](Group &group) { group.pending_ = group.variable_indices_.size(); group.sparse_contents_ = nullptr; }); + // reinitialize vars_marked_ready_ for next iteration + vars_marked_ready_.clear(); + vars_marked_ready_.resize(vars_.size(), false); + PADDLE_ENFORCE_EQ( - all_group_ready_, false, + groups_need_finalize_, false, platform::errors::PreconditionNotMet( - "Please note that all forward outputs derived from the module " + "A serious error has occurred here. There may be several reasons: " + "1) Please note that all forward outputs derived from the module " "parameters must participate in the calculation of losses and " "subsequent gradient calculations. If not, the wrapper will hang, " "waiting for autograd to generate gradients for these parameters. " "you can use detach or stop_gradient to make the unused parameters " - "detached from the autograd graph.")); + "detached from the autograd graph. " + "2) Used multiple forwards and one backward. You may be able to wrap " + "multiple forwards in a model.")); // The first var to trigger the unused parameter has_marked_unused_vars_ = false; + unused_vars_.clear(); + if (!find_unused_vars_) { return; } - // TODO(shenliang03) "find_unused_vars" interface will be exposed in the - // future to handle control flow to process unused parameters - find_unused_vars_ = false; - - unused_vars_.clear(); node_deps_.clear(); std::queue> q; std::unordered_set var_visited; @@ -551,6 +560,23 @@ void Reducer::PrepareForBackward( << "] is not used"; } } + + if (unused_vars_.empty()) { + LOG_FIRST_N(WARNING, 1) + << "All parameters are involved in the backward pass. " + "It is recommended to set find_unused_parameters to False " + "to improve performance. However, if unused parameters " + "appear in subsequent iterative training, then an error " + "will occur. Please make it clear that in the subsequent " + "training, there will be no parameters that are not used " + "in the backward pass, and then set find_unused_parameters"; + } else if (unused_vars_.size() == vars_.size()) { + LOG_FIRST_N(WARNING, 1) + << "There is no parameter in the device involved " + "in the backward calculation. If there are " + "parameters on other devices involved in the " + "backward, then a serious error will occur here."; + } } // Add hook function to each leaf node. When the gradient of a leaf node is @@ -563,67 +589,133 @@ void Reducer::PrepareForBackward( // concat + allreduce + split is emitted in turn according to next_group_. // 3, FinalizeBackward: after the end, synchronize each stream. void Reducer::AddDistHook(size_t var_index) { + PADDLE_ENFORCE_LT(var_index, variable_locators_.size(), + platform::errors::OutOfRange( + "Out of bounds variable index. it must be less" + "than %d, but it is %d", + variable_locators_.size(), var_index)); + VLOG(3) << "Var[" << var_index << "] [" << vars_[var_index]->GradVarBase()->Name() << "] arrived and triggered disthook"; - if (!has_marked_unused_vars_) { - has_marked_unused_vars_ = true; - for (auto unused_index : unused_vars_) { - if (NeedRebuildGroup()) { - rebuild_vars_.push_back(vars_[unused_index]); - rebuild_var_indices_.push_back(unused_index); - } - MarkVarReady(unused_index, false); - } - } + local_used_vars_[var_index] = 1; + + // rebuild group when find_unused_vars_ is false if (NeedRebuildGroup()) { rebuild_vars_.push_back(vars_[var_index]); rebuild_var_indices_.push_back(var_index); } + + if (!has_marked_unused_vars_ && find_unused_vars_) { + has_marked_unused_vars_ = true; + for (const auto &unused_index : unused_vars_) { + MarkVarReady(unused_index, false); + } + } + MarkVarReady(var_index, true); } void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { - all_group_ready_ = true; + groups_need_finalize_ = true; + const auto &var_locator = variable_locators_[var_index]; - auto group_index = var_locator.group_index; + const auto group_index = var_locator.group_index; auto &group = groups_[group_index]; + // error happened, if the var is ready before. + if (vars_marked_ready_[var_index]) { + auto error_info = string::Sprintf( + "Error happened, when parameter[%d][%s] has been ready before. " + "There may be several reasons for this error: " + "1) In multiple reentrant backward phase, some parameters are reused." + "2) Using model parameters outside of forward function. Please " + "make sure that model parameters are not shared in concurrent " + "forward-backward passes.", + var_index, vars_[var_index]->GradVarBase()->Name()); + + PADDLE_ENFORCE_EQ(has_marked_unused_vars_, false, + platform::errors::PreconditionNotMet(error_info)); + + error_info += + "3) Unused parameters retrieval is incorrect. " + "The return value of forward will be used to retrieve" + " the unused parameters of the entire model. These " + "gradients of unused parameters will not be synchronized " + "between multiple cards. However, if the unused " + "parameters participate in the backward calculation " + "again at a later time (e.g. after the forward function, " + "the loss calculation uses the unused " + "paramters of the forward and trigger backward), " + "its gradient will be wrong."; + + PADDLE_ENFORCE_EQ(has_marked_unused_vars_, true, + platform::errors::PreconditionNotMet(error_info)); + } else { + vars_marked_ready_[var_index] = true; + } + if (!group.is_sparse_) { // process dense group - auto inside_group_index = var_locator.inside_group_index; - auto length = group.length_[inside_group_index]; + const auto inside_group_index = var_locator.inside_group_index; + const auto length = group.length_[inside_group_index]; auto &group_tensor = group.dense_tensors_[inside_group_index]; + if (is_used_var) { - auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar(); - auto tensor = - var_warpper->MutableVar()->GetMutable(); + auto var_base = vars_[var_index]->GradVarBase(); + auto tensor = var_base->MutableVar()->GetMutable(); group_tensor.ShareDataWith(*tensor).Resize( {static_cast(length)}); } else { + // TODO(shenliang03): maybe save the memory + // by avoiding tensor construction if (!group_tensor.IsInitialized()) { group_tensor.Resize({static_cast(length)}); group_tensor.mutable_data(place_, group.dtype_); + } + #ifdef PADDLE_WITH_XPU_BKCL - if (platform::is_xpu_place(group_tensor.place())) { - // TODO(liuyuhui) support XPU set constant - VLOG(3) << "XPU doesn't support set_constant"; - } + if (platform::is_xpu_place(group_tensor.place())) { + // TODO(liuyuhui) support XPU set constant + VLOG(3) << "XPU doesn't support set_constant"; + } #else - auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_); + auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_); + if (HasGrad(var_index)) { + auto var_base = vars_[var_index]->GradVarBase(); + auto tensor = + var_base->MutableVar()->GetMutable(); + TensorCopy(*tensor, place_, *dev_ctx, &group_tensor); + group_tensor.Resize({static_cast(length)}); + } else { + group_tensor.Resize({static_cast(length)}); operators::math::set_constant(*dev_ctx, &group_tensor, 0.0); -#endif } +#endif } } else { // process sparse group - if (is_used_var) { - auto var_warpper = vars_[var_index]->GradVarBase()->SharedVar(); - group.sparse_contents_ = var_warpper->MutableVar(); - } else { - group.sparse_contents_ = nullptr; - } + PADDLE_ENFORCE_EQ(HasGrad(var_index), true, + platform::errors::PreconditionNotMet( + "The sparse parameter[%d][%s] must have a gradient", + var_index, vars_[var_index]->Name())); + auto var_base = vars_[var_index]->GradVarBase(); + // need to check tensor type + PADDLE_ENFORCE_EQ( + var_base->Var().IsType(), true, + platform::errors::PreconditionNotMet( + "The sparse parameter[%d][%s] must have a selectedrows gradient. " + "Before forward pass, the parameter type is inferred to be " + "SelectedRows, but after backward pass, its actual type becomes " + "LodTensor. It is currently not supported by DataParallel. " + "For example, if sparse embedding is used, and the weight of " + "embedding is shared with subsequent dense parameters, then " + "the parameter gradient of the embedding will be converted " + "to dense parameters.", + var_index, vars_[var_index]->Name())); + + group.sparse_contents_ = var_base->MutableVar(); } if (--group.pending_ == 0) { @@ -639,6 +731,14 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { // TODO(liuyuhui): If BKCL support non-blocking communication, it should be // fixed as same as multi gpus card trainging. void Reducer::MarkGroupReady(size_t group_index) { + PADDLE_ENFORCE_GE( + group_index, next_group_, + platform::errors::PreconditionNotMet( + "The index of the incoming group must be greater " + "than or equal to the previously synchronized group index, " + "expect it to greater than or equal to %d, but got %d.", + next_group_, group_index)); + if (group_index > next_group_) { VLOG(3) << "It will adjust the order of group in next batch automatically"; return; @@ -647,7 +747,7 @@ void Reducer::MarkGroupReady(size_t group_index) { for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; ++next_group_) { auto &group = groups_[next_group_]; - int run_order = next_group_ % nrings_; + const int run_order = next_group_ % nrings_; // For CUDA or XPU, compute_stream --> comm_stream. // For CPU, do nothing. @@ -666,7 +766,7 @@ void Reducer::MarkGroupReady(size_t group_index) { comm_pool_->enqueue([&] { auto dev_id = BOOST_GET_CONST(platform::XPUPlace, place_).device; platform::SetXPUDeviceId(dev_id); - FusedAllReduceSchedule(run_order, group); + FusedAllReduceSchedule(run_order, group, next_group_); { std::lock_guard lock(mutex_); comm_op_count_ -= 1; // lock @@ -674,7 +774,7 @@ void Reducer::MarkGroupReady(size_t group_index) { } }); #elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) - FusedAllReduceSchedule(run_order, group); + FusedAllReduceSchedule(run_order, group, next_group_); #else PADDLE_THROW(platform::errors::PreconditionNotMet( "Not compiled with BKCL or NCCL.")); @@ -682,24 +782,23 @@ void Reducer::MarkGroupReady(size_t group_index) { } } -void Reducer::FusedAllReduceSchedule(int run_order, Group &group) { +void Reducer::FusedAllReduceSchedule(const int run_order, Group &group, + const int curr_group_index) { + // The overall timeline: concat > div_nranks > allreduce > split + // dev_context is used to select different stream + const auto &dev_context = *parallel_ctx_->GetDeviceContext(run_order); if (group.is_sparse_) { - if (group.sparse_contents_ != nullptr) { - VLOG(3) << "sparse group [" << next_group_ << "] start allreduce in ring[" - << run_order << "]"; - group.DivNRanks(*parallel_ctx_->GetDeviceContext(run_order), nranks_); - parallel_ctx_->AllReduceByStream( - *group.sparse_contents_, group.sparse_contents_, run_order, false); - } else { - VLOG(3) << "The sparse group[" << next_group_ - << "] has no var to allreduce"; - } + VLOG(3) << "sparse group [" << curr_group_index + << "] start allreduce in ring[" << run_order << "]"; + group.DivNRanks(dev_context, nranks_); + parallel_ctx_->AllReduceByStream(*group.sparse_contents_, + group.sparse_contents_, run_order, false); } else { - VLOG(3) << "dense group [" << next_group_ << "] start allreduce in ring[" - << run_order << "]"; + VLOG(3) << "dense group [" << curr_group_index + << "] start allreduce in ring[" << run_order << "]"; // Select common commstream to concat tensors // group.dense_tensors ---> group.dense_contents_ - group.ConcatTensors(*parallel_ctx_->GetDeviceContext(run_order)); + group.ConcatTensors(dev_context); // NOTE(liuyuhui): ConcatTensors use communication stream, but BKCL only support // default stream for communicating, so there exist some problems in @@ -711,15 +810,15 @@ void Reducer::FusedAllReduceSchedule(int run_order, Group &group) { parallel_ctx_->WaitComm(run_order); } #endif - group.DivNRanks(*parallel_ctx_->GetDeviceContext(run_order), nranks_); + group.DivNRanks(dev_context, nranks_); // Start allreduce parallel_ctx_->AllReduceByStream( group.dense_contents_, &(group.dense_contents_), run_order, false); - // Select common commstream to split tensors + // Select communication stream to split tensors // group.dense_contents_ ---> group.dense_tensors - group.SplitTensors(*parallel_ctx_->GetDeviceContext(run_order)); + group.SplitTensors(dev_context); } } @@ -745,14 +844,98 @@ std::vector> Reducer::RebuildGruops() { return rebuild_group_indices; } +void Reducer::ProcessUnusedDenseVars() { + // The calculation stream must be used here to + // avoid conflicts with communication. + VLOG(3) << "Local used vars : " + << string::join_strings(local_used_vars_, ','); + const auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_); + // H2D is to allreduce the local_used_vars_ + auto *global_used_tensor = + global_used_vars_.GetMutable(); + framework::TensorFromVector(local_used_vars_, *dev_ctx, + global_used_tensor); + parallel_ctx_->AllReduceByStream(global_used_vars_, &global_used_vars_, 0, + true); + framework::TensorToVector(*global_used_tensor, *dev_ctx, + &local_used_vars_); + + // sync compute stream to get global used var message, + // but maybe affect speed performance + parallel_ctx_->SynchronizeCompute(); + VLOG(3) << "Global used vars : " + << string::join_strings(local_used_vars_, ','); + + for (const auto var_index : unused_vars_) { + const bool global_unused = (local_used_vars_[var_index] == 0); + + // global used but local unused, set grad + VLOG(3) << "Var [" << var_index << "] [" << vars_[var_index]->Name() + << "] global_unused:" << global_unused + << " has grad: " << HasGrad(var_index); + + if (!global_unused) { + VLOG(3) << "Start process unused Var"; + // 1. source var base + const auto &var_locator = variable_locators_[var_index]; + const auto group_index = var_locator.group_index; + const auto &group = groups_[group_index]; + const auto inside_group_index = var_locator.inside_group_index; + const auto &src_tensor = group.dense_tensors_[inside_group_index]; + // sparse no need to check and no support find_unused_parameters + if (group.is_sparse_) { + continue; + } + // 2. destination var base + auto dest_var_base = vars_[var_index]; + auto *dest_tensor = + dest_var_base->MutableVar()->GetMutable(); + const auto &dest_dims = dest_tensor->dims(); + + // 3. create grad var base or get grad var base + auto grad_var_base_tmp = dest_var_base->MutableGradVarBase(); + + // 4. set grad tensor + auto *dest_grad_tensor = + grad_var_base_tmp->MutableVar()->GetMutable(); + const auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_); + TensorCopy(src_tensor, place_, *dev_ctx, dest_grad_tensor); + dest_grad_tensor->Resize(dest_dims); + } + } +} + +bool Reducer::HasGrad(size_t var_index) { + const auto grad_var = vars_[var_index]->GradVarBase(); + if (!grad_var || !grad_var->Var().IsInitialized()) { + return false; + } + + const auto &var = grad_var->Var(); + if (var.IsType()) { + if (var.Get().IsInitialized()) { + return true; + } + } else if (var.IsType()) { + if (var.Get().value().IsInitialized()) { + return true; + } + } else { + PADDLE_THROW(platform::errors::PermissionDenied( + "Only support LoDTensor and SelectedRows for gradient var")); + } + return false; +} + void Reducer::FinalizeBackward() { - all_group_ready_ = false; + groups_need_finalize_ = false; #ifdef PADDLE_WITH_XPU_BKCL { std::unique_lock lock(mutex_); cv_.wait(lock, [&] { return comm_op_count_ == 0; }); } #endif + // Must prevent compute_stream_ starting until all comm streams have finished for (int i = 0; i < nrings_; ++i) { parallel_ctx_->WaitComm(i); @@ -765,7 +948,18 @@ void Reducer::FinalizeBackward() { InitializeGroups(group_indices_); } - VLOG(3) << "In the batch, Reducer is finished..."; + if (find_unused_vars_) { +// TODO(liuyuhui) support xpu about Tensorcopy/TensorFromVector/TensorToVector +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + ProcessUnusedDenseVars(); +#endif + // Initialize local used vars + local_used_vars_.clear(); + local_used_vars_.resize(vars_.size(), 0); + VLOG(3) << "ProcessUnusedDenseVars is finished."; + } + + VLOG(3) << "In the batch, Reducer is finished."; } // According to the size of each parameter, it is allocated to different groups. diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index b2680d0dea71aa19399242bffedc3d7914cebbb9..0d613dbea896339760d320a6b9937ffcc8ea0dcc 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -27,6 +27,7 @@ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" @@ -153,13 +154,20 @@ class Reducer { void MarkGroupReady(size_t group_index); - void FusedAllReduceSchedule(int run_order, Group& group); // NOLINT + void FusedAllReduceSchedule(const int run_order, Group& group, // NOLINT + const int curr_group_index); void FinalizeBackward(); std::vector> RebuildGruops(); - inline bool NeedRebuildGroup() { return !has_rebuilt_group_; } + inline bool NeedRebuildGroup() { + return !has_rebuilt_group_ && !find_unused_vars_; + } + + void ProcessUnusedDenseVars(); + + bool HasGrad(size_t var_index); private: std::vector> vars_; @@ -188,7 +196,7 @@ class Reducer { std::vector unused_vars_; bool has_marked_unused_vars_{false}; bool find_unused_vars_{false}; - bool all_group_ready_{false}; + bool groups_need_finalize_{false}; #ifdef PADDLE_WITH_XPU_BKCL // comm_pool_ is used for scheduling allreduce in multi Kunlun cards training. std::unique_ptr<::ThreadPool> comm_pool_{nullptr}; @@ -196,6 +204,19 @@ class Reducer { std::mutex mutex_; std::condition_variable cv_; #endif + + // it just for checking hook, each parameter can only trigger one hook + std::vector vars_marked_ready_; + + // Following variables are to help control flow. + // local_used_vars_ uses 0/1 to indicate whether the + // var is used in iteration. After the end of the + // iteration, global_used_vars_ is obtained synchronously + // globally. Choose whether to update the local + // gradient according to the global_used_vars_. + std::vector local_used_vars_; + // global_used_vars_ is used in comm stream to avoid wait + framework::Variable global_used_vars_; }; std::vector> AssignGroupBySize( diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index f79013d7347c00efc36aef17dba5f6d3a1ae3165..626f6a37a982e0310fa6cfffb1d000c163634a89 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -620,6 +620,34 @@ class DistributedStrategy(object): else: raise ValueError("last_comm_group_size_MB should be greater than 0") + @property + def find_unused_parameters(self): + """ + Indicating whether we are using find_unused_parameters to + find unused parameters in DataParallel. + + Default value: True + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.find_unused_parameters = True + """ + + return self.strategy.find_unused_parameters + + @find_unused_parameters.setter + @is_strict_auto + def find_unused_parameters(self, flag): + if isinstance(flag, bool): + self.strategy.find_unused_parameters = flag + else: + print( + "WARNING: find_unused_parameters should have value of bool type") + @property def _fuse_grad_size_in_TFLOPS(self): return self.strategy.fuse_grad_size_in_TFLOPS diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index cf802034cabd85209f8df66e54e82cf765f2c5a5..470d1a2b78f090bde6b7e9c47ad9d7343bc59116 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -706,7 +706,9 @@ class Fleet(object): model, comm_buffer_size=self._user_defined_strategy.fuse_grad_size_in_MB, last_comm_buffer_size=self._user_defined_strategy. - last_comm_group_size_MB) + last_comm_group_size_MB, + find_unused_parameters=self._user_defined_strategy. + find_unused_parameters) return self.model @dygraph_only diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index b157ce81d82fc7f6d2231fc22a5246b65de31035..3df0c60852727624b7c798d857340f3423c4382a 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -22,6 +22,7 @@ import copy import weakref import warnings from copy import deepcopy +import paddle from . import parallel_helper from .. import unique_name @@ -894,9 +895,15 @@ class Layer(core.Layer): if not self._built: with program_desc_tracing_guard(False): self._build_once(*inputs, **kwargs) - if parallel_helper._is_data_parallel_mode(): + + # TODO(liuyuhui) Only xpu broadcast parameters here. + # The other device is to call _sync_params_buffers in DataParallel + # to realize the parameter synchronization among multiply cards. + if parallel_helper._is_data_parallel_mode( + ) and paddle.is_compiled_with_xpu(): parallel_helper._broadcast_parameters( self._parameters.values()) + self._built = True outputs = self.forward(*inputs, **kwargs) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 2ef72f6c5aaf4bc39b5eb71ac6ba64d0829c0475..b80621e21f1c5c430f7f8f0451f4cd6a52a1cb56 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -24,6 +24,7 @@ from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import to_variable, no_grad from paddle.utils import deprecated +from ..layers import collective import warnings import paddle import itertools @@ -348,6 +349,18 @@ class DataParallel(layers.Layer): last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication calling. Making the last communication buffer size small is useful to improve performance. Default: 1. + find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the + all tensors in the return value of the wrapped model's + forward function. For parameters not involved in loss + calculation, their gradients will be marked as ready in + advance to prepare reduce. Please note that all forward + outputs derived from the wrapped model parameters must + participate in the calculation of loss and subsequent + gradient calculations. If not, serious error will occur. + Note that setting the find_unused_parameters to True + will affect computing performance. Therefore, if all parameters + are sure to participate in the loss calculation and the + autograd graph construction, please set it False. Default: True. Returns: Layer: The data paralleled module. @@ -403,11 +416,13 @@ class DataParallel(layers.Layer): layers, strategy=None, comm_buffer_size=25, - last_comm_buffer_size=1): + last_comm_buffer_size=1, + find_unused_parameters=True): super(DataParallel, self).__init__(layers.full_name() + "_data_parallel") self._layers = layers + self.find_unused_parameters = find_unused_parameters # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. # It just stores some environment variables, which can be constructed by @@ -419,6 +434,17 @@ class DataParallel(layers.Layer): self._strategy = _build_default_parallel_strategy() if self._strategy.nranks > 1: + # check the environment + assert parallel_helper.__parallel_ctx__clz__ is not None, \ + "ParallelContext must be initialized before. You should use init_parallel_env() before" \ + "constructing the DataParallel." + + # sync buffer and params + # TODO(liuyuhui) Currently not support xpu. xpu is + # still broadcasting parameters when calling layer + if not paddle.is_compiled_with_xpu(): + self._sync_params_buffers() + self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024) # NOTE(shenliang03): We can set environment variables to control # the size of the group, Default: 1MB. The role of this small group is: @@ -449,6 +475,10 @@ class DataParallel(layers.Layer): trainable_parameters = [param for _, param in layers_param] + assert len(trainable_parameters) > 0, \ + "This model does not have any parameters to train, and " \ + "does not need to use DataParallel" + # NOTE(shenliang03): Here we can only use the attributes to judge whether # parameter is sparse(or SelectedRows). The reason is that the sparse message # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter, @@ -470,19 +500,12 @@ class DataParallel(layers.Layer): trainable_parameters, is_sparse_gradient, [self.last_comm_buffer_size, self.comm_buffer_size]) - assert parallel_helper.__parallel_ctx__clz__ is not None, \ - "ParallelContext must be initialized before. You should use init_parallel_env() before" \ - "constructing the DataParallel." - - # TODO(shenliang03) "find_unused_vars" interface will be exposed in the future - # to handle control flow to process unused parameters - find_unused_vars = True self._reducer = core.Reducer( trainable_parameters, list(reversed(self.group_indices)), is_sparse_gradient, parallel_helper.__parallel_ctx__clz__, [self.last_comm_buffer_size, self.comm_buffer_size], - find_unused_vars) + self.find_unused_parameters) def _find_varbase(self, obj): if isinstance(obj, core.VarBase): @@ -493,11 +516,54 @@ class DataParallel(layers.Layer): return itertools.chain(*map(self._find_varbase, obj.values())) return [] + def _sync_params_buffers(self): + model_vars = [] + for _, param in self._layers.state_dict().items(): + if not isinstance(param, core.VarBase): + raise TypeError("The data type of '%s' must be Varbase" % + param.name) + model_vars.append(param.detach()) + if len(model_vars) == 0: + return + + mega_bytes = 128 * 1024 * 1024 + group_idx = 0 + memory_counter = 0 + var_groups = OrderedDict() + dtype = model_vars[0].dtype + + for var in model_vars: + bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype) + if memory_counter < mega_bytes and dtype == var.dtype: + memory_counter += bytes + else: + memory_counter = 0 + dtype = var.dtype + group_idx += 1 + var_groups.setdefault(group_idx, []).append(var) + + coalesced_vars = _coalesce_tensors(var_groups) + + for coalesced_var, _, _ in coalesced_vars: + collective._broadcast(coalesced_var, root=0, sync_mode=True) + + for coalesced_var, origin_vars, var_shapes in coalesced_vars: + var_len = [np.prod(v_shape) for v_shape in var_shapes] + framework._dygraph_tracer().trace_op( + type='split', + inputs={'X': coalesced_var}, + outputs={'Out': origin_vars}, + attrs={'sections': var_len, + 'axis': 0}) + def forward(self, *inputs, **kwargs): outputs = self._layers(*inputs, **kwargs) - if self._strategy.nranks > 1: - self._reducer.prepare_for_backward( - list(self._find_varbase(outputs))) + if self._strategy.nranks > 1 and framework._dygraph_tracer()._has_grad: + if self.find_unused_parameters: + self._reducer.prepare_for_backward( + list(self._find_varbase(outputs))) + else: + self._reducer.prepare_for_backward(list(self._find_varbase([]))) return outputs diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 0abb61d95aa6ee36f0bc7ab65215f44e1d1a1334..28f5177c20486535e131b4679e7d671e3eb47582 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -19,6 +19,8 @@ list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) list(APPEND DIST_TEST_OPS test_gen_nccl_id_op) list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables) +list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow) +list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) @@ -160,6 +162,8 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_over_height) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_transformer) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sync_batch_norm) + list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow) + list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single) elseif(WITH_GPU) @@ -824,10 +828,12 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120) set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) + set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) - set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) endif() endif() if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_control_flow_different.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_control_flow_different.py new file mode 100644 index 0000000000000000000000000000000000000000..26c9944abd6c6c97baca8b56caa18644e5615977 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_control_flow_different.py @@ -0,0 +1,122 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import paddle.distributed as dist + +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Embedding +import paddle.nn.functional as F +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase + +paddle.seed(123) +np.random.seed(2021) + + +class SimpleNet(fluid.Layer): + def __init__(self, hidden_size, vocab_size, is_sparse=False): + super(SimpleNet, self).__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.embedding = Embedding( + size=[self.vocab_size, self.hidden_size], + dtype='float32', + is_sparse=is_sparse) + + self.lin_a = paddle.nn.Linear(self.hidden_size, self.vocab_size) + self.lin_b = paddle.nn.Linear(self.vocab_size, 1) + + self.unused_net = paddle.nn.Linear(5, 3) + self.phony = self.create_parameter(shape=[1], dtype="float32") + + def forward(self, input, label, conf): + x_emb = self.embedding(input) + fc = self.lin_a(x_emb) + mask = conf > 0 + mask = paddle.cast(mask, dtype="int64") + mask.stop_gradient = True + emb_mask = mask.max(1).flatten() + emb_mask_inds = paddle.nonzero(emb_mask > 0).flatten() + emb_mask_inds.stop_gradient = True + + if emb_mask_inds.numel() == 0: + loss_box = self.phony * 0 + else: + projection = self.lin_b(fc) + projection = paddle.reshape(projection, shape=[-1, 1]) + output = paddle.gather(projection, emb_mask_inds) + target = paddle.gather(label, emb_mask_inds) + loss_box = F.smooth_l1_loss( + output, target, reduction='sum', delta=1.0) + loss_box = loss_box / len(conf) + + return loss_box + + +# global configs +batch_size = 4 +batch_num = 2000 +hidden_size = 5 +vocab_size = 100 + +conf_dataset = [[0], [0], [0], [0], [1], [0], [1], [0], [0], [1], [0], [1], + [1], [1], [1], [1], [1], [1], [1], [1], [1], [0], [0], [1]] + + +def fake_sample_reader(): + def __reader__(): + for i in range(batch_num): + x_data = np.random.randint(0, vocab_size) + y_data = np.random.random_sample((1, )).astype('float32') + conf_data = np.array(conf_dataset[i % len(conf_dataset)]).astype( + 'int64') + yield x_data, y_data, conf_data + + return __reader__ + + +class TestSimpleNet(TestParallelDyGraphRunnerBase): + def get_model(self): + model = SimpleNet( + hidden_size=hidden_size, vocab_size=vocab_size, is_sparse=False) + + train_reader = paddle.batch( + fake_sample_reader(), batch_size=batch_size, drop_last=True) + + optimizer = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model.parameters()) + + return model, train_reader, optimizer + + def run_one_loop(self, model, optimizer, batch): + x_data = np.array([x[0] for x in batch]).astype('int64') + y_data = np.array([x[1] for x in batch]).astype('float32') + conf_data = np.array([x[2] for x in batch]).astype('int64') + x_data = x_data.reshape((-1, 1)) + y_data = y_data.reshape((-1, 1)) + conf_data = conf_data.reshape((-1, 1)) + + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + conf = paddle.to_tensor(conf_data) + + loss = model(x, y, conf) + return loss + + +if __name__ == "__main__": + runtime_main(TestSimpleNet) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_control_flow_same.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_control_flow_same.py new file mode 100644 index 0000000000000000000000000000000000000000..3157d5e4129eebc5f74d09dd506c34371db009f0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_control_flow_same.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import contextlib +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.dygraph.base import to_variable + +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase + +np.random.seed(2021) +paddle.seed(1024) + +batch_size = 4 +batch_num = 1000 + + +class SimpleNet(fluid.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.net_a = paddle.nn.Sequential( + paddle.nn.Linear(10, 20), + paddle.nn.Linear(20, 20), paddle.nn.Linear(20, 5)) + self.net_b = paddle.nn.Sequential( + paddle.nn.Linear(10, 20), + paddle.nn.Linear(20, 20), paddle.nn.Linear(20, 5)) + self.net_unused = Linear(10, 20) + self.step = 0 + + def forward(self, x): + if self.step % 2 == 0: + return self.net_a(x) + else: + return self.net_b(x) + + self.step = self.step + 1 + + +def fake_sample_reader(): + def __reader__(): + for i in range(batch_num): + x_data = np.random.random_sample((10, )).astype('float32') + yield x_data + + return __reader__ + + +class TestSimpleNet(TestParallelDyGraphRunnerBase): + def get_model(self): + model = SimpleNet() + train_reader = paddle.batch( + fake_sample_reader(), batch_size=batch_size, drop_last=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model.parameters()) + return model, train_reader, optimizer + + def run_one_loop(self, model, optimizer, batch): + x_data = np.array([x for x in batch]) + x_data = x_data.reshape((-1, 10)) + x = to_variable(x_data) + out = model(x) + loss = out.sum() / len(batch) + return loss + + +if __name__ == "__main__": + runtime_main(TestSimpleNet) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2631fa108d28b55f6f9682853f382783bf9721 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py @@ -0,0 +1,136 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import paddle +import numpy as np +import paddle.distributed as dist +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear + +paddle.seed(1024) +np.random.seed(2021) + +batch = 5 +in_dim = 10 +out_dim = 20 + + +class SimpleNet(fluid.Layer): + def __init__(self, train_id): + super(SimpleNet, self).__init__() + self.w1 = self.create_parameter( + shape=[in_dim, out_dim], dtype="float32") + self.w2 = self.create_parameter( + shape=[in_dim, out_dim], dtype="float32") + self.share_net = Linear(out_dim, 10) + + self.unused_param = self.create_parameter( + shape=[out_dim, in_dim], dtype="float64") + + # just for test sync_params_buffers + self.register_buffer("queue", paddle.randn([10, 5])) + self.queue = paddle.nn.functional.normalize(self.queue, axis=0) + self.register_buffer("queue_ptr", paddle.zeros([1], 'int64')) + + self.trainer_id = train_id + + def forward(self, x): + is_use = (paddle.equal_all( + x, paddle.ones(shape=(batch, in_dim))).numpy()[0] and + self.trainer_id == 1) + + if is_use: + tmp = paddle.matmul(x, self.w1) + else: + tmp = paddle.matmul(x, self.w2) + + return self.share_net(tmp) + + +class TestDistTraning(unittest.TestCase): + def test_multiple_gpus(self): + dist.init_parallel_env() + self.trainer_id = dist.get_rank() + + model_a = SimpleNet(self.trainer_id) + model_b = SimpleNet(self.trainer_id) + + state_dict = model_a.state_dict() + model_b.set_state_dict(state_dict) + + model_a = paddle.DataParallel(model_a) + model_b = paddle.DataParallel(model_b) + + ones_input = paddle.ones(shape=(batch, in_dim)) + ones_input.stop_gradient = True + + w1_grad_sum = np.zeros((in_dim, out_dim), dtype='float32') + w2_grad_sum = np.zeros((in_dim, out_dim), dtype='float32') + + for step_id in range(5): + random_input = paddle.rand(shape=(batch, in_dim)) + random_input.stop_gradient = True + + if step_id % 2 == 0: + out_a = model_a(random_input) + out_b = model_b(random_input) + else: + out_a = model_a(ones_input) + out_b = model_b(ones_input) + + out_a.sum().backward() + out_b.sum().backward() + + self.check_gradient(model_a.parameters()) + self.check_gradient(model_b.parameters()) + + # test acc gradient + w1_grad_sum = self.check_acc(model_a._layers.w1.grad, w1_grad_sum, + model_b._layers.w1.grad) + w2_grad_sum = self.check_acc(model_a._layers.w2.grad, w2_grad_sum, + model_b._layers.w2.grad) + + model_a.clear_gradients() + + def check_acc(self, grad, grad_sum, acc_grad): + if grad is not None: + grad_sum = grad_sum + grad + np.testing.assert_allclose(grad_sum, acc_grad, rtol=1e-6) + return grad_sum + + def print_trainer_0(self, *args): + if self.trainer_id == 0: + print(*args) + + def broadcast_param(self, param, root): + paddle.distributed.broadcast(param, root) + return param + + def check_gradient(self, params): + other_param = [] + for param in params: + if param.trainable and (param._grad_ivar() is not None): + grad = param._grad_ivar() + other_grad = self.broadcast_param(grad.clone(), root=1) + if self.trainer_id == 0: + np.testing.assert_allclose(other_grad.numpy(), grad.numpy()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_none_var.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_none_var.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0246a9720bfdbeb740f54632f0bbacd7831479 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_none_var.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import contextlib +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.dygraph.nn import Linear + +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase + +np.random.seed(2021) +paddle.seed(1024) + +batch_size = 4 +batch_num = 1000 + + +class SimpleNet(fluid.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.net_a = paddle.nn.Sequential( + paddle.nn.Linear(10, 20), + paddle.nn.Linear(20, 20), paddle.nn.Linear(20, 5)) + self.net_b = paddle.nn.Sequential( + paddle.nn.Linear(10, 20), + paddle.nn.Linear(20, 20), paddle.nn.Linear(20, 5)) + self.step = 0 + + def forward(self, x): + return paddle.to_tensor(0.0, dtype='float32') + + +def fake_sample_reader(): + def __reader__(): + for i in range(batch_num): + x_data = np.random.random_sample((10, )).astype('float32') + yield x_data + + return __reader__ + + +class TestSimpleNet(TestParallelDyGraphRunnerBase): + def get_model(self): + model = SimpleNet() + train_reader = paddle.batch( + fake_sample_reader(), batch_size=batch_size, drop_last=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model.parameters()) + return model, train_reader, optimizer + + def run_one_loop(self, model, optimizer, batch): + x_data = np.array([x for x in batch]) + x_data = x_data.reshape((-1, 10)) + x = paddle.to_tensor(x_data) + out = model(x) + loss = out.sum() / len(batch) + return loss + + +if __name__ == "__main__": + runtime_main(TestSimpleNet) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_shared_unused_var.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_shared_unused_var.py new file mode 100644 index 0000000000000000000000000000000000000000..facac33e4c60ec884c104a9ab069ab4875490c61 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_shared_unused_var.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear +from paddle.fluid.dygraph.base import to_variable +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase + +np.random.seed(2021) +paddle.seed(1024) + + +class SimpleNet(fluid.Layer): + def __init__(self): + # bias is unused parameters, and it share with net_a + super(SimpleNet, self).__init__() + self.net_a = Linear(input_dim=10, output_dim=5) + self.net_b = Linear(10, 10) + self.bias = self.net_a.bias + + def forward(self, x): + return self.net_b(x) + + +batch_size = 4 +batch_num = 1000 + + +def fake_sample_reader(): + def __reader__(): + for i in range(batch_num): + x_data = np.random.random_sample((10, )).astype('float32') + yield x_data + + return __reader__ + + +class TestSimpleNet(TestParallelDyGraphRunnerBase): + def get_model(self): + model = SimpleNet() + train_reader = paddle.batch( + fake_sample_reader(), batch_size=batch_size, drop_last=True) + optimizer = paddle.optimizer.SGD(learning_rate=0.001, + parameters=model.parameters()) + return model, train_reader, optimizer + + def run_one_loop(self, model, optimizer, batch): + x_data = np.array([x for x in batch]) + x_data = x_data.reshape((-1, 10)) + x = to_variable(x_data) + out = model(x) + loss = out.sum() / len(batch) + return loss + + +if __name__ == "__main__": + runtime_main(TestSimpleNet) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py index 65c242a7023093413ff0389aa685ddc817eea028..a15b263a295086271efd5095e61f7d4a42857db9 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sparse_embedding_fp64.py @@ -65,8 +65,6 @@ class SimpleNet(Layer): def forward(self, input, label): x_emb = self.embedding(input) fc = paddle.matmul(x_emb, self.softmax_weight) - # use detach to stop gradient - fc = fc.detach() fc = paddle.add(fc, self.softmax_bias) projection = paddle.reshape(fc, shape=[-1, self.vocab_size]) loss = paddle.nn.functional.softmax_with_cross_entropy( diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py index 1884eef15e9a409319820ac61444094755116e86..9f877381101e96fc57dbce127dad18a034e3b2a1 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_unused_variables.py @@ -37,7 +37,7 @@ class SimpleNet(Layer): self.embedding = Embedding( self.vocab_size, self.hidden_size, - sparse=True, + sparse=is_sparse, weight_attr=paddle.ParamAttr( name='embedding_param', initializer=paddle.nn.initializer.Uniform( @@ -105,7 +105,7 @@ class TestSparseEmbeddingUnusedVars(TestParallelDyGraphRunnerBase): vocab_size=vocab_size, num_steps=num_steps, init_scale=init_scale, - is_sparse=True) + is_sparse=False) train_reader = paddle.batch( fake_sample_reader(), batch_size=batch_size, drop_last=True) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index d73698e7e024a8f8508ac67fdcd6f2026be4cb38..fa5ce28398593b7aecf12d58aa88227cc9011f79 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -501,7 +501,12 @@ class TestParallelDyGraphRunnerBase(object): type(self).__name__, "begin to prepare context in dygraph with nccl2") dygraph.parallel.prepare_context(strategy) - model = dygraph.parallel.DataParallel(model, strategy) + if not args.find_unused_parameters: + model = dygraph.parallel.DataParallel( + model, strategy, find_unused_parameters=False) + else: + model = dygraph.parallel.DataParallel( + model, strategy, find_unused_parameters=True) print_to_err(type(self).__name__, "model built in dygraph") out_losses = [] print_to_err(type(self).__name__, "begin to run dygraph training") @@ -574,9 +579,14 @@ class TestParallelDyGraphRunnerBase(object): # get trainer id args.trainer_id = paddle.distributed.get_rank() + # set strategy + strategy = fleet.DistributedStrategy() + if not args.find_unused_parameters: + strategy.find_unused_parameters = False + # 3. init parallel env if args.update_method == "nccl2" or "bkcl": - fleet.init(is_collective=True) + fleet.init(is_collective=True, strategy=strategy) # 4. train model model, train_reader, opt = self.get_model() @@ -628,6 +638,7 @@ def runtime_main(test_class): parser.add_argument('--use_xpu', action='store_true') parser.add_argument('--use_dgc', action='store_true') parser.add_argument('--accumulate_gradient', action='store_true') + parser.add_argument('--find_unused_parameters', action='store_true') parser.add_argument('--use_reduce', action='store_true') parser.add_argument('--dc_asgd', action='store_true') parser.add_argument('--hogwild', action='store_true') @@ -726,6 +737,7 @@ class TestDistBase(unittest.TestCase): self._save_model = False self._fuse_all_reduce = None self._accumulate_gradient = False + self._find_unused_parameters = True self._setup_config() global DIST_UT_PORT @@ -852,6 +864,9 @@ class TestDistBase(unittest.TestCase): if self._accumulate_gradient: cmd += " --accumulate_gradient" + if self._find_unused_parameters: + cmd += " --find_unused_parameters" + env_local.update(envs) print("local_cmd: {}, env: {}".format(cmd, env_local)) @@ -1021,6 +1036,9 @@ class TestDistBase(unittest.TestCase): if self._accumulate_gradient: tr_cmd += " --accumulate_gradient" + if self._find_unused_parameters: + tr_cmd += " --find_unused_parameters" + if self._pipeline_mode: tr_cmd += " --use_pipeline" if self._mp_mode: diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index 31771ddbd687449fc8c96e60d8524e8f2e5024be..d843e172763fe5506d32d961a68db6dbeabd9d14 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -179,6 +179,15 @@ class TestStrategyConfig(unittest.TestCase): with self.assertRaises(ValueError): strategy.last_comm_group_size_MB = -1 + def test_find_unused_parameters(self): + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.find_unused_parameters = True + self.assertEqual(strategy.find_unused_parameters, True) + strategy.find_unused_parameters = False + self.assertEqual(strategy.find_unused_parameters, False) + strategy.find_unused_parameters = "True" + self.assertEqual(strategy.find_unused_parameters, False) + def test_fuse_grad_size_in_TFLOPS(self): strategy = paddle.distributed.fleet.DistributedStrategy() strategy._fuse_grad_size_in_TFLOPS = 0.1 diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_control_flow.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_control_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..fa571bde5e43bfb910c8825afae14ef7735f6488 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_control_flow.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import sys +import unittest + +import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner + +flag_name = os.path.splitext(__file__)[0] + + +class TestDygraphControlFlowSame(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + + def test_net(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_control_flow_same.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestFleetDygraphControlFlowSame(TestDygraphControlFlowSame): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + self._use_fleet_api = True + + +class TestFleetDygraphControlFlowSameAccGrad(TestDygraphControlFlowSame): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + self._accumulate_gradient = True + + +class TestDygraphControlFlowDiff(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + + def test_net(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_control_flow_different.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestFleetDygraphControlFlowDiff(TestDygraphControlFlowDiff): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + self._use_fleet_api = True + + +class TestFleetDygraphControlFlowDiffAccGrad(TestDygraphControlFlowDiff): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + self._accumulate_gradient = True + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py new file mode 100644 index 0000000000000000000000000000000000000000..1d2a39751905e24acfc1666cfd22952b673cf698 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py @@ -0,0 +1,75 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import time +import paddle.fluid as fluid + +from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, get_gpus, start_local_trainers + + +def get_cluster_from_args(selected_gpus): + cluster_node_ips = '127.0.0.1' + node_ip = '127.0.0.1' + + node_ips = [x.strip() for x in cluster_node_ips.split(',')] + + node_ips.index(node_ip) + + free_ports = None + + free_ports = find_free_ports(len(selected_gpus)) + if free_ports is not None: + free_ports = list(free_ports) + + trainer_endpoints = [] + for ip in node_ips: + trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports]) + return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) + + +class TestMultipleGpus(unittest.TestCase): + def run_mnist_2gpu(self, target_file_name): + if not fluid.core.is_compiled_with_cuda( + ) or fluid.core.get_cuda_device_count() == 0: + return + + selected_gpus = get_gpus('0,1') + cluster = None + pod = None + + cluster, pod = get_cluster_from_args(selected_gpus) + + procs = start_local_trainers( + cluster, + pod, + training_script=target_file_name, + training_script_args=[]) + + while True: + alive = watch_local_trainers(procs, cluster.trainers_nranks()) + + if not alive: + print("Local procs complete, POD info:{}".format(pod)) + break + time.sleep(3) + + def test_multiple_gpus_dynamic(self): + self.run_mnist_2gpu('parallel_dygraph_gradient_check.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py index a3a3c5bfe3df59b9e0094b67db708d377865e6ed..782d2304619f2a08b772b52f620910764cd4aff5 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py @@ -73,6 +73,7 @@ class TestParallelDygraphMnistAccGrad(TestDistBase): self._dygraph = True self._use_fleet_api = True self._accumulate_gradient = True + self._find_unused_parameters = False def test_mnist(self): if fluid.core.is_compiled_with_cuda(): diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py index bef64385f135b3d2177bce6a65125f23e1f315e3..e0aab8541a542c9adfc3fb6c9323e7d54854578b 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer.py @@ -54,6 +54,7 @@ class TestParallelDygraphTransformerAccGrad(TestDistBase): self._nccl2_mode = True self._dygraph = True self._accumulate_gradient = True + self._find_unused_parameters = False def test_transformer(self): if fluid.core.is_compiled_with_cuda(): diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py index 5906114cd24f328d20a99f67bb8f59c73a97c30b..75fa6f7c71d0a53c238ec4a9f7bebe905017531f 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables.py @@ -26,13 +26,13 @@ from parallel_dygraph_unused_variables import TestSparseEmbeddingUnusedVars flag_name = os.path.splitext(__file__)[0] -class TestParallelDygraphMnist(TestDistBase): +class TestParallelDygraphUnusedVar(TestDistBase): def _setup_config(self): self._sync_mode = False self._nccl2_mode = True self._dygraph = True - def test_mnist(self): + def test_net(self): if fluid.core.is_compiled_with_cuda(): self.check_with_place( "parallel_dygraph_unused_variables.py", @@ -41,6 +41,14 @@ class TestParallelDygraphMnist(TestDistBase): log_name=flag_name) +class TestFleetDygraphUnusedVar(TestParallelDygraphUnusedVar): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + self._use_fleet_api = True + + class TestSparseEmbeddingUnusedVarsSpawn(TestDistSpawnRunner): def test_mnist_with_spawn(self): if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): @@ -48,17 +56,31 @@ class TestSparseEmbeddingUnusedVarsSpawn(TestDistSpawnRunner): test_class=TestSparseEmbeddingUnusedVars, delta=1e-5) -class TestFleetDygraphMnist(TestDistBase): +class TestParallelDygraphNoVar(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + + def test_net(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_none_var.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSharedUnusedVariables(TestDistBase): def _setup_config(self): self._sync_mode = False self._nccl2_mode = True self._dygraph = True - self._use_fleet_api = True def test_mnist(self): if fluid.core.is_compiled_with_cuda(): self.check_with_place( - "parallel_dygraph_unused_variables.py", + "parallel_dygraph_shared_unused_var.py", delta=1e-5, check_error_log=True, log_name=flag_name)