From b2e3824e4149e592635e1938188415b663446a8d Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 14 Aug 2017 15:34:38 +0800 Subject: [PATCH] change operator --- paddle/framework/op_registry.h | 25 ++++++++++++------------- paddle/framework/operator.h | 6 ++++-- paddle/operators/net_op.cc | 4 ++-- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index e93ee144254..55cf7fbe31f 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -120,8 +120,10 @@ class OpProtoAndCheckerMaker { }; class OpRegistry { - using OpCreator = std::function; using VarNameMap = OperatorBase::VarNameMap; + using OpCreator = std::function; public: template @@ -153,14 +155,9 @@ class OpRegistry { PADDLE_ENFORCE(op_create_it != op_creators().end(), "Operator %s cannot be found.", type); - auto op = op_create_it->second(); - op->type_ = type; - op->inputs_ = inputs; - op->outputs_ = outputs; - - op->attrs_ = attrs; - op_checkers().at(type).Check(op->attrs_); - + auto attrMap = attrs; + op_checkers().at(type).Check(attrMap); + auto op = op_create_it->second(type, inputs, outputs, attrMap); GenerateTempVariableName(op); op->Init(); @@ -217,12 +214,14 @@ class OpRegistry { static void GenerateTempVariableName(OperatorBase* op) { static std::atomic gUniqId(0UL); - for (auto& output : op->outputs_) { + for (auto& output : op->Outputs()) { for (auto& output_name : output.second) { if (output_name == kTempVarName) { - output_name += op->type_; - output_name += "@"; - output_name += std::to_string(gUniqId.fetch_add(1)); + auto new_name = output_name; + new_name += op->Type(); + new_name += "@"; + new_name += std::to_string(gUniqId.fetch_add(1)); + op->Rename(output_name, new_name); } } } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index e145649d300..038e6fe7a25 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -105,6 +105,8 @@ class OperatorBase { /// rename inputs outputs name void Rename(const std::string& old_name, const std::string& new_name); + const VarNameMap& Inputs() const { return inputs_; } + const VarNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; //! Get a input which has multiple variables. @@ -118,10 +120,10 @@ class OperatorBase { virtual std::vector OutputVars(bool has_intermediate) const; - std::string Type() const { return type_; } + const std::string& Type() const { return type_; } const AttributeMap& Attrs() const { return attrs_; } - public: + protected: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: // I (Inputs) diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index 6a118087a73..61e1377af8a 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -29,7 +29,7 @@ void NetOp::CompleteAddOp(bool calc) { std::set input_set; std::set output_set; for (auto& op : ops_) { - for (auto& ipt : op->inputs_) { + for (auto& ipt : op->Inputs()) { for (auto& var_name : ipt.second) { if (!Contains(output_set, var_name)) { // Not other op's output input_set.insert(var_name); @@ -39,7 +39,7 @@ void NetOp::CompleteAddOp(bool calc) { } } - for (auto& opt : op->outputs_) { + for (auto& opt : op->Outputs()) { for (auto& var_name : opt.second) { output_set.insert(var_name); } -- GitLab