未验证 提交 dffc331f 编写于 作者: 王明冬 提交者: GitHub

make the compatiable pass only check op has pbtxt, test=develop (#33397)

上级 a526b3e0
......@@ -250,6 +250,32 @@ OpCompat& OpCompatSensiblePass::AddOpCompat(OpCompat&& op_compat) {
return *(op_compat_judgers_[name]);
}
//! Tell the Op compability of a subgraph.
bool OpCompatSensiblePass::IsCompat(
const GraphPatternDetector::subgraph_t& subgraph, Graph*) const {
PADDLE_ENFORCE_EQ(op_compat_judgers_.empty(), false,
platform::errors::InvalidArgument(
"At least one OpCompat instance should be added"));
// Check the all the ops in the subgraph are contained in the
// op_compat.
for (auto& node_pair : subgraph) {
if (!node_pair.second->IsOp()) continue;
auto op_type = node_pair.second->Op()->Type();
if (!op_compat_judgers_.count(op_type)) {
if (HasOpDef(op_type)) {
LOG(WARNING) << op_type << "compat not registered!";
return false;
}
continue;
}
auto& judger = *op_compat_judgers_.at(op_type);
if (!judger.Judge(*(node_pair.second->Op()))) {
return false;
}
}
return true;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -195,26 +195,7 @@ class OpCompatSensiblePass : public Pass {
//! Tell the Op compability of a subgraph.
bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) const {
CHECK(!op_compat_judgers_.empty())
<< "At least one OpCompat instance should be added in the "
"OpCompatSensiblePass.";
// Check the all the ops in the subgraph are contained in the
// op_compat.
for (auto& node_pair : subgraph) {
if (!node_pair.second->IsOp()) continue;
auto op_type = node_pair.second->Op()->Type();
if (!op_compat_judgers_.count(op_type)) {
LOG(WARNING) << op_type << "compat not registered!";
return false;
}
auto& judger = *op_compat_judgers_.at(op_type);
if (!judger.Judge(*(node_pair.second->Op()))) {
return false;
}
}
return true;
}
Graph* g) const;
//! Tell the op compatibility of a single Op.
bool IsCompat(const OpDesc& op_desc) const {
......
......@@ -151,6 +151,10 @@ class OpCompatSensiblePassTest : public OpCompatSensiblePass {
public:
OpCompatSensiblePassTest();
bool TestIsCompat(const OpDesc& op_desc) { return IsCompat(op_desc); }
bool TestIsCompat(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
return IsCompat(subgraph, g);
}
};
OpCompatSensiblePassTest::OpCompatSensiblePassTest() {
......@@ -192,6 +196,23 @@ TEST(OpCompatSensiblePass, IsCompat) {
EXPECT_TRUE(test.TestIsCompat(fc_op));
}
TEST(OpCompatSensiblePass, IsCompatFail) {
OpCompatSensiblePassTest test;
GraphPatternDetector::subgraph_t subgraph;
PDPattern pattern;
PDNode* pd_node = pattern.NewNode();
ProgramDesc prog;
Graph g(prog);
OpDesc fc_op;
fc_op.SetType("op1");
subgraph[pd_node] = g.CreateOpNode(&fc_op);
EXPECT_TRUE(test.TestIsCompat(subgraph, &g));
fc_op.SetType("mul");
subgraph[pd_node] = g.CreateOpNode(&fc_op);
EXPECT_FALSE(test.TestIsCompat(subgraph, &g));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -68,5 +68,9 @@ const proto::OpDef& GetOpDef(const std::string& op_name) {
}
return ops_definition.at(op_name);
}
bool HasOpDef(const std::string& op_name) {
return op_def_map.find(op_name) != op_def_map.end();
}
} // namespace framework
} // namespace paddle
......@@ -19,5 +19,7 @@
namespace paddle {
namespace framework {
const proto::OpDef& GetOpDef(const std::string& op_name);
bool HasOpDef(const std::string& op_name);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册