From dffc331fa29d411bf5c04c46c5bf61b429a6a59f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Thu, 10 Jun 2021 10:20:03 +0800 Subject: [PATCH] make the compatiable pass only check op has pbtxt, test=develop (#33397) --- .../framework/ir/op_compat_sensible_pass.cc | 26 +++++++++++++++++++ .../framework/ir/op_compat_sensible_pass.h | 21 +-------------- .../ir/op_compat_sensible_pass_tester.cc | 21 +++++++++++++++ paddle/fluid/framework/op_def_api.cc | 4 +++ paddle/fluid/framework/op_def_api.h | 2 ++ 5 files changed, 54 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index 3d8e655c5b2..e422a9bae31 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -250,6 +250,32 @@ OpCompat& OpCompatSensiblePass::AddOpCompat(OpCompat&& op_compat) { return *(op_compat_judgers_[name]); } +//! Tell the Op compability of a subgraph. +bool OpCompatSensiblePass::IsCompat( + const GraphPatternDetector::subgraph_t& subgraph, Graph*) const { + PADDLE_ENFORCE_EQ(op_compat_judgers_.empty(), false, + platform::errors::InvalidArgument( + "At least one OpCompat instance should be added")); + // Check the all the ops in the subgraph are contained in the + // op_compat. + for (auto& node_pair : subgraph) { + if (!node_pair.second->IsOp()) continue; + auto op_type = node_pair.second->Op()->Type(); + if (!op_compat_judgers_.count(op_type)) { + if (HasOpDef(op_type)) { + LOG(WARNING) << op_type << "compat not registered!"; + return false; + } + continue; + } + auto& judger = *op_compat_judgers_.at(op_type); + if (!judger.Judge(*(node_pair.second->Op()))) { + return false; + } + } + return true; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.h b/paddle/fluid/framework/ir/op_compat_sensible_pass.h index 3aa985c6d46..7346ca3756f 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.h +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.h @@ -195,26 +195,7 @@ class OpCompatSensiblePass : public Pass { //! Tell the Op compability of a subgraph. bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) const { - CHECK(!op_compat_judgers_.empty()) - << "At least one OpCompat instance should be added in the " - "OpCompatSensiblePass."; - // Check the all the ops in the subgraph are contained in the - // op_compat. - for (auto& node_pair : subgraph) { - if (!node_pair.second->IsOp()) continue; - auto op_type = node_pair.second->Op()->Type(); - if (!op_compat_judgers_.count(op_type)) { - LOG(WARNING) << op_type << "compat not registered!"; - return false; - } - auto& judger = *op_compat_judgers_.at(op_type); - if (!judger.Judge(*(node_pair.second->Op()))) { - return false; - } - } - return true; - } + Graph* g) const; //! Tell the op compatibility of a single Op. bool IsCompat(const OpDesc& op_desc) const { diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc index 87e28ae3a3a..9074a9876f9 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc @@ -151,6 +151,10 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass { public: OpCompatSensiblePassTest(); bool TestIsCompat(const OpDesc& op_desc) { return IsCompat(op_desc); } + bool TestIsCompat(const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + return IsCompat(subgraph, g); + } }; OpCompatSensiblePassTest::OpCompatSensiblePassTest() { @@ -192,6 +196,23 @@ TEST(OpCompatSensiblePass, IsCompat) { EXPECT_TRUE(test.TestIsCompat(fc_op)); } +TEST(OpCompatSensiblePass, IsCompatFail) { + OpCompatSensiblePassTest test; + GraphPatternDetector::subgraph_t subgraph; + PDPattern pattern; + PDNode* pd_node = pattern.NewNode(); + ProgramDesc prog; + Graph g(prog); + OpDesc fc_op; + fc_op.SetType("op1"); + subgraph[pd_node] = g.CreateOpNode(&fc_op); + EXPECT_TRUE(test.TestIsCompat(subgraph, &g)); + + fc_op.SetType("mul"); + subgraph[pd_node] = g.CreateOpNode(&fc_op); + EXPECT_FALSE(test.TestIsCompat(subgraph, &g)); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_def_api.cc b/paddle/fluid/framework/op_def_api.cc index 5e758fe4105..b950f000bb8 100644 --- a/paddle/fluid/framework/op_def_api.cc +++ b/paddle/fluid/framework/op_def_api.cc @@ -68,5 +68,9 @@ const proto::OpDef& GetOpDef(const std::string& op_name) { } return ops_definition.at(op_name); } + +bool HasOpDef(const std::string& op_name) { + return op_def_map.find(op_name) != op_def_map.end(); +} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_def_api.h b/paddle/fluid/framework/op_def_api.h index 4ec2089f9b1..1ef2254d0da 100644 --- a/paddle/fluid/framework/op_def_api.h +++ b/paddle/fluid/framework/op_def_api.h @@ -19,5 +19,7 @@ namespace paddle { namespace framework { const proto::OpDef& GetOpDef(const std::string& op_name); + +bool HasOpDef(const std::string& op_name); } } -- GitLab