From 24a2bd5c1bdeb656991a8cc5b62b521ed6ae1213 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Thu, 15 Jul 2021 10:02:10 +0800 Subject: [PATCH] [pass enhance] make the attribute check only object to which defined in op proto. test=develop (#34146) --- .../framework/ir/op_compat_sensible_pass.cc | 71 +++++---- .../framework/ir/op_compat_sensible_pass.h | 6 +- .../ir/op_compat_sensible_pass_tester.cc | 145 ++++++++++++++---- paddle/fluid/operators/compat/fc.pbtxt | 54 +------ 4 files changed, 161 insertions(+), 115 deletions(-) diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index 8f814822b6a..8261a0b8ca7 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -72,19 +72,29 @@ AttrCompat& AttrCompat::IsIntIn(const std::set& candidates) { AttrCompat& AttrCompat::IsLeftDefault() { const std::string& op_name = op_compat_->Name(); if (!OpInfoMap::Instance().Has(op_name)) { - LOG(WARNING) << "Op (" << op_name << ") is not registered!"; - conditions_.emplace_back([](const Attribute& attr) { return false; }); + conditions_.emplace_back([=](const Attribute& attr) { + LOG(WARNING) << "Op (" << op_name << ") is not find in op library!"; + return false; + }); return *this; } const OpInfo& op_info = OpInfoMap::Instance().Get(op_name); const AttributeMap attrs = op_info.Checker()->GetDefaultAttrsMap(); if (attrs.find(attr_name_) == attrs.end()) { - LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_; - conditions_.emplace_back([](const Attribute& attr) { return false; }); + conditions_.emplace_back([=](const Attribute& attr) { + LOG(WARNING) << "Op (" << op_name + << ") has no default attr:" << attr_name_; + return false; + }); } else { Attribute default_attr = attrs.at(attr_name_); - conditions_.emplace_back([default_attr](const Attribute& attr) -> bool { - return attr == default_attr; + conditions_.emplace_back([=](const Attribute& attr) -> bool { + if (attr == default_attr) { + return true; + } + LOG(WARNING) << "Attribute:(" << attr_name_ << ") of Op (" << op_name + << ") not equal to default value!"; + return false; }); } return *this; @@ -167,35 +177,39 @@ InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) { return output_compats_.at(name); } -bool OpCompat::Judge(const OpDesc& op_desc) { +bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) { if (is_first_judge_) { is_first_judge_ = false; + if (OpInfoMap::Instance().Has(op_name_)) { + auto& info = OpInfoMap::Instance().Get(op_name_); + if (info.proto_) { + for (const proto::OpProto::Attr& attr : info.proto_->attrs()) { + attrs_set_.emplace(attr.name()); + } + } + } const proto::OpDef& op_def = GetOpDef(op_name_); if (op_def.has_extra()) { for (const proto::OpDef_AttrDef& attr : op_def.extra().attrs()) { - extra_attrs_.emplace(attr.name()); + attrs_set_.erase(attr.name()); } } - } - - for (auto& attr_map : op_desc.GetAttrMap()) { - const std::string& name = attr_map.first; - if (name.size() >= 10u && - 0 == name.compare(name.size() - 10u, 10u, "_threshold")) { - continue; // skip the attribute ends with "_threshold", it used for - // quantization. + for (const std::string& attr : global_extra_attrs) { + attrs_set_.erase(attr); } - if (attr_compats_.find(attr_map.first) == attr_compats_.end()) { - if (global_extra_attrs.find(attr_map.first) != global_extra_attrs.end() || - extra_attrs_.find(attr_map.first) != extra_attrs_.end()) { - continue; + for (const std::string& attr : attrs_set_) { + if (attr_compats_.find(attr) == attr_compats_.end()) { + attr_compats_.emplace(attr, AttrCompat(attr, this).IsLeftDefault()); } - if (!AttrCompat(attr_map.first, this).IsLeftDefault()(op_desc)) { - LOG(WARNING) - << "The Attr(" << attr_map.first << ") of Op (" << op_name_ - << ") not reigistered in OpCompat, not in extra attribute, not " - "equal to default value!"; - return false; + } + for (auto& attr_compat : attr_compats_) { + if (attrs_set_.find(attr_compat.first) == attrs_set_.end()) { + LOG(WARNING) << " Attribute(" << attr_compat.first << ") of Op(" + << op_name_ + << ") is not defined in opProto or is in extra set!" + << "The compatable check for this attribute is not use." + << " Please remove it from the precondition of pass: " + << pass_name.c_str(); } } } @@ -203,7 +217,8 @@ bool OpCompat::Judge(const OpDesc& op_desc) { for (auto& attr_compat : attr_compats_) { if (!attr_compat.second(op_desc)) { LOG(WARNING) << " Check the Attr(" << attr_compat.first << ") of Op(" - << op_name_ << ") failed!"; + << op_name_ << ") in pass(" << pass_name.c_str() + << ") failed!"; return false; } } @@ -289,7 +304,7 @@ bool OpCompatSensiblePass::IsCompat( continue; } auto& judger = *op_compat_judgers_.at(op_type); - if (!judger.Judge(*(node_pair.second->Op()))) { + if (!judger.Judge(*(node_pair.second->Op()), Type())) { return false; } } diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.h b/paddle/fluid/framework/ir/op_compat_sensible_pass.h index cfec1f123e2..e24294a03a2 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.h +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.h @@ -138,7 +138,7 @@ class OpCompat { InputOrOutputCompat& AddOutput(const std::string& name); //! Judge whether an OpDesc match the defined Op compatibility. - bool Judge(const OpDesc& op_desc); + bool Judge(const OpDesc& op_desc, const std::string& pass_name); const std::string& Name() const { return op_name_; } private: @@ -146,7 +146,7 @@ class OpCompat { std::unordered_map attr_compats_; std::unordered_map input_compats_; std::unordered_map output_compats_; - std::unordered_set extra_attrs_; + std::unordered_set attrs_set_; bool is_first_judge_ = true; }; @@ -206,7 +206,7 @@ class OpCompatSensiblePass : public Pass { //! Tell the op compatibility of a single Op. bool IsCompat(const OpDesc& op_desc) const { if (!op_compat_judgers_.count(op_desc.Type())) return false; - return op_compat_judgers_.at(op_desc.Type())->Judge(op_desc); + return op_compat_judgers_.at(op_desc.Type())->Judge(op_desc, Type()); } private: diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc index 9074a9876f9..9602cd41131 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc @@ -23,7 +23,7 @@ namespace ir { TEST(OpCompatSensiblePass, compatOp) { auto lambda = [](const std::string& str) { return str == "tanh"; }; - OpCompat compat("fc"); + OpCompat compat("fc_test"); compat.AddAttr("in_num_col_dims") .IsIntIn({1, 2}) .IsNumLE(1) @@ -66,74 +66,118 @@ TEST(OpCompatSensiblePass, compatOp) { fc_op.SetInput("Bias", std::vector{"test_input_1"}); fc_op.SetOutput("Out", std::vector{"test_output"}); - EXPECT_STREQ(compat.Name().c_str(), "fc"); - EXPECT_TRUE(compat.Judge(fc_op)); + OpInfo info; + info.proto_ = new proto::OpProto; + info.proto_->set_type("fc_test"); + info.proto_->set_comment(""); + auto* attr = info.proto_->add_attrs(); + attr->set_name("in_num_col_dims"); + attr = info.proto_->add_attrs(); + attr->set_name("test_attr"); + OpInfoMap::Instance().Insert("fc_test", info); + + EXPECT_STREQ(compat.Name().c_str(), "fc_test"); + EXPECT_TRUE(compat.Judge(fc_op, "test_pass")); + + delete info.proto_; + OpInfoMap::Instance().mutable_map()->erase("fc_test"); } TEST(OpCompatSensiblePass, compatOpAttribute) { - OpCompat compat("fc"); + OpCompat compat("fc_test"); OpDesc fc_op; - std::unordered_map attr_map; attr_map["in_num_col_dims"] = 1; fc_op.SetAttrMap(attr_map); OpInfo info; + info.proto_ = new proto::OpProto; + info.proto_->set_type("fc_test"); + info.proto_->set_comment(""); + auto* attr = info.proto_->add_attrs(); + attr->set_name("in_num_col_dims"); info.checker_ = new OpAttrChecker(); - OpInfoMap::Instance().Insert("fc", info); - - EXPECT_FALSE(compat.Judge(fc_op)); + OpInfoMap::Instance().Insert("fc_test", info); + EXPECT_FALSE(compat.Judge(fc_op, "test_pass")); + OpCompat compat_1("fc_test"); info.checker_->AddAttrChecker("in_num_col_dims").SetDefault(1); - - EXPECT_TRUE(compat.Judge(fc_op)); + EXPECT_TRUE(compat_1.Judge(fc_op, "test_pass")); delete info.checker_; + delete info.proto_; + OpInfoMap::Instance().mutable_map()->erase("fc_test"); } TEST(OpCompatSensiblePass, opDefNotFound) { - OpCompat compat("fc_1"); + OpCompat compat("fc_test"); OpDesc fc_op; - - compat.Judge(fc_op); - - OpCompat compat_1(""); - - compat_1.Judge(fc_op); + OpInfo info; + info.proto_ = new proto::OpProto; + info.proto_->set_type("fc_test"); + info.proto_->set_comment(""); + OpInfoMap::Instance().Insert("fc_test", info); + compat.Judge(fc_op, "test_pass"); + delete info.proto_; + OpInfoMap::Instance().mutable_map()->erase("fc_test"); } TEST(OpCompatSensiblePass, compatOpAttributeOptional) { - OpCompat compat("fc"); + OpCompat compat("fc_test"); compat.AddAttr("activation_type") .IsOptional() .IsStringIn({"tanh", "sigmoid"}); OpDesc fc_op; - EXPECT_TRUE(compat.Judge(fc_op)); + OpInfo info; + info.proto_ = new proto::OpProto; + info.proto_->set_type("fc_test"); + info.proto_->set_comment(""); + auto* attr = info.proto_->add_attrs(); + attr->set_name("activation_type"); + OpInfoMap::Instance().Insert("fc_test", info); + EXPECT_TRUE(compat.Judge(fc_op, "test_pass")); + delete info.proto_; + OpInfoMap::Instance().mutable_map()->erase("fc_test"); } TEST(OpCompatSensiblePass, compatOpInput) { - OpCompat compat("fc"); + OpInfo info; + info.proto_ = new proto::OpProto; + info.proto_->set_type("fc_test"); + info.proto_->set_comment(""); + OpInfoMap::Instance().Insert("fc_test", info); + + OpCompat compat("fc_test"); OpDesc fc_op; fc_op.SetInput("Input", std::vector{"test_input"}); - EXPECT_FALSE(compat.Judge(fc_op)); + EXPECT_FALSE(compat.Judge(fc_op, "test_pass")); compat.AddInput("Input").IsTensor().End().AddInput("Bias").IsTensor().End(); - EXPECT_FALSE(compat.Judge(fc_op)); + EXPECT_FALSE(compat.Judge(fc_op, "test_pass")); fc_op.SetInput("Bias", std::vector{"test_input", ""}); - EXPECT_FALSE(compat.Judge(fc_op)); + EXPECT_FALSE(compat.Judge(fc_op, "test_pass")); + + delete info.proto_; + OpInfoMap::Instance().mutable_map()->erase("fc_test"); } TEST(OpCompatSensiblePass, compatOutput) { - OpCompat compat("fc"); + OpInfo info; + info.proto_ = new proto::OpProto; + info.proto_->set_type("fc_test"); + info.proto_->set_comment(""); + OpInfoMap::Instance().Insert("fc_test", info); + + OpCompat compat("fc_test"); OpDesc fc_op; fc_op.SetOutput("Output", std::vector{"test_output"}); - EXPECT_FALSE(compat.Judge(fc_op)); + EXPECT_FALSE(compat.Judge(fc_op, "test_pass")); compat.AddOutput("Output") .IsTensor() @@ -141,10 +185,13 @@ TEST(OpCompatSensiblePass, compatOutput) { .AddOutput("Output_2") .IsTensor() .End(); - EXPECT_FALSE(compat.Judge(fc_op)); + EXPECT_FALSE(compat.Judge(fc_op, "test_pass")); fc_op.SetOutput("Output_2", std::vector{"test_output", ""}); - EXPECT_FALSE(compat.Judge(fc_op)); + EXPECT_FALSE(compat.Judge(fc_op, "test_pass")); + + delete info.proto_; + OpInfoMap::Instance().mutable_map()->erase("fc_test"); } class OpCompatSensiblePassTest : public OpCompatSensiblePass { @@ -158,7 +205,7 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass { }; OpCompatSensiblePassTest::OpCompatSensiblePassTest() { - AddOpCompat(OpCompat("fc")) + AddOpCompat(OpCompat("fc_test")) .AddAttr("in_num_col_dims") .IsNumLE(1) .End() @@ -180,9 +227,19 @@ OpCompatSensiblePassTest::OpCompatSensiblePassTest() { } TEST(OpCompatSensiblePass, IsCompat) { + OpInfo info; + info.proto_ = new proto::OpProto; + info.proto_->set_type("fc_test"); + info.proto_->set_comment(""); + auto* attr = info.proto_->add_attrs(); + attr->set_name("in_num_col_dims"); + attr = info.proto_->add_attrs(); + attr->set_name("activation_type"); + OpInfoMap::Instance().Insert("fc_test", info); + OpCompatSensiblePassTest test; OpDesc fc_op; - fc_op.SetType("fc"); + fc_op.SetType("fc_test"); std::unordered_map attr_map; attr_map["in_num_col_dims"] = 1; attr_map["activation_type"] = std::string("tanh"); @@ -194,9 +251,23 @@ TEST(OpCompatSensiblePass, IsCompat) { fc_op.SetOutput("Out", std::vector{"test_output"}); EXPECT_TRUE(test.TestIsCompat(fc_op)); + + delete info.proto_; + OpInfoMap::Instance().mutable_map()->erase("fc_test"); } TEST(OpCompatSensiblePass, IsCompatFail) { + OpInfo info; + info.proto_ = new proto::OpProto; + info.proto_->set_type("fc_test"); + info.proto_->set_comment(""); + auto* attr = info.proto_->add_attrs(); + attr->set_name("activation_type"); + attr = info.proto_->add_attrs(); + attr->set_name("in_num_col_dims"); + OpInfoMap::Instance().Insert("fc_test", info); + OpInfoMap::Instance().Insert("op2", info); + OpCompatSensiblePassTest test; GraphPatternDetector::subgraph_t subgraph; PDPattern pattern; @@ -204,13 +275,21 @@ TEST(OpCompatSensiblePass, IsCompatFail) { ProgramDesc prog; Graph g(prog); OpDesc fc_op; - fc_op.SetType("op1"); + std::unordered_map attr_map; + attr_map["in_num_col_dims"] = 1; + attr_map["activation_type"] = std::string("tanh"); + fc_op.SetAttrMap(attr_map); + fc_op.SetType("fc_test"); subgraph[pd_node] = g.CreateOpNode(&fc_op); - EXPECT_TRUE(test.TestIsCompat(subgraph, &g)); + EXPECT_FALSE(test.TestIsCompat(subgraph, &g)); - fc_op.SetType("mul"); + fc_op.SetType("op2"); subgraph[pd_node] = g.CreateOpNode(&fc_op); - EXPECT_FALSE(test.TestIsCompat(subgraph, &g)); + EXPECT_TRUE(test.TestIsCompat(subgraph, &g)); + + delete info.proto_; + OpInfoMap::Instance().mutable_map()->erase("fc_test"); + OpInfoMap::Instance().mutable_map()->erase("op2"); } } // namespace ir diff --git a/paddle/fluid/operators/compat/fc.pbtxt b/paddle/fluid/operators/compat/fc.pbtxt index 55e1a22ce4d..b7b9fe7acda 100644 --- a/paddle/fluid/operators/compat/fc.pbtxt +++ b/paddle/fluid/operators/compat/fc.pbtxt @@ -22,10 +22,6 @@ def { } } extra { - attrs { - name: "use_mkldnn" - type: BOOLEAN - } attrs { name: "padding_weights" type: BOOLEAN @@ -35,63 +31,19 @@ extra { type: BOOLEAN } attrs { - name: "use_quantizer" - type: BOOLEAN - } - attrs { - name: "mkldnn_data_type" - type: STRING - } - attrs { - name: "weight_scale" - type: FLOATS - } - attrs { - name: "Input_scale" + name: "Scale_in" type: FLOAT } attrs { - name: "out_scale" + name: "Scale_weights" type: FLOAT } attrs { - name: "out_threshold" + name: "Scale_out" type: FLOAT } attrs { name: "force_fp32_output" type: BOOLEAN } - attrs { - name: "enable_int8" - type: BOOLEAN - } - attrs { - name: "use_fc_padding" - type: BOOLEAN - } - attrs { - name: "use_gpu" - type: BOOLEAN - } - attrs { - name: "op_role" - type: INT - } - attrs { - name: "op_role_var" - type: STRINGS - } - attrs { - name: "op_namescope" - type: STRING - } - attrs { - name: "op_callstack" - type: STRINGS - } - attrs { - name: "op_device" - type: STRING - } } -- GitLab