未验证 提交 cc2f9462 编写于 作者: W wawltor 提交者: GitHub

add the support the op version check for matmul, test=op_version (#30011)

* add the support the op version check for matmul, test=op_version
上级 b33aaea8
......@@ -227,7 +227,7 @@ REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0)
.LE("matmul", 1)
.EQ("mul", 0));
REGISTER_PASS(squeeze2_matmul_fuse_pass,
......@@ -235,7 +235,7 @@ REGISTER_PASS(squeeze2_matmul_fuse_pass,
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0)
.LE("matmul", 1)
.EQ("squeeze2", 0)
.EQ("mul", 0));
......@@ -244,6 +244,6 @@ REGISTER_PASS(reshape2_matmul_fuse_pass,
REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0)
.LE("matmul", 1)
.EQ("reshape2", 0)
.EQ("mul", 0));
......@@ -103,6 +103,6 @@ REGISTER_PASS(matmul_transpose_reshape_fuse_pass,
REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0)
.LE("matmul", 1)
.EQ("transpose", 0)
.EQ("reshape", 0));
......@@ -96,4 +96,4 @@ REGISTER_PASS_CAPABILITY(scale_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("scale", 0)
.EQ("matmul", 0));
.LE("matmul", 1));
......@@ -720,5 +720,5 @@ REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v2)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.EQ("matmul", 0)
.LE("matmul", 1)
.EQ("softmax", 0));
......@@ -389,7 +389,7 @@ REGISTER_PASS(squared_mat_sub_fuse_pass,
REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("square", 0)
.LE("elementwise_mul", 1)
......
......@@ -396,4 +396,4 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.EQ("gelu", 0)
.EQ("layer_norm", 0)
.EQ("scale", 0)
.EQ("matmul", 0));
.LE("matmul", 1));
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -932,3 +933,14 @@ REGISTER_OP_CUDA_KERNEL(
ops::MatMulDoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulDoubleGradKernel<paddle::platform::CUDADeviceContext, double>);
#endif
REGISTER_OP_VERSION(matmul)
.AddCheckpoint(
R"ROC(Register matmul for adding the attribute of
fused_reshape_Y)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"fused_reshape_Y",
"In order to support the function of fused the input Y "
" and input X into the input X when "
"using the operator of matmul, and get raw shape of input Y.",
std::vector<int>{}));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册