From 0441c2cc45feab5e5f7cc67fc2c196379b140589 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 5 May 2018 13:16:27 +0800 Subject: [PATCH] fix ci --- .../details/multi_devices_graph_builder.cc | 30 +++++++++++-------- .../details/multi_devices_graph_builder.h | 4 ++- paddle/fluid/framework/details/var_handle.h | 15 ++++------ 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index da524cc7928..37100b529d0 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -107,6 +107,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { + std::unordered_map var_types; + for (auto *var : program.Block(0).AllVars()) { + var_types[var->Name()] = var->GetType(); + } auto graph = new SSAGraph(); SSAGraph &result = *graph; std::unordered_set og_has_been_broadcast; @@ -116,7 +120,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_map>>>( places_.size()); - size_t cur_update_sparse_gp_dev_id = 0; + size_t cur_dev_id = 0; std::vector> sparse_var_name_on_devices; std::vector> bcast_sparse_var_name_set; @@ -156,14 +160,12 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( // broadcast, and each gradient is only broadcast once. for (auto &og : op->OutputArgumentNames()) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { - if (IsSparseGradient(og)) { - CreateReduceOp(&result, cur_update_sparse_gp_dev_id, og); - sparse_var_name_on_devices[cur_update_sparse_gp_dev_id].emplace( - og); - bcast_sparse_var_name_set[cur_update_sparse_gp_dev_id].emplace( + if (IsSparseGradient(var_types, og)) { + CreateReduceOp(&result, cur_dev_id, og); + sparse_var_name_on_devices[cur_dev_id].emplace(og); + bcast_sparse_var_name_set[cur_dev_id].emplace( og.substr(0, og.size() - strlen(kGradVarSuffix))); - cur_update_sparse_gp_dev_id = - (cur_update_sparse_gp_dev_id + 1) % places_.size(); + cur_dev_id = (cur_dev_id + 1) % places_.size(); } else { InsertNCCLAllReduceOp(&result, og); } @@ -201,10 +203,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( return std::unique_ptr(graph); } -bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { - auto og_var = local_scopes_[0]->FindVar(og); - PADDLE_ENFORCE_NOT_NULL(og_var); - return og_var->IsType(); +bool MultiDevSSAGraphBuilder::IsSparseGradient( + const std::unordered_map &var_types, + const std::string &og) const { + PADDLE_ENFORCE(var_types.count(og) != 0); + if (var_types.at(og) == proto::VarType::SELECTED_ROWS) { + return true; + } + return false; } int MultiDevSSAGraphBuilder::GetOpDeviceID( diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index cf40ea52786..1672958b223 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -99,7 +99,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { */ OpDesc *GetSendOpDesc(const ProgramDesc &program) const; - bool IsSparseGradient(const std::string &og) const; + bool IsSparseGradient( + const std::unordered_map &var_types, + const std::string &og) const; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index 2ccd76df852..7f30a6573b3 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -63,15 +63,12 @@ struct VarHandle : public VarHandleBase { platform::Place place_; // NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four - // member - // variables(version_, scope_id_, name_, place_) must be equal. But sometimes - // judging whether the two var_handle is equal is actually to determine - // whether - // the two Variables that represented by var_handle is the same. And the same - // Variable may have many different var_handles, the version_ of these - // var_handles - // is different. So I don't take care of version_ temporarily when overloading - // equal. + // member variables(version_, scope_id_, name_, place_) must be equal. But + // sometimes judging whether the two var_handle is equal is actually to + // determine whether the two Variables that represented by var_handle is the + // same. And the same Variable may have many different var_handles, the + // version_ of these var_handles is different. So I don't take care of + // version_ temporarily when overloading equal. bool operator==(const VarHandle& o) const { return o.generated_op_ == generated_op_ && o.name_ == name_ && o.scope_idx_ == scope_idx_; -- GitLab