From 54f07019469f950c9d01d146a8f12f84855caaf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Tue, 31 Aug 2021 13:59:32 +0800 Subject: [PATCH] fix the pass compat check position error, test=develop (#35272) --- .../framework/ir/map_matmul_to_mul_pass.cc | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc index 61376828473..b8666c1c73e 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -264,10 +264,6 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { VLOG(4) << "fuse squeeze2+matmul to mul"; - if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Pass in op compat failed."; - return; - } GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); @@ -299,6 +295,10 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { next_ops[0]->Name() == "elementwise_add"; if (flag) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } OpDesc desc; desc.SetType("mul"); desc.SetInput("X", {squeeze2_in_x->Name()}); @@ -403,10 +403,6 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Pass in op compat failed."; - return; - } VLOG(4) << "fuse reshape2+matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern); @@ -441,6 +437,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { next_ops[0]->Name() == "elementwise_add"; if (flag) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } OpDesc desc; desc.SetType("mul"); desc.SetInput("X", {reshape2_in_x->Name()}); @@ -483,11 +483,6 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Pass in op compat failed."; - return; - } - VLOG(4) << "fuse flatten2+matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern); @@ -527,6 +522,10 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { next_ops[0]->Name() == "elementwise_add"; if (pattern_found) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } OpDesc desc; desc.SetType("mul"); desc.SetInput("X", {flatten2_in_x->Name()}); -- GitLab