diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc index 76148a90074c1650946d02492b8664007fe7e6b3..8c4e6f330587773226caee1779fbf31eb80d3137 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -227,7 +227,7 @@ REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass); REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("mul", 0)); REGISTER_PASS(squeeze2_matmul_fuse_pass, @@ -235,7 +235,7 @@ REGISTER_PASS(squeeze2_matmul_fuse_pass, REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("squeeze2", 0) .EQ("mul", 0)); @@ -244,6 +244,6 @@ REGISTER_PASS(reshape2_matmul_fuse_pass, REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("reshape2", 0) .EQ("mul", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc index 41b859f0af665eae6d9ccb6a08cd29db5ce67fdf..fbc97a0a929c48c4eba3baa881061654dd802b62 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc @@ -103,6 +103,6 @@ REGISTER_PASS(matmul_transpose_reshape_fuse_pass, REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("transpose", 0) .EQ("reshape", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc index 0784a1a024cfd31cfb2d2a3ea205518416c2ad13..a552e42619f368c2e8e2a51213ac10d9317151cf 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc @@ -96,4 +96,4 @@ REGISTER_PASS_CAPABILITY(scale_matmul_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .EQ("scale", 0) - .EQ("matmul", 0)); + .LE("matmul", 1)); diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index bb9613d0c1764df7b66b049ed1a4e71d578e9db4..224272a5a039fccd331ab050d25b8fa2d00bc6d9 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -720,5 +720,5 @@ REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v2) .EQ("reshape2", 0) .EQ("transpose2", 0) .EQ("scale", 0) - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("softmax", 0)); diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc index d17212f4aa35ea80ab459c8e86b93a57955e2149..c0420e6b5f3c212721b278ce04bf7ece090a5cc5 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -389,7 +389,7 @@ REGISTER_PASS(squared_mat_sub_fuse_pass, REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("matmul_v2", 0) .EQ("square", 0) .LE("elementwise_mul", 1) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index a67908fe7f22f2e20c445d363499b3ab7d2af4bd..4bd804dfca4d56cc7cb78984f12ae6d2233f77b8 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -396,4 +396,4 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .EQ("gelu", 0) .EQ("layer_norm", 0) .EQ("scale", 0) - .EQ("matmul", 0)); + .LE("matmul", 1)); diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index d45669a9f075b5dfcbd9df27df9868758891ae4d..668445d2429e2977f26c569e01a50da66f136130 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/math/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -932,3 +933,14 @@ REGISTER_OP_CUDA_KERNEL( ops::MatMulDoubleGradKernel, ops::MatMulDoubleGradKernel); #endif + +REGISTER_OP_VERSION(matmul) + .AddCheckpoint( + R"ROC(Register matmul for adding the attribute of + fused_reshape_Y)ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "fused_reshape_Y", + "In order to support the function of fused the input Y " + " and input X into the input X when " + "using the operator of matmul, and get raw shape of input Y.", + std::vector{}));