未验证 提交 dffc331f 编写于 作者: 王明冬 提交者: GitHub

make the compatiable pass only check op has pbtxt, test=develop (#33397)

上级 a526b3e0
...@@ -250,6 +250,32 @@ OpCompat& OpCompatSensiblePass::AddOpCompat(OpCompat&& op_compat) { ...@@ -250,6 +250,32 @@ OpCompat& OpCompatSensiblePass::AddOpCompat(OpCompat&& op_compat) {
return *(op_compat_judgers_[name]); return *(op_compat_judgers_[name]);
} }
//! Tell the Op compability of a subgraph.
bool OpCompatSensiblePass::IsCompat(
const GraphPatternDetector::subgraph_t& subgraph, Graph*) const {
PADDLE_ENFORCE_EQ(op_compat_judgers_.empty(), false,
platform::errors::InvalidArgument(
"At least one OpCompat instance should be added"));
// Check the all the ops in the subgraph are contained in the
// op_compat.
for (auto& node_pair : subgraph) {
if (!node_pair.second->IsOp()) continue;
auto op_type = node_pair.second->Op()->Type();
if (!op_compat_judgers_.count(op_type)) {
if (HasOpDef(op_type)) {
LOG(WARNING) << op_type << "compat not registered!";
return false;
}
continue;
}
auto& judger = *op_compat_judgers_.at(op_type);
if (!judger.Judge(*(node_pair.second->Op()))) {
return false;
}
}
return true;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -195,26 +195,7 @@ class OpCompatSensiblePass : public Pass { ...@@ -195,26 +195,7 @@ class OpCompatSensiblePass : public Pass {
//! Tell the Op compability of a subgraph. //! Tell the Op compability of a subgraph.
bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph, bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) const { Graph* g) const;
CHECK(!op_compat_judgers_.empty())
<< "At least one OpCompat instance should be added in the "
"OpCompatSensiblePass.";
// Check the all the ops in the subgraph are contained in the
// op_compat.
for (auto& node_pair : subgraph) {
if (!node_pair.second->IsOp()) continue;
auto op_type = node_pair.second->Op()->Type();
if (!op_compat_judgers_.count(op_type)) {
LOG(WARNING) << op_type << "compat not registered!";
return false;
}
auto& judger = *op_compat_judgers_.at(op_type);
if (!judger.Judge(*(node_pair.second->Op()))) {
return false;
}
}
return true;
}
//! Tell the op compatibility of a single Op. //! Tell the op compatibility of a single Op.
bool IsCompat(const OpDesc& op_desc) const { bool IsCompat(const OpDesc& op_desc) const {
......
...@@ -151,6 +151,10 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass { ...@@ -151,6 +151,10 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass {
public: public:
OpCompatSensiblePassTest(); OpCompatSensiblePassTest();
bool TestIsCompat(const OpDesc& op_desc) { return IsCompat(op_desc); } bool TestIsCompat(const OpDesc& op_desc) { return IsCompat(op_desc); }
bool TestIsCompat(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
return IsCompat(subgraph, g);
}
}; };
OpCompatSensiblePassTest::OpCompatSensiblePassTest() { OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
...@@ -192,6 +196,23 @@ TEST(OpCompatSensiblePass, IsCompat) { ...@@ -192,6 +196,23 @@ TEST(OpCompatSensiblePass, IsCompat) {
EXPECT_TRUE(test.TestIsCompat(fc_op)); EXPECT_TRUE(test.TestIsCompat(fc_op));
} }
TEST(OpCompatSensiblePass, IsCompatFail) {
OpCompatSensiblePassTest test;
GraphPatternDetector::subgraph_t subgraph;
PDPattern pattern;
PDNode* pd_node = pattern.NewNode();
ProgramDesc prog;
Graph g(prog);
OpDesc fc_op;
fc_op.SetType("op1");
subgraph[pd_node] = g.CreateOpNode(&fc_op);
EXPECT_TRUE(test.TestIsCompat(subgraph, &g));
fc_op.SetType("mul");
subgraph[pd_node] = g.CreateOpNode(&fc_op);
EXPECT_FALSE(test.TestIsCompat(subgraph, &g));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -68,5 +68,9 @@ const proto::OpDef& GetOpDef(const std::string& op_name) { ...@@ -68,5 +68,9 @@ const proto::OpDef& GetOpDef(const std::string& op_name) {
} }
return ops_definition.at(op_name); return ops_definition.at(op_name);
} }
bool HasOpDef(const std::string& op_name) {
return op_def_map.find(op_name) != op_def_map.end();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -19,5 +19,7 @@ ...@@ -19,5 +19,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
const proto::OpDef& GetOpDef(const std::string& op_name); const proto::OpDef& GetOpDef(const std::string& op_name);
bool HasOpDef(const std::string& op_name);
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册