#pragma once #include "paddle/framework/attr_checker.h" //#include "paddle/framework/op_base.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" namespace paddle { namespace framework { //==================For test================// class OpBase { public: std::vector inputs_; std::vector outputs_; AttributeMap attr_map_; virtual std::string Run() const = 0; virtual ~OpBase() {} }; //=========================================// // helper class to set attribute type struct AttrTypeHelper { template static void SetAttrType(AttrProto* attr); static Attribute GetAttrValue(const AttrDesc& attr_desc) { switch (attr_desc.type()) { case paddle::framework::AttrType::INT: { return attr_desc.i(); } case paddle::framework::AttrType::FLOAT: { return attr_desc.f(); } case paddle::framework::AttrType::STRING: { return attr_desc.s(); } case paddle::framework::AttrType::INTS: { std::vector val(attr_desc.ints_size()); for (int i = 0; i < attr_desc.ints_size(); ++i) { val[i] = attr_desc.ints(i); } return val; } case paddle::framework::AttrType::FLOATS: { std::vector val(attr_desc.floats_size()); for (int i = 0; i < attr_desc.floats_size(); ++i) { val[i] = attr_desc.floats(i); } return val; } case paddle::framework::AttrType::STRINGS: { std::vector val(attr_desc.strings_size()); for (int i = 0; i < attr_desc.strings_size(); ++i) { val[i] = attr_desc.strings(i); } return val; } } PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); return boost::blank(); } }; template <> void AttrTypeHelper::SetAttrType(AttrProto* attr) { attr->set_type(paddle::framework::AttrType::INT); } template <> void AttrTypeHelper::SetAttrType(AttrProto* attr) { attr->set_type(paddle::framework::AttrType::FLOAT); } template <> void AttrTypeHelper::SetAttrType(AttrProto* attr) { attr->set_type(paddle::framework::AttrType::STRING); } template <> void AttrTypeHelper::SetAttrType>(AttrProto* attr) { attr->set_type(paddle::framework::AttrType::INTS); } template <> void AttrTypeHelper::SetAttrType>(AttrProto* attr) { attr->set_type(paddle::framework::AttrType::FLOATS); } template <> void AttrTypeHelper::SetAttrType>(AttrProto* attr) { attr->set_type(paddle::framework::AttrType::STRINGS); } // this class not only make proto but also init attribute checkers. class OpProtoAndCheckerMaker { public: OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : proto_(proto), op_checker_(op_checker) {} protected: void AddInput(const std::string& name, const std::string& comment) { auto input = proto_->mutable_inputs()->Add(); *(input->mutable_name()) = name; *(input->mutable_comment()) = comment; } void AddOutput(const std::string& name, const std::string& comment) { auto output = proto_->mutable_outputs()->Add(); *(output->mutable_name()) = name; *(output->mutable_comment()) = comment; } template TypedAttrChecker& AddAttr(const std::string& name, const std::string& comment) { auto attr = proto_->mutable_attrs()->Add(); *(attr->mutable_name()) = name; *(attr->mutable_comment()) = comment; AttrTypeHelper::SetAttrType(attr); return op_checker_->AddAttrChecker(name); } void AddType(const std::string& op_type) { proto_->set_type(op_type); } void AddComment(const std::string& comment) { *(proto_->mutable_comment()) = comment; } OpProto* proto_; OpAttrChecker* op_checker_; }; class OpRegistry { typedef std::function OpCreator; public: template static void RegisterOp(const std::string& op_type) { creators_[op_type] = []() { return new OpType; }; OpProto& op_proto = protos_[op_type]; OpAttrChecker& op_checker = op_checkers_[op_type]; ProtoMakerType(&op_proto, &op_checker); PADDLE_ENFORCE(op_proto.IsInitialized() == true, "Fail to initialize %s's OpProto !", op_type); } static OpBase* CreateOp(const OpDesc& op_desc) { std::string op_type = op_desc.type(); OpBase* op = (creators_.at(op_type))(); (op->inputs_).resize(op_desc.inputs_size()); for (int i = 0; i < op_desc.inputs_size(); ++i) { (op->inputs_)[i] = op_desc.inputs(i); } (op->outputs_).resize(op_desc.outputs_size()); for (int i = 0; i < op_desc.outputs_size(); ++i) { (op->outputs_)[i] = op_desc.outputs(i); } for (int i = 0; i < op_desc.attrs_size(); ++i) { const AttrDesc& ith_attr = op_desc.attrs(i); std::string name = ith_attr.name(); (op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr); } const OpAttrChecker& op_checker = op_checkers_.at(op_type); op_checker.Check(op->attr_map_); return op; } private: static std::unordered_map creators_; static std::unordered_map protos_; static std::unordered_map op_checkers_; }; std::unordered_map> OpRegistry::creators_; std::unordered_map OpRegistry::protos_; std::unordered_map OpRegistry::op_checkers_; template class OpRegisterHelper { public: OpRegisterHelper(std::string op_type) { OpRegistry::RegisterOp(op_type); } }; #define REGISTER_OP(__op_class, __op_maker_class, __op_type) \ class __op_class##Register { \ private: \ const static OpRegisterHelper<__op_class, __op_maker_class> reg; \ }; \ const OpRegisterHelper<__op_class, __op_maker_class> \ __op_class##Register::reg(#__op_type); // Demos class CosineOp : public OpBase { public: virtual std::string Run() const { std::string msg = "CosineOp runs! scale = " + std::to_string(boost::get(attr_map_.at("scale"))); return msg; } }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("input", "input of cosine op"); AddOutput("output", "output of cosine op"); AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .LargerThan(0.0); AddType("cos"); AddComment("This is cos op"); } }; REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) class MyTestOp : public OpBase { public: virtual std::string Run() const { std::string msg = "MyTestOp runs! test_attr = " + std::to_string(boost::get(attr_map_.at("test_attr"))); return msg; } }; class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("input", "input of cosine op"); AddOutput("output", "output of cosine op"); auto my_checker = [](int i) { PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); }; AddAttr("test_attr", "a simple test attribute") .AddCustomChecker(my_checker); AddType("my_test_op"); AddComment("This is my_test op"); } }; REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) } // namespace framework } // namespace paddle