提交 0441c2cc 编写于 作者: C chengduoZH

fix ci

上级 f9c680c4
...@@ -107,6 +107,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, ...@@ -107,6 +107,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unordered_map<std::string, proto::VarType::Type> var_types;
for (auto *var : program.Block(0).AllVars()) {
var_types[var->Name()] = var->GetType();
}
auto graph = new SSAGraph(); auto graph = new SSAGraph();
SSAGraph &result = *graph; SSAGraph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
...@@ -116,7 +120,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -116,7 +120,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size()); places_.size());
size_t cur_update_sparse_gp_dev_id = 0; size_t cur_dev_id = 0;
std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices; std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set; std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
...@@ -156,14 +160,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -156,14 +160,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
if (IsSparseGradient(og)) { if (IsSparseGradient(var_types, og)) {
CreateReduceOp(&result, cur_update_sparse_gp_dev_id, og); CreateReduceOp(&result, cur_dev_id, og);
sparse_var_name_on_devices[cur_update_sparse_gp_dev_id].emplace( sparse_var_name_on_devices[cur_dev_id].emplace(og);
og); bcast_sparse_var_name_set[cur_dev_id].emplace(
bcast_sparse_var_name_set[cur_update_sparse_gp_dev_id].emplace(
og.substr(0, og.size() - strlen(kGradVarSuffix))); og.substr(0, og.size() - strlen(kGradVarSuffix)));
cur_update_sparse_gp_dev_id = cur_dev_id = (cur_dev_id + 1) % places_.size();
(cur_update_sparse_gp_dev_id + 1) % places_.size();
} else { } else {
InsertNCCLAllReduceOp(&result, og); InsertNCCLAllReduceOp(&result, og);
} }
...@@ -201,10 +203,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -201,10 +203,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { bool MultiDevSSAGraphBuilder::IsSparseGradient(
auto og_var = local_scopes_[0]->FindVar(og); const std::unordered_map<std::string, proto::VarType::Type> &var_types,
PADDLE_ENFORCE_NOT_NULL(og_var); const std::string &og) const {
return og_var->IsType<SelectedRows>(); PADDLE_ENFORCE(var_types.count(og) != 0);
if (var_types.at(og) == proto::VarType::SELECTED_ROWS) {
return true;
}
return false;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID( int MultiDevSSAGraphBuilder::GetOpDeviceID(
......
...@@ -99,7 +99,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -99,7 +99,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
*/ */
OpDesc *GetSendOpDesc(const ProgramDesc &program) const; OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -63,15 +63,12 @@ struct VarHandle : public VarHandleBase { ...@@ -63,15 +63,12 @@ struct VarHandle : public VarHandleBase {
platform::Place place_; platform::Place place_;
// NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four // NOTE(zcd): Strictly speaking, if the two var_handle is equal, the four
// member // member variables(version_, scope_id_, name_, place_) must be equal. But
// variables(version_, scope_id_, name_, place_) must be equal. But sometimes // sometimes judging whether the two var_handle is equal is actually to
// judging whether the two var_handle is equal is actually to determine // determine whether the two Variables that represented by var_handle is the
// whether // same. And the same Variable may have many different var_handles, the
// the two Variables that represented by var_handle is the same. And the same // version_ of these var_handles is different. So I don't take care of
// Variable may have many different var_handles, the version_ of these // version_ temporarily when overloading equal.
// var_handles
// is different. So I don't take care of version_ temporarily when overloading
// equal.
bool operator==(const VarHandle& o) const { bool operator==(const VarHandle& o) const {
return o.generated_op_ == generated_op_ && o.name_ == name_ && return o.generated_op_ == generated_op_ && o.name_ == name_ &&
o.scope_idx_ == scope_idx_; o.scope_idx_ == scope_idx_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册