From fb370f44113c843d5d46a77ea59ec6ec253f0f90 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Tue, 22 May 2018 22:51:54 +0800 Subject: [PATCH] Refine code --- .../details/multi_devices_graph_builder.cc | 47 +++++++++++-------- paddle/fluid/framework/op_desc.cc | 5 +- paddle/fluid/framework/op_desc.h | 3 +- python/paddle/fluid/backward.py | 2 +- 4 files changed, 31 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 447dfa9655f..26879a7cd91 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -163,27 +163,34 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( if (static_cast(boost::get(op->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) & static_cast(OpRole::kBackward))) { - auto backward_vars = boost::get>( - op->GetAttrOrDefault(OpProtoAndCheckerMaker::OpRoleVarAttrName(), - std::vector())); - for (auto &og : backward_vars) { - switch (strategy_.reduce_) { - case BuildStrategy::ReduceStrategy::kReduce: - CreateReduceOp(&result, og, cur_device_id); - var_name_on_devices[cur_device_id].emplace(og); - bcast_var_name_set[cur_device_id].emplace( - og.substr(0, og.size() - strlen(kGradVarSuffix))); - cur_device_id = (cur_device_id + 1) % places_.size(); - break; - case BuildStrategy::ReduceStrategy::kAllReduce: - if (IsSparseGradient(var_types, og)) { - CreateReduceOp(&result, og, 0); - CreateBroadcastOp(&result, og, 0); - } else { - InsertNCCLAllReduceOp(&result, og); - } - break; + try { + auto backward_vars = + boost::get>(op->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + + PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); + + for (size_t i = 0; i < backward_vars.size(); ++i) { + auto &p_name = backward_vars[i]; + auto &g_name = backward_vars[i + 1]; + switch (strategy_.reduce_) { + case BuildStrategy::ReduceStrategy::kReduce: + CreateReduceOp(&result, g_name, cur_device_id); + var_name_on_devices[cur_device_id].emplace(g_name); + bcast_var_name_set[cur_device_id].emplace(p_name); + cur_device_id = (cur_device_id + 1) % places_.size(); + break; + case BuildStrategy::ReduceStrategy::kAllReduce: + if (IsSparseGradient(var_types, g_name)) { + CreateReduceOp(&result, g_name, 0); + CreateBroadcastOp(&result, g_name, 0); + } else { + InsertNCCLAllReduceOp(&result, g_name); + } + break; + } } + } catch (boost::bad_get e) { } } } diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index d14d9cb8ab8..1b9c6858667 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -223,13 +223,12 @@ Attribute OpDesc::GetAttr(const std::string &name) const { return it->second; } -Attribute OpDesc::GetAttrOrDefault( - const std::string &name, paddle::framework::Attribute default_attr) const { +Attribute OpDesc::GetNullableAttr(const std::string &name) const { auto it = attrs_.find(name); if (it != attrs_.end()) { return it->second; } else { - return default_attr; + return Attribute(); } } diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 82542a83c50..1a330db7cc5 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -78,8 +78,7 @@ class OpDesc { Attribute GetAttr(const std::string &name) const; - Attribute GetAttrOrDefault(const std::string &name, - Attribute default_attr) const; + Attribute GetNullableAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const; diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index d90e2782223..bd14eadede9 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -536,7 +536,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, if g.op is None: raise ValueError("Unexpected branch") - attr_val = [p.name] + attr_val = [p.name, g.name] if g.op.has_attr(op_role_var_attr_name): attr_val.extend(g.op.attr(op_role_var_attr_name)) g.op.set_attr(op_role_var_attr_name, attr_val) -- GitLab