diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index e93ee1442544951d1cc972a43980154d33a8602c..55cf7fbe31f86306a5456465b4232bedf525499a 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -120,8 +120,10 @@ class OpProtoAndCheckerMaker { }; class OpRegistry { - using OpCreator = std::function; using VarNameMap = OperatorBase::VarNameMap; + using OpCreator = std::function; public: template @@ -153,14 +155,9 @@ class OpRegistry { PADDLE_ENFORCE(op_create_it != op_creators().end(), "Operator %s cannot be found.", type); - auto op = op_create_it->second(); - op->type_ = type; - op->inputs_ = inputs; - op->outputs_ = outputs; - - op->attrs_ = attrs; - op_checkers().at(type).Check(op->attrs_); - + auto attrMap = attrs; + op_checkers().at(type).Check(attrMap); + auto op = op_create_it->second(type, inputs, outputs, attrMap); GenerateTempVariableName(op); op->Init(); @@ -217,12 +214,14 @@ class OpRegistry { static void GenerateTempVariableName(OperatorBase* op) { static std::atomic gUniqId(0UL); - for (auto& output : op->outputs_) { + for (auto& output : op->Outputs()) { for (auto& output_name : output.second) { if (output_name == kTempVarName) { - output_name += op->type_; - output_name += "@"; - output_name += std::to_string(gUniqId.fetch_add(1)); + auto new_name = output_name; + new_name += op->Type(); + new_name += "@"; + new_name += std::to_string(gUniqId.fetch_add(1)); + op->Rename(output_name, new_name); } } } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index e145649d300d57425b9c83bd7daa4149cb698e2c..038e6fe7a2526ed1972a9c863bbd85f690116b0f 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -105,6 +105,8 @@ class OperatorBase { /// rename inputs outputs name void Rename(const std::string& old_name, const std::string& new_name); + const VarNameMap& Inputs() const { return inputs_; } + const VarNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; //! Get a input which has multiple variables. @@ -118,10 +120,10 @@ class OperatorBase { virtual std::vector OutputVars(bool has_intermediate) const; - std::string Type() const { return type_; } + const std::string& Type() const { return type_; } const AttributeMap& Attrs() const { return attrs_; } - public: + protected: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: // I (Inputs) diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index 6a118087a73af29ccfb29d442acff6a0c9501512..61e1377af8ae5b93c08a5920e182e6f8da4376d1 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -29,7 +29,7 @@ void NetOp::CompleteAddOp(bool calc) { std::set input_set; std::set output_set; for (auto& op : ops_) { - for (auto& ipt : op->inputs_) { + for (auto& ipt : op->Inputs()) { for (auto& var_name : ipt.second) { if (!Contains(output_set, var_name)) { // Not other op's output input_set.insert(var_name); @@ -39,7 +39,7 @@ void NetOp::CompleteAddOp(bool calc) { } } - for (auto& opt : op->outputs_) { + for (auto& opt : op->Outputs()) { for (auto& var_name : opt.second) { output_set.insert(var_name); }