提交 267f9a2c 编写于 作者: Y Yu Yang 提交者: Qiao Longfei

Move static variable defined in .cc (#2782)

* Move static variable defined in .cc

We cannot define static variable in .h, because it will be multi-defined
errors.

Also fix some cpp syntax, like:

* Prefer to use algorithm not manually for-loop, to make code more
  readable.
* Remove unused `()`.
* Enforce take a bool. It is no need `xxx==true`.
* Use range-based for-loop iterator from op_desc.attrs

* Fix a protential static variable init order error
上级 27b196ba
...@@ -11,7 +11,8 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type) ...@@ -11,7 +11,8 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type) proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
......
#include <paddle/framework/op_registry.h>
namespace paddle {
namespace framework {
template <>
void AttrTypeHelper::SetAttrType<int>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INT);
}
template <>
void AttrTypeHelper::SetAttrType<float>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOAT);
}
template <>
void AttrTypeHelper::SetAttrType<std::string>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRING);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<int>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INTS);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<float>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOATS);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<std::string>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRINGS);
}
}
}
\ No newline at end of file
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "paddle/framework/attr_checker.h" #include "paddle/framework/attr_checker.h"
//#include "paddle/framework/op_base.h" //#include "paddle/framework/op_base.h"
#include <algorithm>
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
...@@ -64,36 +65,6 @@ struct AttrTypeHelper { ...@@ -64,36 +65,6 @@ struct AttrTypeHelper {
} }
}; };
template <>
void AttrTypeHelper::SetAttrType<int>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INT);
}
template <>
void AttrTypeHelper::SetAttrType<float>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOAT);
}
template <>
void AttrTypeHelper::SetAttrType<std::string>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRING);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<int>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INTS);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<float>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOATS);
}
template <>
void AttrTypeHelper::SetAttrType<std::vector<std::string>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRINGS);
}
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker { class OpProtoAndCheckerMaker {
public: public:
...@@ -103,22 +74,22 @@ class OpProtoAndCheckerMaker { ...@@ -103,22 +74,22 @@ class OpProtoAndCheckerMaker {
protected: protected:
void AddInput(const std::string& name, const std::string& comment) { void AddInput(const std::string& name, const std::string& comment) {
auto input = proto_->mutable_inputs()->Add(); auto input = proto_->mutable_inputs()->Add();
*(input->mutable_name()) = name; *input->mutable_name() = name;
*(input->mutable_comment()) = comment; *input->mutable_comment() = comment;
} }
void AddOutput(const std::string& name, const std::string& comment) { void AddOutput(const std::string& name, const std::string& comment) {
auto output = proto_->mutable_outputs()->Add(); auto output = proto_->mutable_outputs()->Add();
*(output->mutable_name()) = name; *output->mutable_name() = name;
*(output->mutable_comment()) = comment; *output->mutable_comment() = comment;
} }
template <typename T> template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name, TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment) { const std::string& comment) {
auto attr = proto_->mutable_attrs()->Add(); auto attr = proto_->mutable_attrs()->Add();
*(attr->mutable_name()) = name; *attr->mutable_name() = name;
*(attr->mutable_comment()) = comment; *attr->mutable_comment() = comment;
AttrTypeHelper::SetAttrType<T>(attr); AttrTypeHelper::SetAttrType<T>(attr);
return op_checker_->AddAttrChecker<T>(name); return op_checker_->AddAttrChecker<T>(name);
} }
...@@ -134,49 +105,51 @@ class OpProtoAndCheckerMaker { ...@@ -134,49 +105,51 @@ class OpProtoAndCheckerMaker {
}; };
class OpRegistry { class OpRegistry {
typedef std::function<OpBase*()> OpCreator; using OpCreator = std::function<OpBase*()>;
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type) {
creators_[op_type] = []() { return new OpType; }; creators()[op_type] = [] { return new OpType; };
OpProto& op_proto = protos_[op_type]; OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers_[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
ProtoMakerType(&op_proto, &op_checker); ProtoMakerType(&op_proto, &op_checker);
PADDLE_ENFORCE(op_proto.IsInitialized() == true, PADDLE_ENFORCE(op_proto.IsInitialized(),
"Fail to initialize %s's OpProto !", op_type); "Fail to initialize %s's OpProto !", op_type);
} }
static OpBase* CreateOp(const OpDesc& op_desc) { static OpBase* CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type(); std::string op_type = op_desc.type();
OpBase* op = (creators_.at(op_type))(); OpBase* op = creators().at(op_type)();
(op->inputs_).resize(op_desc.inputs_size()); op->inputs_.reserve((size_t)op_desc.inputs_size());
for (int i = 0; i < op_desc.inputs_size(); ++i) { std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
(op->inputs_)[i] = op_desc.inputs(i); std::back_inserter(op->inputs_));
} op->outputs_.reserve((size_t)op_desc.outputs_size());
(op->outputs_).resize(op_desc.outputs_size()); std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
for (int i = 0; i < op_desc.outputs_size(); ++i) { std::back_inserter(op->outputs_));
(op->outputs_)[i] = op_desc.outputs(i); for (auto& attr : op_desc.attrs()) {
} op->attr_map_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
for (int i = 0; i < op_desc.attrs_size(); ++i) {
const AttrDesc& ith_attr = op_desc.attrs(i);
std::string name = ith_attr.name();
(op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr);
} }
const OpAttrChecker& op_checker = op_checkers_.at(op_type); op_checkers().at(op_type).Check(op->attr_map_);
op_checker.Check(op->attr_map_);
return op; return op;
} }
private: private:
static std::unordered_map<std::string, OpCreator> creators_; static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpProto> protos_; static std::unordered_map<std::string, OpCreator> creators_;
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; return creators_;
}; }
std::unordered_map<std::string, std::function<OpBase*()>> OpRegistry::creators_; static std::unordered_map<std::string, OpProto>& protos() {
std::unordered_map<std::string, OpProto> OpRegistry::protos_; static std::unordered_map<std::string, OpProto> protos_;
std::unordered_map<std::string, OpAttrChecker> OpRegistry::op_checkers_; return protos_;
};
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
};
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper { class OpRegisterHelper {
...@@ -194,60 +167,5 @@ class OpRegisterHelper { ...@@ -194,60 +167,5 @@ class OpRegisterHelper {
const OpRegisterHelper<__op_class, __op_maker_class> \ const OpRegisterHelper<__op_class, __op_maker_class> \
__op_class##Register::reg(#__op_type); __op_class##Register::reg(#__op_type);
// Demos
class CosineOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg = "CosineOp runs! scale = " +
std::to_string(boost::get<float>(attr_map_.at("scale")));
return msg;
}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
class MyTestOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg =
"MyTestOp runs! test_attr = " +
std::to_string(boost::get<int>(attr_map_.at("test_attr")));
return msg;
}
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace paddle {
namespace framework {
class CosineOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg = "CosineOp runs! scale = " +
std::to_string(boost::get<float>(attr_map_.at("scale")));
return msg;
}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
class MyTestOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg =
"MyTestOp runs! test_attr = " +
std::to_string(boost::get<int>(attr_map_.at("test_attr")));
return msg;
}
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
} // namespace framework
} // namespace paddle
TEST(OpRegistry, CreateOp) { TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
...@@ -120,3 +177,8 @@ TEST(OpRegistry, CustomChecker) { ...@@ -120,3 +177,8 @@ TEST(OpRegistry, CustomChecker) {
ASSERT_EQ(debug_str[i], str[i]); ASSERT_EQ(debug_str[i], str[i]);
} }
} }
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册