From 267f9a2cdfad6b627eb6094a28cf5db41bc4f1a4 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 11 Jul 2017 04:21:37 -0500 Subject: [PATCH] Move static variable defined in .cc (#2782) * Move static variable defined in .cc We cannot define static variable in .h, because it will be multi-defined errors. Also fix some cpp syntax, like: * Prefer to use algorithm not manually for-loop, to make code more readable. * Remove unused `()`. * Enforce take a bool. It is no need `xxx==true`. * Use range-based for-loop iterator from op_desc.attrs * Fix a protential static variable init order error --- paddle/framework/CMakeLists.txt | 3 +- paddle/framework/op_registry.cc | 36 +++++++ paddle/framework/op_registry.h | 154 +++++++-------------------- paddle/framework/op_registry_test.cc | 62 +++++++++++ 4 files changed, 136 insertions(+), 119 deletions(-) create mode 100644 paddle/framework/op_registry.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index aecc97d4a86..0a5edba6ef3 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -11,7 +11,8 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) -cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc) +cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) +cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc new file mode 100644 index 00000000000..bc6a0dda57d --- /dev/null +++ b/paddle/framework/op_registry.cc @@ -0,0 +1,36 @@ +#include + +namespace paddle { +namespace framework { + +template <> +void AttrTypeHelper::SetAttrType(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::INT); +} + +template <> +void AttrTypeHelper::SetAttrType(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::FLOAT); +} + +template <> +void AttrTypeHelper::SetAttrType(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::STRING); +} + +template <> +void AttrTypeHelper::SetAttrType>(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::INTS); +} + +template <> +void AttrTypeHelper::SetAttrType>(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::FLOATS); +} + +template <> +void AttrTypeHelper::SetAttrType>(AttrProto* attr) { + attr->set_type(paddle::framework::AttrType::STRINGS); +} +} +} \ No newline at end of file diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 81241b5342d..a782834693c 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -3,6 +3,7 @@ #include "paddle/framework/attr_checker.h" //#include "paddle/framework/op_base.h" +#include #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" @@ -64,36 +65,6 @@ struct AttrTypeHelper { } }; -template <> -void AttrTypeHelper::SetAttrType(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::INT); -} - -template <> -void AttrTypeHelper::SetAttrType(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::FLOAT); -} - -template <> -void AttrTypeHelper::SetAttrType(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::STRING); -} - -template <> -void AttrTypeHelper::SetAttrType>(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::INTS); -} - -template <> -void AttrTypeHelper::SetAttrType>(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::FLOATS); -} - -template <> -void AttrTypeHelper::SetAttrType>(AttrProto* attr) { - attr->set_type(paddle::framework::AttrType::STRINGS); -} - // this class not only make proto but also init attribute checkers. class OpProtoAndCheckerMaker { public: @@ -103,22 +74,22 @@ class OpProtoAndCheckerMaker { protected: void AddInput(const std::string& name, const std::string& comment) { auto input = proto_->mutable_inputs()->Add(); - *(input->mutable_name()) = name; - *(input->mutable_comment()) = comment; + *input->mutable_name() = name; + *input->mutable_comment() = comment; } void AddOutput(const std::string& name, const std::string& comment) { auto output = proto_->mutable_outputs()->Add(); - *(output->mutable_name()) = name; - *(output->mutable_comment()) = comment; + *output->mutable_name() = name; + *output->mutable_comment() = comment; } template TypedAttrChecker& AddAttr(const std::string& name, const std::string& comment) { auto attr = proto_->mutable_attrs()->Add(); - *(attr->mutable_name()) = name; - *(attr->mutable_comment()) = comment; + *attr->mutable_name() = name; + *attr->mutable_comment() = comment; AttrTypeHelper::SetAttrType(attr); return op_checker_->AddAttrChecker(name); } @@ -134,49 +105,51 @@ class OpProtoAndCheckerMaker { }; class OpRegistry { - typedef std::function OpCreator; + using OpCreator = std::function; public: template static void RegisterOp(const std::string& op_type) { - creators_[op_type] = []() { return new OpType; }; - OpProto& op_proto = protos_[op_type]; - OpAttrChecker& op_checker = op_checkers_[op_type]; + creators()[op_type] = [] { return new OpType; }; + OpProto& op_proto = protos()[op_type]; + OpAttrChecker& op_checker = op_checkers()[op_type]; ProtoMakerType(&op_proto, &op_checker); - PADDLE_ENFORCE(op_proto.IsInitialized() == true, + PADDLE_ENFORCE(op_proto.IsInitialized(), "Fail to initialize %s's OpProto !", op_type); } static OpBase* CreateOp(const OpDesc& op_desc) { std::string op_type = op_desc.type(); - OpBase* op = (creators_.at(op_type))(); - (op->inputs_).resize(op_desc.inputs_size()); - for (int i = 0; i < op_desc.inputs_size(); ++i) { - (op->inputs_)[i] = op_desc.inputs(i); - } - (op->outputs_).resize(op_desc.outputs_size()); - for (int i = 0; i < op_desc.outputs_size(); ++i) { - (op->outputs_)[i] = op_desc.outputs(i); - } - for (int i = 0; i < op_desc.attrs_size(); ++i) { - const AttrDesc& ith_attr = op_desc.attrs(i); - std::string name = ith_attr.name(); - (op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr); + OpBase* op = creators().at(op_type)(); + op->inputs_.reserve((size_t)op_desc.inputs_size()); + std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), + std::back_inserter(op->inputs_)); + op->outputs_.reserve((size_t)op_desc.outputs_size()); + std::copy(op_desc.outputs().begin(), op_desc.outputs().end(), + std::back_inserter(op->outputs_)); + for (auto& attr : op_desc.attrs()) { + op->attr_map_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); } - const OpAttrChecker& op_checker = op_checkers_.at(op_type); - op_checker.Check(op->attr_map_); + op_checkers().at(op_type).Check(op->attr_map_); return op; } private: - static std::unordered_map creators_; - static std::unordered_map protos_; - static std::unordered_map op_checkers_; -}; + static std::unordered_map& creators() { + static std::unordered_map creators_; + return creators_; + } -std::unordered_map> OpRegistry::creators_; -std::unordered_map OpRegistry::protos_; -std::unordered_map OpRegistry::op_checkers_; + static std::unordered_map& protos() { + static std::unordered_map protos_; + return protos_; + }; + + static std::unordered_map& op_checkers() { + static std::unordered_map op_checkers_; + return op_checkers_; + }; +}; template class OpRegisterHelper { @@ -194,60 +167,5 @@ class OpRegisterHelper { const OpRegisterHelper<__op_class, __op_maker_class> \ __op_class##Register::reg(#__op_type); -// Demos - -class CosineOp : public OpBase { - public: - virtual std::string Run() const { - std::string msg = "CosineOp runs! scale = " + - std::to_string(boost::get(attr_map_.at("scale"))); - return msg; - } -}; - -class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); - AddType("cos"); - AddComment("This is cos op"); - } -}; - -REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) - -class MyTestOp : public OpBase { - public: - virtual std::string Run() const { - std::string msg = - "MyTestOp runs! test_attr = " + - std::to_string(boost::get(attr_map_.at("test_attr"))); - return msg; - } -}; - -class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - auto my_checker = [](int i) { - PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); - }; - AddAttr("test_attr", "a simple test attribute") - .AddCustomChecker(my_checker); - AddType("my_test_op"); - AddComment("This is my_test op"); - } -}; - -REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) - } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index ae6b7387129..a92f1feb476 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -1,6 +1,63 @@ #include "paddle/framework/op_registry.h" #include +namespace paddle { +namespace framework { +class CosineOp : public OpBase { + public: + virtual std::string Run() const { + std::string msg = "CosineOp runs! scale = " + + std::to_string(boost::get(attr_map_.at("scale"))); + return msg; + } +}; + +class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of cosine op"); + AddOutput("output", "output of cosine op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); + AddType("cos"); + AddComment("This is cos op"); + } +}; + +REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) + +class MyTestOp : public OpBase { + public: + virtual std::string Run() const { + std::string msg = + "MyTestOp runs! test_attr = " + + std::to_string(boost::get(attr_map_.at("test_attr"))); + return msg; + } +}; + +class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of cosine op"); + AddOutput("output", "output of cosine op"); + auto my_checker = [](int i) { + PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); + }; + AddAttr("test_attr", "a simple test attribute") + .AddCustomChecker(my_checker); + AddType("my_test_op"); + AddComment("This is my_test op"); + } +}; + +REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) +} // namespace framework +} // namespace paddle + TEST(OpRegistry, CreateOp) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); @@ -120,3 +177,8 @@ TEST(OpRegistry, CustomChecker) { ASSERT_EQ(debug_str[i], str[i]); } } + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file -- GitLab