diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc old mode 100644 new mode 100755 index c36123f65f6644289cfba2b2729862efa601e2fd..20761f2f1eacba2a50aea028fef8c0f992ebf8d3 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -16,6 +16,7 @@ #include #include +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -26,6 +27,103 @@ namespace ir { class Node; +MapMatmul2MulPass::MapMatmul2MulPass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + +Flatten2MatmulFusePass::Flatten2MatmulFusePass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + + AddOpCompat(OpCompat("flatten2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(0) + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -39,6 +137,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + VLOG(4) << "map matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); @@ -82,6 +185,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { IR_NODE_LINK_TO(mul_node, matmul_out); GraphSafeRemoveNodes(graph, {matmul_op}); ++found_count; + + if (!IsCompat(desc)) { + LOG(WARNING) << "MapMatmul2MulPass in out mul op compat failed."; + return; + } } }; @@ -244,6 +352,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + VLOG(4) << "fuse flatten2+matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern); @@ -301,6 +414,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { IR_NODE_LINK_TO(mul_node, matmul_out); GraphSafeRemoveNodes(graph, {flatten2_op, matmul_in_x, matmul_op}); ++found_count; + + if (!IsCompat(desc)) { + LOG(WARNING) << "Flatten2MatmulFusePass in out mul op compat failed."; + return; + } } }; diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h index 85067a6f642fe4637467541cd08f89bba3b397db..27828f9c43829ca66f1a513b8984c9366755babc 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h @@ -39,6 +39,7 @@ class Graph; class MapMatmul2MulPass : public FusePassBase { public: + MapMatmul2MulPass(); virtual ~MapMatmul2MulPass() {} protected: @@ -103,6 +104,7 @@ class Reshape2MatmulFusePass : public FusePassBase { class Flatten2MatmulFusePass : public FusePassBase { public: + Flatten2MatmulFusePass(); virtual ~Flatten2MatmulFusePass() {} protected: