未验证 提交 cc2f9462 编写于 作者: W wawltor 提交者: GitHub

add the support the op version check for matmul, test=op_version (#30011)

* add the support the op version check for matmul, test=op_version
上级 b33aaea8
...@@ -227,7 +227,7 @@ REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass); ...@@ -227,7 +227,7 @@ REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass) REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0) .LE("matmul", 1)
.EQ("mul", 0)); .EQ("mul", 0));
REGISTER_PASS(squeeze2_matmul_fuse_pass, REGISTER_PASS(squeeze2_matmul_fuse_pass,
...@@ -235,7 +235,7 @@ REGISTER_PASS(squeeze2_matmul_fuse_pass, ...@@ -235,7 +235,7 @@ REGISTER_PASS(squeeze2_matmul_fuse_pass,
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass) REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0) .LE("matmul", 1)
.EQ("squeeze2", 0) .EQ("squeeze2", 0)
.EQ("mul", 0)); .EQ("mul", 0));
...@@ -244,6 +244,6 @@ REGISTER_PASS(reshape2_matmul_fuse_pass, ...@@ -244,6 +244,6 @@ REGISTER_PASS(reshape2_matmul_fuse_pass,
REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass) REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0) .LE("matmul", 1)
.EQ("reshape2", 0) .EQ("reshape2", 0)
.EQ("mul", 0)); .EQ("mul", 0));
...@@ -103,6 +103,6 @@ REGISTER_PASS(matmul_transpose_reshape_fuse_pass, ...@@ -103,6 +103,6 @@ REGISTER_PASS(matmul_transpose_reshape_fuse_pass,
REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_fuse_pass) REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0) .LE("matmul", 1)
.EQ("transpose", 0) .EQ("transpose", 0)
.EQ("reshape", 0)); .EQ("reshape", 0));
...@@ -96,4 +96,4 @@ REGISTER_PASS_CAPABILITY(scale_matmul_fuse_pass) ...@@ -96,4 +96,4 @@ REGISTER_PASS_CAPABILITY(scale_matmul_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("scale", 0) .EQ("scale", 0)
.EQ("matmul", 0)); .LE("matmul", 1));
...@@ -720,5 +720,5 @@ REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v2) ...@@ -720,5 +720,5 @@ REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v2)
.EQ("reshape2", 0) .EQ("reshape2", 0)
.EQ("transpose2", 0) .EQ("transpose2", 0)
.EQ("scale", 0) .EQ("scale", 0)
.EQ("matmul", 0) .LE("matmul", 1)
.EQ("softmax", 0)); .EQ("softmax", 0));
...@@ -389,7 +389,7 @@ REGISTER_PASS(squared_mat_sub_fuse_pass, ...@@ -389,7 +389,7 @@ REGISTER_PASS(squared_mat_sub_fuse_pass,
REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass) REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0) .LE("matmul", 1)
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.EQ("square", 0) .EQ("square", 0)
.LE("elementwise_mul", 1) .LE("elementwise_mul", 1)
......
...@@ -396,4 +396,4 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) ...@@ -396,4 +396,4 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.EQ("gelu", 0) .EQ("gelu", 0)
.EQ("layer_norm", 0) .EQ("layer_norm", 0)
.EQ("scale", 0) .EQ("scale", 0)
.EQ("matmul", 0)); .LE("matmul", 1));
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -932,3 +933,14 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -932,3 +933,14 @@ REGISTER_OP_CUDA_KERNEL(
ops::MatMulDoubleGradKernel<paddle::platform::CUDADeviceContext, float>, ops::MatMulDoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulDoubleGradKernel<paddle::platform::CUDADeviceContext, double>); ops::MatMulDoubleGradKernel<paddle::platform::CUDADeviceContext, double>);
#endif #endif
REGISTER_OP_VERSION(matmul)
.AddCheckpoint(
R"ROC(Register matmul for adding the attribute of
fused_reshape_Y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"fused_reshape_Y",
"In order to support the function of fused the input Y "
" and input X into the input X when "
"using the operator of matmul, and get raw shape of input Y.",
std::vector<int>{}));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册