提交 80a26a63 编写于 作者: Q Qiao Longfei 提交者: GitHub

check duplicate of ProtoAndCheckerMaker (#2903)

上级 cdec5634
...@@ -61,7 +61,14 @@ class OpProtoAndCheckerMaker { ...@@ -61,7 +61,14 @@ class OpProtoAndCheckerMaker {
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {} : proto_(proto), op_checker_(op_checker) {}
~OpProtoAndCheckerMaker() { CheckNoDuplicatedAttrs(); } ~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}
void Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}
protected: protected:
void AddInput(const std::string& name, const std::string& comment, void AddInput(const std::string& name, const std::string& comment,
...@@ -163,19 +170,26 @@ Add a mark to which output is temporary is helpful for future optimization. ...@@ -163,19 +170,26 @@ Add a mark to which output is temporary is helpful for future optimization.
} }
} }
void CheckNoDuplicatedAttrs() { void CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
size_t cnt = 0; auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) { for (auto& attr : proto_->attrs()) {
names.insert(attr.name()); checker(attr.name());
++cnt; }
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
} }
PADDLE_ENFORCE(names.size() == cnt,
"Cannot register two attribute in same name!");
} }
OpProto* proto_; OpProto* proto_;
OpAttrChecker* op_checker_; OpAttrChecker* op_checker_;
bool validated_{false};
bool has_multiple_input_{false}; bool has_multiple_input_{false};
bool has_multiple_output_{false}; bool has_multiple_output_{false};
bool has_temporary_output_{false}; bool has_temporary_output_{false};
...@@ -190,7 +204,8 @@ class OpRegistry { ...@@ -190,7 +204,8 @@ class OpRegistry {
creators()[op_type] = [] { return new OpType; }; creators()[op_type] = [] { return new OpType; };
OpProto& op_proto = protos()[op_type]; OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
ProtoMakerType(&op_proto, &op_checker); auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate();
*op_proto.mutable_type() = op_type; *op_proto.mutable_type() = op_type;
PADDLE_ENFORCE( PADDLE_ENFORCE(
op_proto.IsInitialized(), op_proto.IsInitialized(),
......
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace pd = paddle::framework;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
...@@ -28,8 +30,6 @@ class MyTestOp : public OperatorBase { ...@@ -28,8 +30,6 @@ class MyTestOp : public OperatorBase {
void InferShape(const ScopePtr& scope) const override {} void InferShape(const ScopePtr& scope) const override {}
void Run(const ScopePtr& scope, void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
public:
}; };
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -182,3 +182,35 @@ TEST(OpRegistry, CustomChecker) { ...@@ -182,3 +182,35 @@ TEST(OpRegistry, CustomChecker) {
int test_attr = op->GetAttr<int>("test_attr"); int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }
class TestAttrProtoMaker : public pd::OpProtoAndCheckerMaker {
public:
TestAttrProtoMaker(pd::OpProto* proto, pd::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op");
AddAttr<float>("scale", "scale of test op");
}
};
TEST(ProtoMaker, DuplicatedAttr) {
pd::OpProto op_proto;
pd::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet);
}
class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
public:
TestInOutProtoMaker(pd::OpProto* proto, pd::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddInput("input", "input of test op");
}
};
TEST(ProtoMaker, DuplicatedInOut) {
pd::OpProto op_proto;
pd::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册