diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 970b2b9abde7be17a2427911dc9e3a4fa638a327..4409c6feae218222b7c0216760cebe4ae8e235cb 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -11,6 +11,7 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) +cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) diff --git a/paddle/framework/attr_checker.h b/paddle/framework/attr_checker.h new file mode 100644 index 0000000000000000000000000000000000000000..c0c33d81149ac2fc2a9a57d90931ef32375fe1d0 --- /dev/null +++ b/paddle/framework/attr_checker.h @@ -0,0 +1,119 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/framework/enforce.h" + +namespace paddle { +namespace framework { + +typedef boost::variant, + std::vector, std::vector> + Attribute; +typedef std::unordered_map AttributeMap; + +// check whether a value(attribute) fit a certain limit +template +class LargerThanChecker { + public: + LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} + void operator()(T& value) const { + PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail"); + } + + private: + T lower_bound_; +}; + +// we can provide users more common Checker, like 'LessThanChecker', +// 'BetweenChecker'... + +template +class DefaultValueSetter { + public: + DefaultValueSetter(T default_value) : default_value_(default_value) {} + void operator()(T& value) const { value = default_value_; } + + private: + T default_value_; +}; + +// check whether a certain attribute fit its limits +// an attribute can have more than one limits +template +class TypedAttrChecker { + typedef std::function ValueChecker; + + public: + TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} + + TypedAttrChecker& LargerThan(const T& lower_bound) { + value_checkers_.push_back(LargerThanChecker(lower_bound)); + return *this; + } + + // we can add more common limits, like LessThan(), Between()... + + TypedAttrChecker& SetDefault(const T& default_value) { + PADDLE_ENFORCE(default_value_setter_.empty(), + "%s can't have more than one default value!", attr_name_); + default_value_setter_.push_back(DefaultValueSetter(default_value)); + return *this; + } + + // allow users provide their own checker + TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) { + value_checkers_.push_back(checker); + return *this; + } + + void operator()(AttributeMap& attr_map) const { + if (!attr_map.count(attr_name_)) { + // user do not set this attr + PADDLE_ENFORCE(!default_value_setter_.empty(), + "Attribute '%s' is required!", attr_name_); + // default_value_setter_ has no more than one element + T val; + (default_value_setter_[0])(val); + attr_map[attr_name_] = val; + } + Attribute& attr = attr_map.at(attr_name_); + T& attr_value = boost::get(attr); + for (const auto& checker : value_checkers_) { + checker(attr_value); + } + } + + private: + std::string attr_name_; + std::vector value_checkers_; + std::vector default_value_setter_; +}; + +// check whether op's all attributes fit their own limits +class OpAttrChecker { + typedef std::function AttrChecker; + + public: + template + TypedAttrChecker& AddAttrChecker(const std::string& attr_name) { + attr_checkers_.push_back(TypedAttrChecker(attr_name)); + AttrChecker& checker = attr_checkers_.back(); + return *(checker.target>()); + } + + void Check(AttributeMap& attr_map) const { + for (const auto& checker : attr_checkers_) { + checker(attr_map); + } + } + + private: + std::vector attr_checkers_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..81241b5342d8900c205dd62f2a62dc2496010560 --- /dev/null +++ b/paddle/framework/op_registry.h @@ -0,0 +1,253 @@ +#pragma once + +#include "paddle/framework/attr_checker.h" + +//#include "paddle/framework/op_base.h" +#include "paddle/framework/op_desc.pb.h" +#include "paddle/framework/op_proto.pb.h" + +namespace paddle { +namespace framework { + +//==================For test================// +class OpBase { + public: + std::vector inputs_; + std::vector outputs_; + AttributeMap attr_map_; + + virtual std::string Run() const = 0; + virtual ~OpBase() {} +}; +//=========================================// + +// helper class to set attribute type +struct AttrTypeHelper { + template + static void SetAttrType(AttrProto* attr); + + static Attribute GetAttrValue(const AttrDesc& attr_desc) { + switch (attr_desc.type()) { + case paddle::framework::AttrType::INT: { + return attr_desc.i(); + } + case paddle::framework::AttrType::FLOAT: { + return attr_desc.f(); + } + case paddle::framework::AttrType::STRING: { + return attr_desc.s(); + } + case paddle::framework::AttrType::INTS: { + std::vector val(attr_desc.ints_size()); + for (int i = 0; i < attr_desc.ints_size(); ++i) { + val[i] = attr_desc.ints(i); + } + return val; + } + case paddle::framework::AttrType::FLOATS: { + std::vector val(attr_desc.floats_size()); + for (int i = 0; i < attr_desc.floats_size(); ++i) { + val[i] = attr_desc.floats(i); + } + return val; + } + case paddle::framework::AttrType::STRINGS: { + std::vector val(attr_desc.strings_size()); + for (int i = 0; i < attr_desc.strings_size(); ++i) { + val[i] = attr_desc.strings(i); + } + return val; + } + } + PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); + return boost::blank(); + } +}; + +template <> +void AttrTypeHelper::SetAttrType(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::INT); +} + +template <> +void AttrTypeHelper::SetAttrType(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::FLOAT); +} + +template <> +void AttrTypeHelper::SetAttrType(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::STRING); +} + +template <> +void AttrTypeHelper::SetAttrType>(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::INTS); +} + +template <> +void AttrTypeHelper::SetAttrType>(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::FLOATS); +} + +template <> +void AttrTypeHelper::SetAttrType>(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::STRINGS); +} + +// this class not only make proto but also init attribute checkers. +class OpProtoAndCheckerMaker { + public: + OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : proto_(proto), op_checker_(op_checker) {} + + protected: + void AddInput(const std::string& name, const std::string& comment) { + auto input = proto_->mutable_inputs()->Add(); + *(input->mutable_name()) = name; + *(input->mutable_comment()) = comment; + } + + void AddOutput(const std::string& name, const std::string& comment) { + auto output = proto_->mutable_outputs()->Add(); + *(output->mutable_name()) = name; + *(output->mutable_comment()) = comment; + } + + template + TypedAttrChecker& AddAttr(const std::string& name, + const std::string& comment) { + auto attr = proto_->mutable_attrs()->Add(); + *(attr->mutable_name()) = name; + *(attr->mutable_comment()) = comment; + AttrTypeHelper::SetAttrType(attr); + return op_checker_->AddAttrChecker(name); + } + + void AddType(const std::string& op_type) { proto_->set_type(op_type); } + + void AddComment(const std::string& comment) { + *(proto_->mutable_comment()) = comment; + } + + OpProto* proto_; + OpAttrChecker* op_checker_; +}; + +class OpRegistry { + typedef std::function OpCreator; + + public: + template + static void RegisterOp(const std::string& op_type) { + creators_[op_type] = []() { return new OpType; }; + OpProto& op_proto = protos_[op_type]; + OpAttrChecker& op_checker = op_checkers_[op_type]; + ProtoMakerType(&op_proto, &op_checker); + PADDLE_ENFORCE(op_proto.IsInitialized() == true, + "Fail to initialize %s's OpProto !", op_type); + } + + static OpBase* CreateOp(const OpDesc& op_desc) { + std::string op_type = op_desc.type(); + OpBase* op = (creators_.at(op_type))(); + (op->inputs_).resize(op_desc.inputs_size()); + for (int i = 0; i < op_desc.inputs_size(); ++i) { + (op->inputs_)[i] = op_desc.inputs(i); + } + (op->outputs_).resize(op_desc.outputs_size()); + for (int i = 0; i < op_desc.outputs_size(); ++i) { + (op->outputs_)[i] = op_desc.outputs(i); + } + for (int i = 0; i < op_desc.attrs_size(); ++i) { + const AttrDesc& ith_attr = op_desc.attrs(i); + std::string name = ith_attr.name(); + (op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr); + } + const OpAttrChecker& op_checker = op_checkers_.at(op_type); + op_checker.Check(op->attr_map_); + return op; + } + + private: + static std::unordered_map creators_; + static std::unordered_map protos_; + static std::unordered_map op_checkers_; +}; + +std::unordered_map> OpRegistry::creators_; +std::unordered_map OpRegistry::protos_; +std::unordered_map OpRegistry::op_checkers_; + +template +class OpRegisterHelper { + public: + OpRegisterHelper(std::string op_type) { + OpRegistry::RegisterOp(op_type); + } +}; + +#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \ + class __op_class##Register { \ + private: \ + const static OpRegisterHelper<__op_class, __op_maker_class> reg; \ + }; \ + const OpRegisterHelper<__op_class, __op_maker_class> \ + __op_class##Register::reg(#__op_type); + +// Demos + +class CosineOp : public OpBase { + public: + virtual std::string Run() const { + std::string msg = "CosineOp runs! scale = " + + std::to_string(boost::get(attr_map_.at("scale"))); + return msg; + } +}; + +class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of cosine op"); + AddOutput("output", "output of cosine op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); + AddType("cos"); + AddComment("This is cos op"); + } +}; + +REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) + +class MyTestOp : public OpBase { + public: + virtual std::string Run() const { + std::string msg = + "MyTestOp runs! test_attr = " + + std::to_string(boost::get(attr_map_.at("test_attr"))); + return msg; + } +}; + +class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of cosine op"); + AddOutput("output", "output of cosine op"); + auto my_checker = [](int i) { + PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); + }; + AddAttr("test_attr", "a simple test attribute") + .AddCustomChecker(my_checker); + AddType("my_test_op"); + AddComment("This is my_test op"); + } +}; + +REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..17849ca0191db644884e766342b30461abf50298 --- /dev/null +++ b/paddle/framework/op_registry_test.cc @@ -0,0 +1,122 @@ +#include "paddle/framework/op_registry.h" +#include + +TEST(OpRegistry, CreateOp) { + paddle::framework::OpDesc op_desc; + op_desc.set_type("cos_sim"); + op_desc.add_inputs("aa"); + op_desc.add_outputs("bb"); + + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_name("scale"); + attr->set_type(paddle::framework::AttrType::FLOAT); + attr->set_f(3.3); + + paddle::framework::OpBase* op = + paddle::framework::OpRegistry::CreateOp(op_desc); + std::string debug_str = op->Run(); + std::string str = "CosineOp runs! scale = " + std::to_string(3.3); + ASSERT_EQ(str.size(), debug_str.size()); + for (size_t i = 0; i < debug_str.length(); ++i) { + ASSERT_EQ(debug_str[i], str[i]); + } +} + +TEST(OpRegistry, IllegalAttr) { + paddle::framework::OpDesc op_desc; + op_desc.set_type("cos_sim"); + op_desc.add_inputs("aa"); + op_desc.add_outputs("bb"); + + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_name("scale"); + attr->set_type(paddle::framework::AttrType::FLOAT); + attr->set_f(-2.0); + + bool caught = false; + try { + paddle::framework::OpBase* op __attribute__((unused)) = + paddle::framework::OpRegistry::CreateOp(op_desc); + } catch (paddle::framework::EnforceNotMet err) { + caught = true; + std::string msg = "larger_than check fail"; + const char* err_msg = err.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(err_msg[i], msg[i]); + } + } + ASSERT_TRUE(caught); +} + +TEST(OpRegistry, DefaultValue) { + paddle::framework::OpDesc op_desc; + op_desc.set_type("cos_sim"); + op_desc.add_inputs("aa"); + op_desc.add_outputs("bb"); + + paddle::framework::OpBase* op = + paddle::framework::OpRegistry::CreateOp(op_desc); + std::string debug_str = op->Run(); + float default_value = 1.0; + std::string str = "CosineOp runs! scale = " + std::to_string(default_value); + ASSERT_EQ(str.size(), debug_str.size()); + for (size_t i = 0; i < debug_str.length(); ++i) { + ASSERT_EQ(debug_str[i], str[i]); + } +} + +TEST(OpRegistry, CustomChecker) { + paddle::framework::OpDesc op_desc; + op_desc.set_type("my_test_op"); + op_desc.add_inputs("ii"); + op_desc.add_outputs("oo"); + + // attr 'test_attr' is not set + bool caught = false; + try { + paddle::framework::OpBase* op __attribute__((unused)) = + paddle::framework::OpRegistry::CreateOp(op_desc); + } catch (paddle::framework::EnforceNotMet err) { + caught = true; + std::string msg = "Attribute 'test_attr' is required!"; + const char* err_msg = err.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(err_msg[i], msg[i]); + } + } + ASSERT_TRUE(caught); + + // set 'test_attr' set to an illegal value + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_name("test_attr"); + attr->set_type(paddle::framework::AttrType::INT); + attr->set_i(3); + caught = false; + try { + paddle::framework::OpBase* op __attribute__((unused)) = + paddle::framework::OpRegistry::CreateOp(op_desc); + } catch (paddle::framework::EnforceNotMet err) { + caught = true; + std::string msg = "'test_attr' must be even!"; + const char* err_msg = err.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(err_msg[i], msg[i]); + } + } + ASSERT_TRUE(caught); + + // set 'test_attr' set to a legal value + op_desc.mutable_attrs()->Clear(); + attr = op_desc.mutable_attrs()->Add(); + attr->set_name("test_attr"); + attr->set_type(paddle::framework::AttrType::INT); + attr->set_i(4); + paddle::framework::OpBase* op = + paddle::framework::OpRegistry::CreateOp(op_desc); + std::string debug_str = op->Run(); + std::string str = "MyTestOp runs! test_attr = " + std::to_string(4); + ASSERT_EQ(str.size(), debug_str.size()); + for (size_t i = 0; i < debug_str.length(); ++i) { + ASSERT_EQ(debug_str[i], str[i]); + } +} \ No newline at end of file