diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index b316f2d535cc97183a28f39eec6a5a7b73e7c2d2..cb491ec95f6d3cbb7283366ad3286f2a4f5dabee 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -25,8 +25,9 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, const auto& src_inout = src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; auto& dst_inout = *vars; + const OpProto* proto = OpRegistry::op_info_map().at(src_op->type_).proto_; const auto& src_arg_list = - src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); + src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); for (const auto& arg : src_arg_list) { if (arg.no_gradient() && !is_grad) continue; const std::string src_name = arg.name(); @@ -43,6 +44,8 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { auto it = OpRegistry::op_info_map().find(op->type_); PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), "'%s' has not been registered.", op->type_); + PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.", + op->type_); std::string grad_op_type = it->second.grad_op_type_; PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", op->type_); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 6dd5f4af224f3c0c26e907f1a2eb977e6597d5dc..120f4ede6ba21e31683a8d19f0b39072c3f5c309 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -126,13 +126,6 @@ class NOPMaker : public OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) {} }; -struct OpInfo { - std::function creator_; - std::string grad_op_type_; - OpProto* proto_; - OpAttrChecker* checker_; -}; - class OpRegistry { using VarNameMap = OperatorBase::VarNameMap; using OpCreator = std::function; public: + struct OpInfo { + OpCreator creator_; + std::string grad_op_type_; + OpProto* proto_; + OpAttrChecker* checker_; + }; + template static void RegisterOp(const std::string& op_type, const std::string& grad_op_type) { @@ -175,9 +175,9 @@ class OpRegistry { } static std::shared_ptr CreateOp(const std::string& type, - const VarNameList& inputs, - const VarNameList& outputs, - const AttributeMap& attrs) { + const VarNameMap& inputs, + const VarNameMap& outputs, + AttributeMap attrs) { auto it = op_info_map().find(type); PADDLE_ENFORCE(it != op_info_map().end(), "Operator '%s' has not been registered.", type); diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index c054804477f6058de7e0c7ea7efb46a9668a8fd1..0daf12e7f5f3539d460ce67d39ca1c06f5aa2237 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -152,7 +152,7 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { type_); // get all OpProto::Var for outputs - for (auto& o : it->second.proto_.outputs()) { + for (auto& o : it->second.proto_->outputs()) { // ignore all intermediate output if (o.intermediate()) continue; auto out = outputs_.find(o.name());