diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc old mode 100755 new mode 100644 index 72e6742f8f34c1745637f1d0b492edadf3048bc4..9542d3d3d43f311d4e4237e2efa41fe3f998603d --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -124,6 +124,60 @@ Flatten2MatmulFusePass::Flatten2MatmulFusePass() { .End(); } +Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + + AddOpCompat(OpCompat("Squeeze2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axes") + .IsType>() + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -211,6 +265,10 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { VLOG(4) << "fuse squeeze2+matmul to mul"; + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); @@ -260,6 +318,10 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { IR_NODE_LINK_TO(mul_node, matmul_out); GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op}); ++found_count; + if (!IsCompat(desc)) { + LOG(WARNING) << "Squeeze2MatmulFusePass in out mul op compat failed."; + return; + } } }; diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h index 5dc5caae21ea9895fb30647fa3228317c43fe7ba..192dcfc00f9d34bf286b8ddebe355aa1b8d381be 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h @@ -67,6 +67,7 @@ class MapMatmul2MulPass : public FusePassBase { class Squeeze2MatmulFusePass : public FusePassBase { public: + Squeeze2MatmulFusePass(); virtual ~Squeeze2MatmulFusePass() {} protected: