未验证 提交 54f07019 编写于 作者: 王明冬 提交者: GitHub

fix the pass compat check position error, test=develop (#35272)

上级 ef536250
...@@ -264,10 +264,6 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -264,10 +264,6 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "fuse squeeze2+matmul to mul"; VLOG(4) << "fuse squeeze2+matmul to mul";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
...@@ -299,6 +295,10 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -299,6 +295,10 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
next_ops[0]->Name() == "elementwise_add"; next_ops[0]->Name() == "elementwise_add";
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc; OpDesc desc;
desc.SetType("mul"); desc.SetType("mul");
desc.SetInput("X", {squeeze2_in_x->Name()}); desc.SetInput("X", {squeeze2_in_x->Name()});
...@@ -403,10 +403,6 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -403,10 +403,6 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "fuse reshape2+matmul to mul"; VLOG(4) << "fuse reshape2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern);
...@@ -441,6 +437,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -441,6 +437,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
next_ops[0]->Name() == "elementwise_add"; next_ops[0]->Name() == "elementwise_add";
if (flag) { if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc; OpDesc desc;
desc.SetType("mul"); desc.SetType("mul");
desc.SetInput("X", {reshape2_in_x->Name()}); desc.SetInput("X", {reshape2_in_x->Name()});
...@@ -483,11 +483,6 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -483,11 +483,6 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "fuse flatten2+matmul to mul"; VLOG(4) << "fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
...@@ -527,6 +522,10 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -527,6 +522,10 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
next_ops[0]->Name() == "elementwise_add"; next_ops[0]->Name() == "elementwise_add";
if (pattern_found) { if (pattern_found) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc; OpDesc desc;
desc.SetType("mul"); desc.SetType("mul");
desc.SetInput("X", {flatten2_in_x->Name()}); desc.SetInput("X", {flatten2_in_x->Name()});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册