From 44c52a8c1a7a310057da6c4a004be665e9f3dd99 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Tue, 15 May 2018 15:47:12 +0800 Subject: [PATCH] Polish op_proto_maker --- .../details/multi_devices_graph_builder.cc | 17 +++++++++------ paddle/fluid/framework/op_desc.cc | 8 +++++++ paddle/fluid/framework/op_proto_maker.cc | 6 ++++-- python/paddle/fluid/backward.py | 5 ++++- python/paddle/fluid/framework.py | 21 +++++++++++-------- 5 files changed, 39 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 4755559f8d0..428efb4ace8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -163,8 +163,13 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( if (!is_forwarding && places_.size() > 1) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. - for (auto &og : op->OutputArgumentNames()) { - if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { + if (static_cast(boost::get(op->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) & + static_cast(OpRole::kBackward))) { + auto &backward_vars = boost::get>( + op->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + + for (auto &og : backward_vars) { if (balance_parameter_opt_between_cards_) { CreateReduceOp(&result, og, cur_device_id); var_name_on_devices[cur_device_id].emplace(og); @@ -399,11 +404,11 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, } bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { - // FIXME(yy): Do not hard code like this - return op.OutputArgumentNames().size() == 1 && - op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); + return boost::get( + op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss)); } - } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 076c4571301..b68421afed9 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include #include "glog/logging.h" #include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/shape_inference.h" @@ -249,6 +250,13 @@ void OpDesc::RenameOutput(const std::string &old_name, std::replace(output.second.begin(), output.second.end(), old_name, new_name); } + + auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName()); + if (it != attrs_.end()) { + auto &op_vars = boost::get>(it->second); + std::replace(op_vars.begin(), op_vars.end(), old_name, new_name); + } + need_update_ = true; } diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index a2e46c7a597..6070ade7e03 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_proto_maker.h" #include +#include namespace paddle { namespace framework { @@ -69,8 +70,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, static_cast(OpRole::kLoss) | static_cast(OpRole::kForward), static_cast(OpRole::kLoss) | static_cast(OpRole::kBackward)}); - AddAttr(OpRoleVarAttrName(), "Optimized for variable") - .SetDefault(""); + AddAttr>(OpRoleVarAttrName(), + "Optimized for variable") + .SetDefault({}); Validate(); } diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 59940636e5e..fea509874d2 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -534,7 +534,10 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, if g.op is None: raise ValueError("Unexpected branch") - g.op.set_attr(op_role_var_attr_name, p.name) + attr_val = [p.name] + if g.op.has_attr(op_role_var_attr_name): + attr_val.extend(g.op.attr(op_role_var_attr_name)) + g.op.set_attr(op_role_var_attr_name, attr_val) return params_and_grads diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 9e7c8509b1d..5b222513c1f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -410,10 +410,14 @@ class Operator(object): if op_maker.kOpRoleAttrName() not in self.attrs: self.attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role - if len(self.block.program.op_role_var - ) != 0 and op_maker.kOpRoleVarAttrName() not in self.attrs: - self.attrs[op_maker.kOpRoleVarAttrName( - )] = self.block.program.op_role_var + + role_var_name = op_maker.kOpRoleVarAttrName() + if len(self.block.program. + op_role_var) != 0 and role_var_name not in self.attrs: + self.attrs[role_var_name] = self.block.program.op_role_var + + if role_var_name in self.attrs and len(self.attrs[role_var_name]) == 0: + del self.attrs[role_var_name] if len(self.desc.type()) != 0: return @@ -497,7 +501,6 @@ class Operator(object): attr_name, self.attrs[attr_name].serialize_to_string()) else: self.desc.set_attr(attr_name, self.attrs[attr_name]) - self.desc.check_attrs() no_kernel_op_set = { 'feed', 'fetch', 'save', 'load', 'recurrent', 'go', @@ -1020,7 +1023,7 @@ class Program(object): self.current_block_idx = 0 self._seed = 0 self._current_role = core.op_proto_and_checker_maker.OpRole.Forward - self._op_role_var = "" + self._op_role_var = [] @property def op_role(self): @@ -1036,15 +1039,15 @@ class Program(object): @op_role_var.setter def set_op_role_var(self, var_name): - self._op_role_var = var_name + self._op_role_var = [var_name] @contextlib.contextmanager def optimized_guard(self, var): OpRole = core.op_proto_and_checker_maker.OpRole self._current_role = OpRole.Optimize - self._op_role_var = var.name if isinstance(var, Variable) else var + self._op_role_var = [var.name if isinstance(var, Variable) else var] yield - self._op_role_var = "" + self._op_role_var = [] self._current_role = OpRole.Forward def __str__(self): -- GitLab