未验证 提交 24a2bd5c 编写于 作者: 王明冬 提交者: GitHub

[pass enhance] make the attribute check only object to which defined in op...

[pass enhance] make the attribute check only object to which defined in op proto. test=develop (#34146)
上级 f05098b5
...@@ -72,19 +72,29 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) { ...@@ -72,19 +72,29 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
AttrCompat& AttrCompat::IsLeftDefault() { AttrCompat& AttrCompat::IsLeftDefault() {
const std::string& op_name = op_compat_->Name(); const std::string& op_name = op_compat_->Name();
if (!OpInfoMap::Instance().Has(op_name)) { if (!OpInfoMap::Instance().Has(op_name)) {
LOG(WARNING) << "Op (" << op_name << ") is not registered!"; conditions_.emplace_back([=](const Attribute& attr) {
conditions_.emplace_back([](const Attribute& attr) { return false; }); LOG(WARNING) << "Op (" << op_name << ") is not find in op library!";
return false;
});
return *this; return *this;
} }
const OpInfo& op_info = OpInfoMap::Instance().Get(op_name); const OpInfo& op_info = OpInfoMap::Instance().Get(op_name);
const AttributeMap attrs = op_info.Checker()->GetDefaultAttrsMap(); const AttributeMap attrs = op_info.Checker()->GetDefaultAttrsMap();
if (attrs.find(attr_name_) == attrs.end()) { if (attrs.find(attr_name_) == attrs.end()) {
LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_; conditions_.emplace_back([=](const Attribute& attr) {
conditions_.emplace_back([](const Attribute& attr) { return false; }); LOG(WARNING) << "Op (" << op_name
<< ") has no default attr:" << attr_name_;
return false;
});
} else { } else {
Attribute default_attr = attrs.at(attr_name_); Attribute default_attr = attrs.at(attr_name_);
conditions_.emplace_back([default_attr](const Attribute& attr) -> bool { conditions_.emplace_back([=](const Attribute& attr) -> bool {
return attr == default_attr; if (attr == default_attr) {
return true;
}
LOG(WARNING) << "Attribute:(" << attr_name_ << ") of Op (" << op_name
<< ") not equal to default value!";
return false;
}); });
} }
return *this; return *this;
...@@ -167,35 +177,39 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) { ...@@ -167,35 +177,39 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
return output_compats_.at(name); return output_compats_.at(name);
} }
bool OpCompat::Judge(const OpDesc& op_desc) { bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) {
if (is_first_judge_) { if (is_first_judge_) {
is_first_judge_ = false; is_first_judge_ = false;
if (OpInfoMap::Instance().Has(op_name_)) {
auto& info = OpInfoMap::Instance().Get(op_name_);
if (info.proto_) {
for (const proto::OpProto::Attr& attr : info.proto_->attrs()) {
attrs_set_.emplace(attr.name());
}
}
}
const proto::OpDef& op_def = GetOpDef(op_name_); const proto::OpDef& op_def = GetOpDef(op_name_);
if (op_def.has_extra()) { if (op_def.has_extra()) {
for (const proto::OpDef_AttrDef& attr : op_def.extra().attrs()) { for (const proto::OpDef_AttrDef& attr : op_def.extra().attrs()) {
extra_attrs_.emplace(attr.name()); attrs_set_.erase(attr.name());
} }
} }
for (const std::string& attr : global_extra_attrs) {
attrs_set_.erase(attr);
} }
for (const std::string& attr : attrs_set_) {
for (auto& attr_map : op_desc.GetAttrMap()) { if (attr_compats_.find(attr) == attr_compats_.end()) {
const std::string& name = attr_map.first; attr_compats_.emplace(attr, AttrCompat(attr, this).IsLeftDefault());
if (name.size() >= 10u &&
0 == name.compare(name.size() - 10u, 10u, "_threshold")) {
continue; // skip the attribute ends with "_threshold", it used for
// quantization.
}
if (attr_compats_.find(attr_map.first) == attr_compats_.end()) {
if (global_extra_attrs.find(attr_map.first) != global_extra_attrs.end() ||
extra_attrs_.find(attr_map.first) != extra_attrs_.end()) {
continue;
} }
if (!AttrCompat(attr_map.first, this).IsLeftDefault()(op_desc)) { }
LOG(WARNING) for (auto& attr_compat : attr_compats_) {
<< "The Attr(" << attr_map.first << ") of Op (" << op_name_ if (attrs_set_.find(attr_compat.first) == attrs_set_.end()) {
<< ") not reigistered in OpCompat, not in extra attribute, not " LOG(WARNING) << " Attribute(" << attr_compat.first << ") of Op("
"equal to default value!"; << op_name_
return false; << ") is not defined in opProto or is in extra set!"
<< "The compatable check for this attribute is not use."
<< " Please remove it from the precondition of pass: "
<< pass_name.c_str();
} }
} }
} }
...@@ -203,7 +217,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) { ...@@ -203,7 +217,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& attr_compat : attr_compats_) { for (auto& attr_compat : attr_compats_) {
if (!attr_compat.second(op_desc)) { if (!attr_compat.second(op_desc)) {
LOG(WARNING) << " Check the Attr(" << attr_compat.first << ") of Op(" LOG(WARNING) << " Check the Attr(" << attr_compat.first << ") of Op("
<< op_name_ << ") failed!"; << op_name_ << ") in pass(" << pass_name.c_str()
<< ") failed!";
return false; return false;
} }
} }
...@@ -289,7 +304,7 @@ bool OpCompatSensiblePass::IsCompat( ...@@ -289,7 +304,7 @@ bool OpCompatSensiblePass::IsCompat(
continue; continue;
} }
auto& judger = *op_compat_judgers_.at(op_type); auto& judger = *op_compat_judgers_.at(op_type);
if (!judger.Judge(*(node_pair.second->Op()))) { if (!judger.Judge(*(node_pair.second->Op()), Type())) {
return false; return false;
} }
} }
......
...@@ -138,7 +138,7 @@ class OpCompat { ...@@ -138,7 +138,7 @@ class OpCompat {
InputOrOutputCompat& AddOutput(const std::string& name); InputOrOutputCompat& AddOutput(const std::string& name);
//! Judge whether an OpDesc match the defined Op compatibility. //! Judge whether an OpDesc match the defined Op compatibility.
bool Judge(const OpDesc& op_desc); bool Judge(const OpDesc& op_desc, const std::string& pass_name);
const std::string& Name() const { return op_name_; } const std::string& Name() const { return op_name_; }
private: private:
...@@ -146,7 +146,7 @@ class OpCompat { ...@@ -146,7 +146,7 @@ class OpCompat {
std::unordered_map<std::string, AttrCompat> attr_compats_; std::unordered_map<std::string, AttrCompat> attr_compats_;
std::unordered_map<std::string, InputOrOutputCompat> input_compats_; std::unordered_map<std::string, InputOrOutputCompat> input_compats_;
std::unordered_map<std::string, InputOrOutputCompat> output_compats_; std::unordered_map<std::string, InputOrOutputCompat> output_compats_;
std::unordered_set<std::string> extra_attrs_; std::unordered_set<std::string> attrs_set_;
bool is_first_judge_ = true; bool is_first_judge_ = true;
}; };
...@@ -206,7 +206,7 @@ class OpCompatSensiblePass : public Pass { ...@@ -206,7 +206,7 @@ class OpCompatSensiblePass : public Pass {
//! Tell the op compatibility of a single Op. //! Tell the op compatibility of a single Op.
bool IsCompat(const OpDesc& op_desc) const { bool IsCompat(const OpDesc& op_desc) const {
if (!op_compat_judgers_.count(op_desc.Type())) return false; if (!op_compat_judgers_.count(op_desc.Type())) return false;
return op_compat_judgers_.at(op_desc.Type())->Judge(op_desc); return op_compat_judgers_.at(op_desc.Type())->Judge(op_desc, Type());
} }
private: private:
......
...@@ -23,7 +23,7 @@ namespace ir { ...@@ -23,7 +23,7 @@ namespace ir {
TEST(OpCompatSensiblePass, compatOp) { TEST(OpCompatSensiblePass, compatOp) {
auto lambda = [](const std::string& str) { return str == "tanh"; }; auto lambda = [](const std::string& str) { return str == "tanh"; };
OpCompat compat("fc"); OpCompat compat("fc_test");
compat.AddAttr("in_num_col_dims") compat.AddAttr("in_num_col_dims")
.IsIntIn({1, 2}) .IsIntIn({1, 2})
.IsNumLE(1) .IsNumLE(1)
...@@ -66,74 +66,118 @@ TEST(OpCompatSensiblePass, compatOp) { ...@@ -66,74 +66,118 @@ TEST(OpCompatSensiblePass, compatOp) {
fc_op.SetInput("Bias", std::vector<std::string>{"test_input_1"}); fc_op.SetInput("Bias", std::vector<std::string>{"test_input_1"});
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"}); fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_STREQ(compat.Name().c_str(), "fc"); OpInfo info;
EXPECT_TRUE(compat.Judge(fc_op)); info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
auto* attr = info.proto_->add_attrs();
attr->set_name("in_num_col_dims");
attr = info.proto_->add_attrs();
attr->set_name("test_attr");
OpInfoMap::Instance().Insert("fc_test", info);
EXPECT_STREQ(compat.Name().c_str(), "fc_test");
EXPECT_TRUE(compat.Judge(fc_op, "test_pass"));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
} }
TEST(OpCompatSensiblePass, compatOpAttribute) { TEST(OpCompatSensiblePass, compatOpAttribute) {
OpCompat compat("fc"); OpCompat compat("fc_test");
OpDesc fc_op; OpDesc fc_op;
std::unordered_map<std::string, Attribute> attr_map; std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1; attr_map["in_num_col_dims"] = 1;
fc_op.SetAttrMap(attr_map); fc_op.SetAttrMap(attr_map);
OpInfo info; OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
auto* attr = info.proto_->add_attrs();
attr->set_name("in_num_col_dims");
info.checker_ = new OpAttrChecker(); info.checker_ = new OpAttrChecker();
OpInfoMap::Instance().Insert("fc", info); OpInfoMap::Instance().Insert("fc_test", info);
EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
EXPECT_FALSE(compat.Judge(fc_op));
OpCompat compat_1("fc_test");
info.checker_->AddAttrChecker<int>("in_num_col_dims").SetDefault(1); info.checker_->AddAttrChecker<int>("in_num_col_dims").SetDefault(1);
EXPECT_TRUE(compat_1.Judge(fc_op, "test_pass"));
EXPECT_TRUE(compat.Judge(fc_op));
delete info.checker_; delete info.checker_;
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
} }
TEST(OpCompatSensiblePass, opDefNotFound) { TEST(OpCompatSensiblePass, opDefNotFound) {
OpCompat compat("fc_1"); OpCompat compat("fc_test");
OpDesc fc_op; OpDesc fc_op;
OpInfo info;
compat.Judge(fc_op); info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
OpCompat compat_1(""); info.proto_->set_comment("");
OpInfoMap::Instance().Insert("fc_test", info);
compat_1.Judge(fc_op); compat.Judge(fc_op, "test_pass");
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
} }
TEST(OpCompatSensiblePass, compatOpAttributeOptional) { TEST(OpCompatSensiblePass, compatOpAttributeOptional) {
OpCompat compat("fc"); OpCompat compat("fc_test");
compat.AddAttr("activation_type") compat.AddAttr("activation_type")
.IsOptional() .IsOptional()
.IsStringIn({"tanh", "sigmoid"}); .IsStringIn({"tanh", "sigmoid"});
OpDesc fc_op; OpDesc fc_op;
EXPECT_TRUE(compat.Judge(fc_op)); OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
auto* attr = info.proto_->add_attrs();
attr->set_name("activation_type");
OpInfoMap::Instance().Insert("fc_test", info);
EXPECT_TRUE(compat.Judge(fc_op, "test_pass"));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
} }
TEST(OpCompatSensiblePass, compatOpInput) { TEST(OpCompatSensiblePass, compatOpInput) {
OpCompat compat("fc"); OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
OpInfoMap::Instance().Insert("fc_test", info);
OpCompat compat("fc_test");
OpDesc fc_op; OpDesc fc_op;
fc_op.SetInput("Input", std::vector<std::string>{"test_input"}); fc_op.SetInput("Input", std::vector<std::string>{"test_input"});
EXPECT_FALSE(compat.Judge(fc_op)); EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
compat.AddInput("Input").IsTensor().End().AddInput("Bias").IsTensor().End(); compat.AddInput("Input").IsTensor().End().AddInput("Bias").IsTensor().End();
EXPECT_FALSE(compat.Judge(fc_op)); EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
fc_op.SetInput("Bias", std::vector<std::string>{"test_input", ""}); fc_op.SetInput("Bias", std::vector<std::string>{"test_input", ""});
EXPECT_FALSE(compat.Judge(fc_op)); EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
} }
TEST(OpCompatSensiblePass, compatOutput) { TEST(OpCompatSensiblePass, compatOutput) {
OpCompat compat("fc"); OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
OpInfoMap::Instance().Insert("fc_test", info);
OpCompat compat("fc_test");
OpDesc fc_op; OpDesc fc_op;
fc_op.SetOutput("Output", std::vector<std::string>{"test_output"}); fc_op.SetOutput("Output", std::vector<std::string>{"test_output"});
EXPECT_FALSE(compat.Judge(fc_op)); EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
compat.AddOutput("Output") compat.AddOutput("Output")
.IsTensor() .IsTensor()
...@@ -141,10 +185,13 @@ TEST(OpCompatSensiblePass, compatOutput) { ...@@ -141,10 +185,13 @@ TEST(OpCompatSensiblePass, compatOutput) {
.AddOutput("Output_2") .AddOutput("Output_2")
.IsTensor() .IsTensor()
.End(); .End();
EXPECT_FALSE(compat.Judge(fc_op)); EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
fc_op.SetOutput("Output_2", std::vector<std::string>{"test_output", ""}); fc_op.SetOutput("Output_2", std::vector<std::string>{"test_output", ""});
EXPECT_FALSE(compat.Judge(fc_op)); EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
} }
class OpCompatSensiblePassTest : public OpCompatSensiblePass { class OpCompatSensiblePassTest : public OpCompatSensiblePass {
...@@ -158,7 +205,7 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass { ...@@ -158,7 +205,7 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass {
}; };
OpCompatSensiblePassTest::OpCompatSensiblePassTest() { OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
AddOpCompat(OpCompat("fc")) AddOpCompat(OpCompat("fc_test"))
.AddAttr("in_num_col_dims") .AddAttr("in_num_col_dims")
.IsNumLE(1) .IsNumLE(1)
.End() .End()
...@@ -180,9 +227,19 @@ OpCompatSensiblePassTest::OpCompatSensiblePassTest() { ...@@ -180,9 +227,19 @@ OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
} }
TEST(OpCompatSensiblePass, IsCompat) { TEST(OpCompatSensiblePass, IsCompat) {
OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
auto* attr = info.proto_->add_attrs();
attr->set_name("in_num_col_dims");
attr = info.proto_->add_attrs();
attr->set_name("activation_type");
OpInfoMap::Instance().Insert("fc_test", info);
OpCompatSensiblePassTest test; OpCompatSensiblePassTest test;
OpDesc fc_op; OpDesc fc_op;
fc_op.SetType("fc"); fc_op.SetType("fc_test");
std::unordered_map<std::string, Attribute> attr_map; std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1; attr_map["in_num_col_dims"] = 1;
attr_map["activation_type"] = std::string("tanh"); attr_map["activation_type"] = std::string("tanh");
...@@ -194,9 +251,23 @@ TEST(OpCompatSensiblePass, IsCompat) { ...@@ -194,9 +251,23 @@ TEST(OpCompatSensiblePass, IsCompat) {
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"}); fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_TRUE(test.TestIsCompat(fc_op)); EXPECT_TRUE(test.TestIsCompat(fc_op));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
} }
TEST(OpCompatSensiblePass, IsCompatFail) { TEST(OpCompatSensiblePass, IsCompatFail) {
OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
auto* attr = info.proto_->add_attrs();
attr->set_name("activation_type");
attr = info.proto_->add_attrs();
attr->set_name("in_num_col_dims");
OpInfoMap::Instance().Insert("fc_test", info);
OpInfoMap::Instance().Insert("op2", info);
OpCompatSensiblePassTest test; OpCompatSensiblePassTest test;
GraphPatternDetector::subgraph_t subgraph; GraphPatternDetector::subgraph_t subgraph;
PDPattern pattern; PDPattern pattern;
...@@ -204,13 +275,21 @@ TEST(OpCompatSensiblePass, IsCompatFail) { ...@@ -204,13 +275,21 @@ TEST(OpCompatSensiblePass, IsCompatFail) {
ProgramDesc prog; ProgramDesc prog;
Graph g(prog); Graph g(prog);
OpDesc fc_op; OpDesc fc_op;
fc_op.SetType("op1"); std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
attr_map["activation_type"] = std::string("tanh");
fc_op.SetAttrMap(attr_map);
fc_op.SetType("fc_test");
subgraph[pd_node] = g.CreateOpNode(&fc_op); subgraph[pd_node] = g.CreateOpNode(&fc_op);
EXPECT_TRUE(test.TestIsCompat(subgraph, &g)); EXPECT_FALSE(test.TestIsCompat(subgraph, &g));
fc_op.SetType("mul"); fc_op.SetType("op2");
subgraph[pd_node] = g.CreateOpNode(&fc_op); subgraph[pd_node] = g.CreateOpNode(&fc_op);
EXPECT_FALSE(test.TestIsCompat(subgraph, &g)); EXPECT_TRUE(test.TestIsCompat(subgraph, &g));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
OpInfoMap::Instance().mutable_map()->erase("op2");
} }
} // namespace ir } // namespace ir
......
...@@ -22,10 +22,6 @@ def { ...@@ -22,10 +22,6 @@ def {
} }
} }
extra { extra {
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs { attrs {
name: "padding_weights" name: "padding_weights"
type: BOOLEAN type: BOOLEAN
...@@ -35,63 +31,19 @@ extra { ...@@ -35,63 +31,19 @@ extra {
type: BOOLEAN type: BOOLEAN
} }
attrs { attrs {
name: "use_quantizer" name: "Scale_in"
type: BOOLEAN
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "weight_scale"
type: FLOATS
}
attrs {
name: "Input_scale"
type: FLOAT type: FLOAT
} }
attrs { attrs {
name: "out_scale" name: "Scale_weights"
type: FLOAT type: FLOAT
} }
attrs { attrs {
name: "out_threshold" name: "Scale_out"
type: FLOAT type: FLOAT
} }
attrs { attrs {
name: "force_fp32_output" name: "force_fp32_output"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "enable_int8"
type: BOOLEAN
}
attrs {
name: "use_fc_padding"
type: BOOLEAN
}
attrs {
name: "use_gpu"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册