diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 491ee21eec93c270cbe405e3bcbcb02b18af8fc7..c41fe10729501698fd07f59456f64ac26df77f08 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -199,6 +199,7 @@ Add a mark to which output is temporary is helpful for future optimization. class OpRegistry { using OpCreator = std::function; using VarIndexMap = std::unordered_map; + using VarNameList = std::vector; public: template @@ -226,42 +227,51 @@ class OpRegistry { } } - static OperatorPtr CreateOp(const OpDesc& op_desc) { - //! Create a OpPtr by type. - std::string op_type = op_desc.type(); - OperatorPtr op(creators().at(op_type)()); - //! Fill op's data member. Not use constructor because it will be noising - //! for Op developer. - op->type_ = op_desc.type(); - // set op's inputs_ from desc. - op->inputs_.reserve((size_t)op_desc.inputs_size()); - std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), - std::back_inserter(op->inputs_)); - // set op's outputs_ from desc. - op->outputs_.reserve((size_t)op_desc.outputs_size()); - std::copy(op_desc.outputs().begin(), op_desc.outputs().end(), - std::back_inserter(op->outputs_)); + static OperatorPtr CreateOp(const std::string& type, + const VarNameList& inputs, + const VarNameList& outputs, + const AttributeMap& attrs) { + auto op_create_it = creators().find(type); + PADDLE_ENFORCE(op_create_it != creators().end(), + "Operator %s cannot be found", type); - //! Fill attrs, and validate attrs. - for (auto& attr : op_desc.attrs()) { - op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); - } - op_checkers().at(op_type).Check(op->attrs_); + auto op = op_create_it->second(); + op->type_ = type; + op->inputs_ = inputs; + op->outputs_ = outputs; + op->attrs_ = attrs; + op_checkers().at(type).Check(op->attrs_); - //! Convert Temporary variable name to an unique variable name. - GenerateTempVariableName(op.get()); + GenerateTempVariableName(op); - //! set argument offsets stored in op. { - auto var_index_it = VarIndexMaps().find(op_type); + auto var_index_it = VarIndexMaps().find(type); if (var_index_it != VarIndexMaps().end()) { op->in_out_idxs_ = var_index_it->second; } } - //! Other op's custom Init for a complex Op. For simple Op, the Init - //! method do nothing. + op->Init(); - return op; + return OperatorPtr(op); + } + + static OperatorPtr CreateOp(const OpDesc& op_desc) { + std::vector inputs; + inputs.reserve((size_t)op_desc.inputs_size()); + std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), + std::back_inserter(inputs)); + + std::vector outputs; + outputs.reserve((size_t)op_desc.outputs_size()); + std::copy(op_desc.outputs().begin(), op_desc.outputs().end(), + std::back_inserter(outputs)); + + AttributeMap attrs; + for (auto& attr : op_desc.attrs()) { + attrs[attr.name()] = AttrTypeHelper::GetAttrValue(attr); + } + + return CreateOp(op_desc.type(), inputs, outputs, attrs); } static std::unordered_map& protos() {