未验证 提交 24a2bd5c 编写于 作者: 王明冬 提交者: GitHub

[pass enhance] make the attribute check only object to which defined in op...

[pass enhance] make the attribute check only object to which defined in op proto. test=develop (#34146)
上级 f05098b5
......@@ -72,19 +72,29 @@ AttrCompat& AttrCompat::IsIntIn(const std::set<int>& candidates) {
AttrCompat& AttrCompat::IsLeftDefault() {
const std::string& op_name = op_compat_->Name();
if (!OpInfoMap::Instance().Has(op_name)) {
LOG(WARNING) << "Op (" << op_name << ") is not registered!";
conditions_.emplace_back([](const Attribute& attr) { return false; });
conditions_.emplace_back([=](const Attribute& attr) {
LOG(WARNING) << "Op (" << op_name << ") is not find in op library!";
return false;
});
return *this;
}
const OpInfo& op_info = OpInfoMap::Instance().Get(op_name);
const AttributeMap attrs = op_info.Checker()->GetDefaultAttrsMap();
if (attrs.find(attr_name_) == attrs.end()) {
LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_;
conditions_.emplace_back([](const Attribute& attr) { return false; });
conditions_.emplace_back([=](const Attribute& attr) {
LOG(WARNING) << "Op (" << op_name
<< ") has no default attr:" << attr_name_;
return false;
});
} else {
Attribute default_attr = attrs.at(attr_name_);
conditions_.emplace_back([default_attr](const Attribute& attr) -> bool {
return attr == default_attr;
conditions_.emplace_back([=](const Attribute& attr) -> bool {
if (attr == default_attr) {
return true;
}
LOG(WARNING) << "Attribute:(" << attr_name_ << ") of Op (" << op_name
<< ") not equal to default value!";
return false;
});
}
return *this;
......@@ -167,35 +177,39 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) {
return output_compats_.at(name);
}
bool OpCompat::Judge(const OpDesc& op_desc) {
bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) {
if (is_first_judge_) {
is_first_judge_ = false;
if (OpInfoMap::Instance().Has(op_name_)) {
auto& info = OpInfoMap::Instance().Get(op_name_);
if (info.proto_) {
for (const proto::OpProto::Attr& attr : info.proto_->attrs()) {
attrs_set_.emplace(attr.name());
}
}
}
const proto::OpDef& op_def = GetOpDef(op_name_);
if (op_def.has_extra()) {
for (const proto::OpDef_AttrDef& attr : op_def.extra().attrs()) {
extra_attrs_.emplace(attr.name());
attrs_set_.erase(attr.name());
}
}
for (const std::string& attr : global_extra_attrs) {
attrs_set_.erase(attr);
}
for (auto& attr_map : op_desc.GetAttrMap()) {
const std::string& name = attr_map.first;
if (name.size() >= 10u &&
0 == name.compare(name.size() - 10u, 10u, "_threshold")) {
continue; // skip the attribute ends with "_threshold", it used for
// quantization.
}
if (attr_compats_.find(attr_map.first) == attr_compats_.end()) {
if (global_extra_attrs.find(attr_map.first) != global_extra_attrs.end() ||
extra_attrs_.find(attr_map.first) != extra_attrs_.end()) {
continue;
for (const std::string& attr : attrs_set_) {
if (attr_compats_.find(attr) == attr_compats_.end()) {
attr_compats_.emplace(attr, AttrCompat(attr, this).IsLeftDefault());
}
if (!AttrCompat(attr_map.first, this).IsLeftDefault()(op_desc)) {
LOG(WARNING)
<< "The Attr(" << attr_map.first << ") of Op (" << op_name_
<< ") not reigistered in OpCompat, not in extra attribute, not "
"equal to default value!";
return false;
}
for (auto& attr_compat : attr_compats_) {
if (attrs_set_.find(attr_compat.first) == attrs_set_.end()) {
LOG(WARNING) << " Attribute(" << attr_compat.first << ") of Op("
<< op_name_
<< ") is not defined in opProto or is in extra set!"
<< "The compatable check for this attribute is not use."
<< " Please remove it from the precondition of pass: "
<< pass_name.c_str();
}
}
}
......@@ -203,7 +217,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) {
for (auto& attr_compat : attr_compats_) {
if (!attr_compat.second(op_desc)) {
LOG(WARNING) << " Check the Attr(" << attr_compat.first << ") of Op("
<< op_name_ << ") failed!";
<< op_name_ << ") in pass(" << pass_name.c_str()
<< ") failed!";
return false;
}
}
......@@ -289,7 +304,7 @@ bool OpCompatSensiblePass::IsCompat(
continue;
}
auto& judger = *op_compat_judgers_.at(op_type);
if (!judger.Judge(*(node_pair.second->Op()))) {
if (!judger.Judge(*(node_pair.second->Op()), Type())) {
return false;
}
}
......
......@@ -138,7 +138,7 @@ class OpCompat {
InputOrOutputCompat& AddOutput(const std::string& name);
//! Judge whether an OpDesc match the defined Op compatibility.
bool Judge(const OpDesc& op_desc);
bool Judge(const OpDesc& op_desc, const std::string& pass_name);
const std::string& Name() const { return op_name_; }
private:
......@@ -146,7 +146,7 @@ class OpCompat {
std::unordered_map<std::string, AttrCompat> attr_compats_;
std::unordered_map<std::string, InputOrOutputCompat> input_compats_;
std::unordered_map<std::string, InputOrOutputCompat> output_compats_;
std::unordered_set<std::string> extra_attrs_;
std::unordered_set<std::string> attrs_set_;
bool is_first_judge_ = true;
};
......@@ -206,7 +206,7 @@ class OpCompatSensiblePass : public Pass {
//! Tell the op compatibility of a single Op.
bool IsCompat(const OpDesc& op_desc) const {
if (!op_compat_judgers_.count(op_desc.Type())) return false;
return op_compat_judgers_.at(op_desc.Type())->Judge(op_desc);
return op_compat_judgers_.at(op_desc.Type())->Judge(op_desc, Type());
}
private:
......
......@@ -23,7 +23,7 @@ namespace ir {
TEST(OpCompatSensiblePass, compatOp) {
auto lambda = [](const std::string& str) { return str == "tanh"; };
OpCompat compat("fc");
OpCompat compat("fc_test");
compat.AddAttr("in_num_col_dims")
.IsIntIn({1, 2})
.IsNumLE(1)
......@@ -66,74 +66,118 @@ TEST(OpCompatSensiblePass, compatOp) {
fc_op.SetInput("Bias", std::vector<std::string>{"test_input_1"});
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_STREQ(compat.Name().c_str(), "fc");
EXPECT_TRUE(compat.Judge(fc_op));
OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
auto* attr = info.proto_->add_attrs();
attr->set_name("in_num_col_dims");
attr = info.proto_->add_attrs();
attr->set_name("test_attr");
OpInfoMap::Instance().Insert("fc_test", info);
EXPECT_STREQ(compat.Name().c_str(), "fc_test");
EXPECT_TRUE(compat.Judge(fc_op, "test_pass"));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
}
TEST(OpCompatSensiblePass, compatOpAttribute) {
OpCompat compat("fc");
OpCompat compat("fc_test");
OpDesc fc_op;
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
fc_op.SetAttrMap(attr_map);
OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
auto* attr = info.proto_->add_attrs();
attr->set_name("in_num_col_dims");
info.checker_ = new OpAttrChecker();
OpInfoMap::Instance().Insert("fc", info);
EXPECT_FALSE(compat.Judge(fc_op));
OpInfoMap::Instance().Insert("fc_test", info);
EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
OpCompat compat_1("fc_test");
info.checker_->AddAttrChecker<int>("in_num_col_dims").SetDefault(1);
EXPECT_TRUE(compat.Judge(fc_op));
EXPECT_TRUE(compat_1.Judge(fc_op, "test_pass"));
delete info.checker_;
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
}
TEST(OpCompatSensiblePass, opDefNotFound) {
OpCompat compat("fc_1");
OpCompat compat("fc_test");
OpDesc fc_op;
compat.Judge(fc_op);
OpCompat compat_1("");
compat_1.Judge(fc_op);
OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
OpInfoMap::Instance().Insert("fc_test", info);
compat.Judge(fc_op, "test_pass");
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
}
TEST(OpCompatSensiblePass, compatOpAttributeOptional) {
OpCompat compat("fc");
OpCompat compat("fc_test");
compat.AddAttr("activation_type")
.IsOptional()
.IsStringIn({"tanh", "sigmoid"});
OpDesc fc_op;
EXPECT_TRUE(compat.Judge(fc_op));
OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
auto* attr = info.proto_->add_attrs();
attr->set_name("activation_type");
OpInfoMap::Instance().Insert("fc_test", info);
EXPECT_TRUE(compat.Judge(fc_op, "test_pass"));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
}
TEST(OpCompatSensiblePass, compatOpInput) {
OpCompat compat("fc");
OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
OpInfoMap::Instance().Insert("fc_test", info);
OpCompat compat("fc_test");
OpDesc fc_op;
fc_op.SetInput("Input", std::vector<std::string>{"test_input"});
EXPECT_FALSE(compat.Judge(fc_op));
EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
compat.AddInput("Input").IsTensor().End().AddInput("Bias").IsTensor().End();
EXPECT_FALSE(compat.Judge(fc_op));
EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
fc_op.SetInput("Bias", std::vector<std::string>{"test_input", ""});
EXPECT_FALSE(compat.Judge(fc_op));
EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
}
TEST(OpCompatSensiblePass, compatOutput) {
OpCompat compat("fc");
OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
OpInfoMap::Instance().Insert("fc_test", info);
OpCompat compat("fc_test");
OpDesc fc_op;
fc_op.SetOutput("Output", std::vector<std::string>{"test_output"});
EXPECT_FALSE(compat.Judge(fc_op));
EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
compat.AddOutput("Output")
.IsTensor()
......@@ -141,10 +185,13 @@ TEST(OpCompatSensiblePass, compatOutput) {
.AddOutput("Output_2")
.IsTensor()
.End();
EXPECT_FALSE(compat.Judge(fc_op));
EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
fc_op.SetOutput("Output_2", std::vector<std::string>{"test_output", ""});
EXPECT_FALSE(compat.Judge(fc_op));
EXPECT_FALSE(compat.Judge(fc_op, "test_pass"));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
}
class OpCompatSensiblePassTest : public OpCompatSensiblePass {
......@@ -158,7 +205,7 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass {
};
OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
AddOpCompat(OpCompat("fc"))
AddOpCompat(OpCompat("fc_test"))
.AddAttr("in_num_col_dims")
.IsNumLE(1)
.End()
......@@ -180,9 +227,19 @@ OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
}
TEST(OpCompatSensiblePass, IsCompat) {
OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
auto* attr = info.proto_->add_attrs();
attr->set_name("in_num_col_dims");
attr = info.proto_->add_attrs();
attr->set_name("activation_type");
OpInfoMap::Instance().Insert("fc_test", info);
OpCompatSensiblePassTest test;
OpDesc fc_op;
fc_op.SetType("fc");
fc_op.SetType("fc_test");
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
attr_map["activation_type"] = std::string("tanh");
......@@ -194,9 +251,23 @@ TEST(OpCompatSensiblePass, IsCompat) {
fc_op.SetOutput("Out", std::vector<std::string>{"test_output"});
EXPECT_TRUE(test.TestIsCompat(fc_op));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
}
TEST(OpCompatSensiblePass, IsCompatFail) {
OpInfo info;
info.proto_ = new proto::OpProto;
info.proto_->set_type("fc_test");
info.proto_->set_comment("");
auto* attr = info.proto_->add_attrs();
attr->set_name("activation_type");
attr = info.proto_->add_attrs();
attr->set_name("in_num_col_dims");
OpInfoMap::Instance().Insert("fc_test", info);
OpInfoMap::Instance().Insert("op2", info);
OpCompatSensiblePassTest test;
GraphPatternDetector::subgraph_t subgraph;
PDPattern pattern;
......@@ -204,13 +275,21 @@ TEST(OpCompatSensiblePass, IsCompatFail) {
ProgramDesc prog;
Graph g(prog);
OpDesc fc_op;
fc_op.SetType("op1");
std::unordered_map<std::string, Attribute> attr_map;
attr_map["in_num_col_dims"] = 1;
attr_map["activation_type"] = std::string("tanh");
fc_op.SetAttrMap(attr_map);
fc_op.SetType("fc_test");
subgraph[pd_node] = g.CreateOpNode(&fc_op);
EXPECT_TRUE(test.TestIsCompat(subgraph, &g));
EXPECT_FALSE(test.TestIsCompat(subgraph, &g));
fc_op.SetType("mul");
fc_op.SetType("op2");
subgraph[pd_node] = g.CreateOpNode(&fc_op);
EXPECT_FALSE(test.TestIsCompat(subgraph, &g));
EXPECT_TRUE(test.TestIsCompat(subgraph, &g));
delete info.proto_;
OpInfoMap::Instance().mutable_map()->erase("fc_test");
OpInfoMap::Instance().mutable_map()->erase("op2");
}
} // namespace ir
......
......@@ -22,10 +22,6 @@ def {
}
}
extra {
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "padding_weights"
type: BOOLEAN
......@@ -35,63 +31,19 @@ extra {
type: BOOLEAN
}
attrs {
name: "use_quantizer"
type: BOOLEAN
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "weight_scale"
type: FLOATS
}
attrs {
name: "Input_scale"
name: "Scale_in"
type: FLOAT
}
attrs {
name: "out_scale"
name: "Scale_weights"
type: FLOAT
}
attrs {
name: "out_threshold"
name: "Scale_out"
type: FLOAT
}
attrs {
name: "force_fp32_output"
type: BOOLEAN
}
attrs {
name: "enable_int8"
type: BOOLEAN
}
attrs {
name: "use_fc_padding"
type: BOOLEAN
}
attrs {
name: "use_gpu"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册