diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index 3d8e655c5b2730fd36651c67d2f7c37b7dd5ecd9..e422a9bae31181d064bd36359ff1ebe38da2cac6 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -250,6 +250,32 @@ OpCompat& OpCompatSensiblePass::AddOpCompat(OpCompat&& op_compat) { return *(op_compat_judgers_[name]); } +//! Tell the Op compability of a subgraph. +bool OpCompatSensiblePass::IsCompat( + const GraphPatternDetector::subgraph_t& subgraph, Graph*) const { + PADDLE_ENFORCE_EQ(op_compat_judgers_.empty(), false, + platform::errors::InvalidArgument( + "At least one OpCompat instance should be added")); + // Check the all the ops in the subgraph are contained in the + // op_compat. + for (auto& node_pair : subgraph) { + if (!node_pair.second->IsOp()) continue; + auto op_type = node_pair.second->Op()->Type(); + if (!op_compat_judgers_.count(op_type)) { + if (HasOpDef(op_type)) { + LOG(WARNING) << op_type << "compat not registered!"; + return false; + } + continue; + } + auto& judger = *op_compat_judgers_.at(op_type); + if (!judger.Judge(*(node_pair.second->Op()))) { + return false; + } + } + return true; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.h b/paddle/fluid/framework/ir/op_compat_sensible_pass.h index 3aa985c6d46fa262bd4050f63e668c68e55237ac..7346ca3756f361f00fb67090d4127995fbe89b30 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.h +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.h @@ -195,26 +195,7 @@ class OpCompatSensiblePass : public Pass { //! Tell the Op compability of a subgraph. bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) const { - CHECK(!op_compat_judgers_.empty()) - << "At least one OpCompat instance should be added in the " - "OpCompatSensiblePass."; - // Check the all the ops in the subgraph are contained in the - // op_compat. - for (auto& node_pair : subgraph) { - if (!node_pair.second->IsOp()) continue; - auto op_type = node_pair.second->Op()->Type(); - if (!op_compat_judgers_.count(op_type)) { - LOG(WARNING) << op_type << "compat not registered!"; - return false; - } - auto& judger = *op_compat_judgers_.at(op_type); - if (!judger.Judge(*(node_pair.second->Op()))) { - return false; - } - } - return true; - } + Graph* g) const; //! Tell the op compatibility of a single Op. bool IsCompat(const OpDesc& op_desc) const { diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc index 87e28ae3a3aadda63ef67c82596d20cfb0c644f4..9074a9876f9f7d200d4c464fdab57b641c1d3b5a 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc @@ -151,6 +151,10 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass { public: OpCompatSensiblePassTest(); bool TestIsCompat(const OpDesc& op_desc) { return IsCompat(op_desc); } + bool TestIsCompat(const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + return IsCompat(subgraph, g); + } }; OpCompatSensiblePassTest::OpCompatSensiblePassTest() { @@ -192,6 +196,23 @@ TEST(OpCompatSensiblePass, IsCompat) { EXPECT_TRUE(test.TestIsCompat(fc_op)); } +TEST(OpCompatSensiblePass, IsCompatFail) { + OpCompatSensiblePassTest test; + GraphPatternDetector::subgraph_t subgraph; + PDPattern pattern; + PDNode* pd_node = pattern.NewNode(); + ProgramDesc prog; + Graph g(prog); + OpDesc fc_op; + fc_op.SetType("op1"); + subgraph[pd_node] = g.CreateOpNode(&fc_op); + EXPECT_TRUE(test.TestIsCompat(subgraph, &g)); + + fc_op.SetType("mul"); + subgraph[pd_node] = g.CreateOpNode(&fc_op); + EXPECT_FALSE(test.TestIsCompat(subgraph, &g)); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_def_api.cc b/paddle/fluid/framework/op_def_api.cc index 5e758fe4105097e0c6f3032d1d4e150b661ff5f5..b950f000bb8e50973d6d6ecbc32c416958b92ed3 100644 --- a/paddle/fluid/framework/op_def_api.cc +++ b/paddle/fluid/framework/op_def_api.cc @@ -68,5 +68,9 @@ const proto::OpDef& GetOpDef(const std::string& op_name) { } return ops_definition.at(op_name); } + +bool HasOpDef(const std::string& op_name) { + return op_def_map.find(op_name) != op_def_map.end(); +} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_def_api.h b/paddle/fluid/framework/op_def_api.h index 4ec2089f9b1f88de18305cb5a6615f96f2718d39..1ef2254d0da361915f29b713e2d9a53d5c35cb8a 100644 --- a/paddle/fluid/framework/op_def_api.h +++ b/paddle/fluid/framework/op_def_api.h @@ -19,5 +19,7 @@ namespace paddle { namespace framework { const proto::OpDef& GetOpDef(const std::string& op_name); + +bool HasOpDef(const std::string& op_name); } }