diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index da524cc79286472f6a18ae5c9fa783a2e03d03be..37100b529d0677a010afe11a1d659088ebbbc54d 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -107,6 +107,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { + std::unordered_map var_types; + for (auto *var : program.Block(0).AllVars()) { + var_types[var->Name()] = var->GetType(); + } auto graph = new SSAGraph(); SSAGraph &result = *graph; std::unordered_set og_has_been_broadcast; @@ -116,7 +120,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_map>>>( places_.size()); - size_t cur_update_sparse_gp_dev_id = 0; + size_t cur_dev_id = 0; std::vector> sparse_var_name_on_devices; std::vector> bcast_sparse_var_name_set; @@ -156,14 +160,12 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( // broadcast, and each gradient is only broadcast once. for (auto &og : op->OutputArgumentNames()) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { - if (IsSparseGradient(og)) { - CreateReduceOp(&result, cur_update_sparse_gp_dev_id, og); - sparse_var_name_on_devices[cur_update_sparse_gp_dev_id].emplace( - og); - bcast_sparse_var_name_set[cur_update_sparse_gp_dev_id].emplace( + if (IsSparseGradient(var_types, og)) { + CreateReduceOp(&result, cur_dev_id, og); + sparse_var_name_on_devices[cur_dev_id].emplace(og); + bcast_sparse_var_name_set[cur_dev_id].emplace( og.substr(0, og.size() - strlen(kGradVarSuffix))); - cur_update_sparse_gp_dev_id = - (cur_update_sparse_gp_dev_id + 1) % places_.size(); + cur_dev_id = (cur_dev_id + 1) % places_.size(); } else { InsertNCCLAllReduceOp(&result, og); } @@ -201,10 +203,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( return std::unique_ptr(graph); } -bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { - auto og_var = local_scopes_[0]->FindVar(og); - PADDLE_ENFORCE_NOT_NULL(og_var); - return og_var->IsType(); +bool MultiDevSSAGraphBuilder::IsSparseGradient( + const std::unordered_map &var_types, + const std::string &og) const { + PADDLE_ENFORCE(var_types.count(og) != 0); + if (var_types.at(og) == proto::VarType::SELECTED_ROWS) { + return true; + } + return false; } int MultiDevSSAGraphBuilder::GetOpDeviceID( diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index cf40ea5278676db2cda56d06685fd45f00392cc0..1672958b223625dd2d9cd3e10e2db91b9db0094e 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -99,7 +99,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { */ OpDesc *GetSendOpDesc(const ProgramDesc &program) const; - bool IsSparseGradient(const std::string &og) const; + bool IsSparseGradient( + const std::unordered_map &var_types, + const std::string &og) const; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index 2ccd76df8528efaa096311224d4ff412396c9f71..7f30a6573b3f50b10a2b7a67a7d4655a963a6837 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -63,15 +63,12 @@ struct VarHandle : public VarHandleBase { platform::Place place_; // NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four - // member - // variables(version_, scope_id_, name_, place_) must be equal. But sometimes - // judging whether the two var_handle is equal is actually to determine - // whether - // the two Variables that represented by var_handle is the same. And the same - // Variable may have many different var_handles, the version_ of these - // var_handles - // is different. So I don't take care of version_ temporarily when overloading - // equal. + // member variables(version_, scope_id_, name_, place_) must be equal. But + // sometimes judging whether the two var_handle is equal is actually to + // determine whether the two Variables that represented by var_handle is the + // same. And the same Variable may have many different var_handles, the + // version_ of these var_handles is different. So I don't take care of + // version_ temporarily when overloading equal. bool operator==(const VarHandle& o) const { return o.generated_op_ == generated_op_ && o.name_ == name_ && o.scope_idx_ == scope_idx_;