未验证 提交 4d97b25d 编写于 作者: S Sławomir Siwek 提交者: GitHub

Remove oneDNN-specific attributes from matmul (#49444)

* replace matmul with matmul_v2 in fuse passes

* Remove fusion logic from matmul

* removing fusion methods

* add proper name

* adjust namespaces

* clean attrs in python tests

* delete checkpoint and restore matmul version

* remove unused code

* matmul and reshape/transpose fuses migrated

* split MatmulOneDNN headers

* fuse activation and eltwise_add

* add fuse_activation

* matmul_transpose_reshape/reshape_transpose_matmul

* matmul + elementwise_add (fused)

* activation temporary modifciation

* restore matmul(v1) version 0

* merge newest develop

* remove depedency from other PR

* revert pbtxt

* remove placeholders from matmul_v2

* add description in OPMaker

* remove matmul_v2_op.h and all depedencies

* remove dims changing in base op

* add possibility to fuse already fused_matmul

* restart broken CI

* Empty-Commit

* revert matmul_utils.h

* codestyle

* adjust imports

* add pbtxt file

* 100% matmul unit tests coverage

* trigger CI with minimal changes to develop

* adjust changes to develop

* add fused_matmul op

* inherit base ops

* add "v2"

* move OPMaker

* Gradually add fused_matmul files

* second batch of fused_matmul changes

* split infershapes of matmul_v2 and fused_matmul

* merge code from other PR

* 2023

* inherit fused_matmul from matmul_v2

* Update paddle/phi/backends/onednn/onednn_reuse.h
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>

* Update paddle/phi/kernels/fusion/onednn/fused_matmul_kernel.cc
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>

* resolve conflicts

* codestyle

* simplify isgemmlinear

* 2023

* remove import

* reuse methods

* matmul_v2_mkldnn cleanup

* simplify ExecuteMatMulV1Grad

* matmul refactored

* fc

* SetOutMemDescWithLogicalLayoutFusesSupport

* matmul_v2

* alpha support

* group repetetive funcs

* matmul utils

* execute matmul methods

* restore registered kernel names

* split header and impl files

* remove double negatives

* reduce numer of modified files

* adjust ExecuteMatmul

* add scales for ut

* dates

* limit number of modified files

* fluid imports

* remove alpha

* codestyle

---------
Co-authored-by: NTomasz Socha <tomasz.socha@intel.com>
上级 a7ec8958
...@@ -81,8 +81,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph, ...@@ -81,8 +81,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
// currently. The conditions below are used to verify wether matmul_v2 // currently. The conditions below are used to verify wether matmul_v2
// is created by paddle.nn.Linear // is created by paddle.nn.Linear
auto matmul_op_desc = matmul_op->Op(); auto matmul_op_desc = matmul_op->Op();
if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape, matmul_op_desc)) if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape)) return;
return;
bool trans_x, trans_y; bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y); GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y);
...@@ -165,8 +164,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd( ...@@ -165,8 +164,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
// currently. The conditions below are used to verify wether matmul_v2 // currently. The conditions below are used to verify wether matmul_v2
// is created by paddle.nn.Linear // is created by paddle.nn.Linear
auto matmul_op_desc = matmul_op->Op(); auto matmul_op_desc = matmul_op->Op();
if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape, matmul_op_desc)) if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape)) return;
return;
auto activation = act_op->Op()->Type(); auto activation = act_op->Op()->Type();
...@@ -291,9 +289,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, ...@@ -291,9 +289,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
// currently. The conditions below are used to verify wether matmul_v2 // currently. The conditions below are used to verify wether matmul_v2
// is created by paddle.nn.Linear // is created by paddle.nn.Linear
auto matmul_grad_op_desc = matmul_grad_op->Op(); auto matmul_grad_op_desc = matmul_grad_op->Op();
if (!IsGemmFromLinear_( if (!IsGemmFromLinear_(matmul_grad_x_shape, matmul_grad_w_shape)) return;
matmul_grad_x_shape, matmul_grad_w_shape, matmul_grad_op_desc))
return;
bool trans_x, trans_y; bool trans_x, trans_y;
GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y); GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y);
...@@ -430,9 +426,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( ...@@ -430,9 +426,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
// currently. The conditions below are used to verify wether matmul_v2 // currently. The conditions below are used to verify wether matmul_v2
// is created by paddle.nn.Linear // is created by paddle.nn.Linear
auto matmul_grad_op_desc = matmul_grad_op->Op(); auto matmul_grad_op_desc = matmul_grad_op->Op();
if (!IsGemmFromLinear_( if (!IsGemmFromLinear_(matmul_grad_x_shape, matmul_grad_w_shape)) return;
matmul_grad_x_shape, matmul_grad_w_shape, matmul_grad_op_desc))
return;
auto activation_grad = act_grad_op->Op()->Type(); auto activation_grad = act_grad_op->Op()->Type();
...@@ -509,22 +503,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( ...@@ -509,22 +503,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
bool FuseGemmEpiloguePass::IsGemmFromLinear_( bool FuseGemmEpiloguePass::IsGemmFromLinear_(
const std::vector<int64_t> &x_shape, const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &w_shape, const std::vector<int64_t> &w_shape) const {
OpDesc *matmul_v2_op) const { return (w_shape.size() == 2 && x_shape.size() >= 2);
if (w_shape.size() != 2 || x_shape.size() < 2) return false;
for (auto attr_name : {"fused_reshape_Out",
"fused_reshape_X",
"fused_reshape_Y",
"fused_transpose_Out",
"fused_transpose_X",
"fused_transpose_Y"}) {
if (matmul_v2_op->HasAttr(attr_name)) {
std::vector<int> tmp_vec =
PADDLE_GET_CONST(std::vector<int>, matmul_v2_op->GetAttr(attr_name));
if (tmp_vec.size() > 0) return false;
}
}
return true;
} }
} // namespace ir } // namespace ir
......
...@@ -90,8 +90,7 @@ class FuseGemmEpiloguePass : public FusePassBase { ...@@ -90,8 +90,7 @@ class FuseGemmEpiloguePass : public FusePassBase {
private: private:
bool IsGemmFromLinear_(const std::vector<int64_t> &x_shape, bool IsGemmFromLinear_(const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &w_shape, const std::vector<int64_t> &w_shape) const;
OpDesc *matmul_v2_op) const;
const std::string GetReserveSpaceCacheKey(const std::string var_name, const std::string GetReserveSpaceCacheKey(const std::string var_name,
int block_id) const { int block_id) const {
return std::to_string(block_id) + var_name; return std::to_string(block_id) + var_name;
......
...@@ -75,28 +75,4 @@ extra { ...@@ -75,28 +75,4 @@ extra {
name: "force_fp32_output" name: "force_fp32_output"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "fused_reshape_Out"
type: INTS
}
attrs {
name: "fused_reshape_X"
type: INTS
}
attrs {
name: "fused_reshape_Y"
type: INTS
}
attrs {
name: "fused_transpose_Out"
type: INTS
}
attrs {
name: "fused_transpose_X"
type: INTS
}
attrs {
name: "fused_transpose_Y"
type: INTS
}
} }
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -167,12 +167,6 @@ void GetLinearOpGrad(const std::vector<T> &x_vec, ...@@ -167,12 +167,6 @@ void GetLinearOpGrad(const std::vector<T> &x_vec,
dout_ptr, dout_vec.data(), size_z * sizeof(T), cudaMemcpyHostToDevice); dout_ptr, dout_vec.data(), size_z * sizeof(T), cudaMemcpyHostToDevice);
bool use_mkldnn = false; bool use_mkldnn = false;
std::vector<int> fused_reshape_X = {};
std::vector<int> fused_reshape_Y = {};
std::vector<int> fused_reshape_Out = {};
std::vector<int> fused_transpose_X = {};
std::vector<int> fused_transpose_Y = {};
std::vector<int> fused_transpose_Out = {};
bool use_quantizer = false, force_fp32_output = false; bool use_quantizer = false, force_fp32_output = false;
std::string mkldnn_data_type = "float32"; std::string mkldnn_data_type = "float32";
float Scale_x = 1.0, Scale_y = 1.0, Scale_out = 1.0; float Scale_x = 1.0, Scale_y = 1.0, Scale_out = 1.0;
...@@ -182,12 +176,6 @@ void GetLinearOpGrad(const std::vector<T> &x_vec, ...@@ -182,12 +176,6 @@ void GetLinearOpGrad(const std::vector<T> &x_vec,
attrs.insert({"transpose_Y", transpose_b}); attrs.insert({"transpose_Y", transpose_b});
attrs.insert({"alpha", alpha}); attrs.insert({"alpha", alpha});
attrs.insert({"use_mkldnn", use_mkldnn}); attrs.insert({"use_mkldnn", use_mkldnn});
attrs.insert({"fused_reshape_X", fused_reshape_X});
attrs.insert({"fused_reshape_Y", fused_reshape_Y});
attrs.insert({"fused_reshape_Out", fused_reshape_Out});
attrs.insert({"fused_transpose_X", fused_transpose_X});
attrs.insert({"fused_transpose_Y", fused_transpose_Y});
attrs.insert({"fused_transpose_Out", fused_transpose_Out});
attrs.insert({"use_quantizer", use_quantizer}); attrs.insert({"use_quantizer", use_quantizer});
attrs.insert({"mkldnn_data_type", mkldnn_data_type}); attrs.insert({"mkldnn_data_type", mkldnn_data_type});
attrs.insert({"Scale_x", Scale_x}); attrs.insert({"Scale_x", Scale_x});
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -16,9 +16,6 @@ limitations under the License. */ ...@@ -16,9 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -303,7 +300,7 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -303,7 +300,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
bool transpose_y = context.Attr<bool>("transpose_Y"); bool transpose_y = context.Attr<bool>("transpose_Y");
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims; phi::DDim dx_dims;
if (dx) { if (dx) {
dx_dims = dx->dims(); dx_dims = dx->dims();
if (dx_dims != x.dims()) { if (dx_dims != x.dims()) {
...@@ -311,7 +308,7 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -311,7 +308,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
} }
} }
framework::DDim dy_dims; phi::DDim dy_dims;
if (dy) { if (dy) {
dy_dims = dy->dims(); dy_dims = dy->dims();
if (dy_dims != y.dims()) { if (dy_dims != y.dims()) {
...@@ -346,23 +343,15 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -346,23 +343,15 @@ class MatMulGradKernel : public framework::OpKernel<T> {
} }
}; };
framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, phi::DDim GetDimForInput(const framework::InferShapeContext &ctx,
std::string input_name) { std::string input_name) {
auto shape = ctx.Attrs().Get<std::vector<int>>("fused_reshape_" + input_name);
auto axis =
ctx.Attrs().Get<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.GetInputDim(input_name); auto dim = ctx.GetInputDim(input_name);
PADDLE_ENFORCE_GT(dim.size(), PADDLE_ENFORCE_GT(dim.size(),
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The " "The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].", "shape of Input(%s) = [%s].",
dim)); dim));
if (!shape.empty() && !axis.empty()) {
dim = dim.reshape(shape).transpose(axis);
}
return dim; return dim;
} }
...@@ -449,7 +438,7 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> { ...@@ -449,7 +438,7 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> {
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims; phi::DDim dx_dims;
if (dx) { if (dx) {
dx_dims = dx->dims(); dx_dims = dx->dims();
if (dx_dims != x.dims()) { if (dx_dims != x.dims()) {
...@@ -457,7 +446,7 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> { ...@@ -457,7 +446,7 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> {
} }
} }
framework::DDim dy_dims; phi::DDim dy_dims;
if (dy) { if (dy) {
dy_dims = dy->dims(); dy_dims = dy->dims();
if (dy_dims != y.dims()) { if (dy_dims != y.dims()) {
...@@ -465,7 +454,7 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> { ...@@ -465,7 +454,7 @@ class MatMulDoubleGradKernel : public framework::OpKernel<T> {
} }
} }
framework::DDim ddout_dims; phi::DDim ddout_dims;
if (ddout) { if (ddout) {
ddout_dims = ddout->dims(); ddout_dims = ddout->dims();
if (ddout_dims != dout.dims()) { if (ddout_dims != dout.dims()) {
...@@ -617,7 +606,7 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -617,7 +606,7 @@ class MatMulOp : public framework::OperatorWithKernel {
mat_dim_x.batch_size_ == mat_dim_y.batch_size_ || mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0, mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0,
true, true,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The batch size of the two matrices should be equal, or " "The batch size of the two matrices should be equal, or "
"at least one is zero.\n" "at least one is zero.\n"
"But received X's shape: %s, Y's shape: %s.", "But received X's shape: %s, Y's shape: %s.",
...@@ -633,7 +622,7 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -633,7 +622,7 @@ class MatMulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
head_number, head_number,
mat_dim_x.width_, mat_dim_x.width_,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Unsatisfied mkl acceleration library requirements: " "Unsatisfied mkl acceleration library requirements: "
"The number of heads " "The number of heads "
"(%d) must be equal to X's width. But received X's shape: %s.", "(%d) must be equal to X's width. But received X's shape: %s.",
...@@ -647,7 +636,7 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -647,7 +636,7 @@ class MatMulOp : public framework::OperatorWithKernel {
#else #else
PADDLE_ENFORCE_EQ(mat_dim_x.width_, PADDLE_ENFORCE_EQ(mat_dim_x.width_,
mat_dim_y.height_, mat_dim_y.height_,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input X's width should be equal to the Y's height, " "Input X's width should be equal to the Y's height, "
"but received X's shape: [%s], " "but received X's shape: [%s], "
"Y's shape: [%s].", "Y's shape: [%s].",
...@@ -681,16 +670,8 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -681,16 +670,8 @@ class MatMulOp : public framework::OperatorWithKernel {
dim_out = {1}; dim_out = {1};
} }
framework::DDim ddim_out = phi::make_ddim(dim_out); phi::DDim ddim_out = phi::make_ddim(dim_out);
#ifdef PADDLE_WITH_MKLDNN
auto shape = context->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto axis = context->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!shape.empty() && !axis.empty()) {
ddim_out = ddim_out.transpose(axis).reshape(shape);
}
#endif
context->SetOutputDim("Out", ddim_out); context->SetOutputDim("Out", ddim_out);
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
} }
...@@ -749,34 +730,6 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -749,34 +730,6 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Indicates if MKL-DNN kernel will be used") "(bool, default false) Indicates if MKL-DNN kernel will be used")
.SetDefault(false) .SetDefault(false)
.AsExtra(); .AsExtra();
AddAttr<std::vector<int>>("fused_reshape_X",
R"DOC(Shape of fused reshape of `X` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_reshape_Y",
R"DOC(Shape of fused reshape of `Y` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_transpose_X",
R"DOC(Axis of fused transpose of `X` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_transpose_Y",
R"DOC(Axis of fused transpose of `Y` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>(
"fused_reshape_Out",
R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, "
"it's a shape attribute of fused reshape for `Out` output.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>(
"fused_transpose_Out",
R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, "
"it's a axis attribute of fused transpose for `Out` output.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<bool>( AddAttr<bool>(
"use_quantizer", "use_quantizer",
"(bool, default false) " "(bool, default false) "
......
...@@ -71,16 +71,6 @@ phi::DenseTensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx, ...@@ -71,16 +71,6 @@ phi::DenseTensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx,
return output; return output;
} }
phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<phi::DenseTensor>(input_name)->dims();
if (!shape.empty() && !axis.empty()) {
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
}
template <typename XT, typename YT, typename OT> template <typename XT, typename YT, typename OT>
class MatMulV1OneDNNHandler class MatMulV1OneDNNHandler
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> { : public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
...@@ -91,10 +81,7 @@ class MatMulV1OneDNNHandler ...@@ -91,10 +81,7 @@ class MatMulV1OneDNNHandler
const std::vector<int64_t> &x_org_dims, const std::vector<int64_t> &x_org_dims,
bool trans_x, bool trans_x,
const std::vector<int64_t> &y_org_dims, const std::vector<int64_t> &y_org_dims,
bool trans_y, bool trans_y)
bool is_output_fused,
const std::vector<int64_t> &x_strides_override,
const std::vector<int64_t> &y_strides_override)
: phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine, : phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) { cpu_place) {
// M X K * K X N // M X K * K X N
...@@ -121,47 +108,27 @@ class MatMulV1OneDNNHandler ...@@ -121,47 +108,27 @@ class MatMulV1OneDNNHandler
y_strides.reserve(x_dims.size()); y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size()); out_strides.reserve(x_dims.size());
if (x_strides_override.empty()) { if (trans_x) {
if (trans_x) { x_strides.insert(x_strides.end(), {M * K, 1, M});
x_strides.insert(x_strides.end(), {M * K, 1, M});
} else {
x_strides.insert(x_strides.end(), {M * K, K, 1});
}
} else { } else {
x_strides = x_strides_override; x_strides.insert(x_strides.end(), {M * K, K, 1});
} }
if (trans_y) {
if (y_strides_override.empty()) { y_strides.insert(y_strides.end(), {N * K, 1, K});
if (trans_y) {
y_strides.insert(y_strides.end(), {N * K, 1, K});
} else {
y_strides.insert(y_strides.end(), {N * K, N, 1});
}
} else { } else {
y_strides = y_strides_override; y_strides.insert(y_strides.end(), {N * K, N, 1});
} }
out_strides.insert(out_strides.end(), {M * N, N, 1}); out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(), out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N}); {std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
for (int i = x_dims.size() - 4; i >= 0; --i) { for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]); out_ddims[i] = std::max(x_dims[i], y_dims[i]);
if (x_strides_override.empty()) { x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
if (y_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
} }
// TODO(jczaja): Why not for int8??
if (!phi::funcs::is_int8<OT>() && is_output_fused) {
std::vector<int> transpose_axis = {0, 2, 1, 3};
out_strides = phi::funcs::FakeTransposeStrides(out_ddims, transpose_axis);
}
auto x_md = auto x_md =
memory::desc(x_dims, phi::funcs::OneDNNGetDataType<XT>(), x_strides); memory::desc(x_dims, phi::funcs::OneDNNGetDataType<XT>(), x_strides);
auto y_md = auto y_md =
...@@ -191,34 +158,10 @@ class MatMulV1OneDNNHandler ...@@ -191,34 +158,10 @@ class MatMulV1OneDNNHandler
dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext &ctx) { dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext &ctx) {
dnnl::primitive_attr matmul_attrs; dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float scale_out = ComputeOutputScale(ctx); float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) { if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out}); matmul_attrs.set_output_scales(0, {scale_out});
} }
if (ctx.HasInput("ResidualData")) {
auto *residual_data = ctx.Input<phi::DenseTensor>("ResidualData");
auto residual_data_tz = phi::vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz,
phi::funcs::OneDNNGetDataType<OT>(),
dnnl::memory::format_tag::any);
post_operations.append_binary(dnnl::algorithm::binary_add,
residual_data_md);
if (ctx.HasAttr("Scale_in_eltwise")) {
float sum_scale = scale_out / ctx.Attr<float>("Scale_in_eltwise");
post_operations.append_sum(sum_scale);
}
}
if (ctx.HasAttr("fused_output_scale")) {
float scale_alpha = ctx.Attr<float>("fused_output_scale");
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs; return matmul_attrs;
} }
...@@ -272,10 +215,10 @@ class MatMulOneDNNHandler ...@@ -272,10 +215,10 @@ class MatMulOneDNNHandler
memory::dims out_dims = {out_bs, M, N}; memory::dims out_dims = {out_bs, M, N};
memory::dims x_strides = memory::dims x_strides =
!trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M}; trans_x ? memory::dims{M * K, 1, M} : memory::dims{M * K, K, 1};
memory::dims y_strides = memory::dims y_strides =
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K}; trans_y ? memory::dims{N * K, 1, K} : memory::dims{N * K, N, 1};
memory::dims out_strides = memory::dims{M * N, N, 1}; memory::dims out_strides = memory::dims{M * N, N, 1};
auto x_md = memory::desc(x_dims, OneDNNGetDataType<XT>(), x_strides); auto x_md = memory::desc(x_dims, OneDNNGetDataType<XT>(), x_strides);
...@@ -364,12 +307,6 @@ void ReshapeXYOutToMatrixSequence(phi::DenseTensor *x, ...@@ -364,12 +307,6 @@ void ReshapeXYOutToMatrixSequence(phi::DenseTensor *x,
ReshapeTensorToMatrixSequence(y, mat_dim_y); ReshapeTensorToMatrixSequence(y, mat_dim_y);
} }
bool IsOutputFused(const ExecutionContext &ctx) {
auto &fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto &fused_transpose_Out = ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}
template <typename T, typename T_out> template <typename T, typename T_out>
void ExecuteMatMulV1(const ExecutionContext &ctx, void ExecuteMatMulV1(const ExecutionContext &ctx,
const dnnl::engine onednn_engine, const dnnl::engine onednn_engine,
...@@ -380,29 +317,8 @@ void ExecuteMatMulV1(const ExecutionContext &ctx, ...@@ -380,29 +317,8 @@ void ExecuteMatMulV1(const ExecutionContext &ctx,
const std::vector<int64_t> &y_dims, const std::vector<int64_t> &y_dims,
bool trans_y, bool trans_y,
phi::DenseTensor *out) { phi::DenseTensor *out) {
std::vector<int64_t> x_strides_override = phi::funcs::GetInputStrides( MatMulV1OneDNNHandler<T, T, T_out> handler(
"X", ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims, trans_y);
x->dims(),
trans_x,
ctx.Attr<std::vector<int>>("fused_reshape_X"),
ctx.Attr<std::vector<int>>("fused_transpose_X"));
std::vector<int64_t> y_strides_override = phi::funcs::GetInputStrides(
"Y",
y->dims(),
trans_y,
ctx.Attr<std::vector<int>>("fused_reshape_Y"),
ctx.Attr<std::vector<int>>("fused_transpose_Y"));
MatMulV1OneDNNHandler<T, T, T_out> handler(ctx,
onednn_engine,
ctx.GetPlace(),
x_dims,
trans_x,
y_dims,
trans_y,
IsOutputFused(ctx),
x_strides_override,
y_strides_override);
const auto src_memory_p = handler.AcquireSrcMemory(x); const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y); const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out); const auto dst_memory_p = handler.AcquireDstMemory(out);
...@@ -414,27 +330,12 @@ void ExecuteMatMulV1(const ExecutionContext &ctx, ...@@ -414,27 +330,12 @@ void ExecuteMatMulV1(const ExecutionContext &ctx,
{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
if (ctx.HasInput("ResidualData")) {
auto *residual_data = ctx.Input<phi::DenseTensor>("ResidualData");
const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*residual_data_memory_p});
}
auto &astream = OneDNNContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
// TODO(jczaja): Explain why int8 format of dst is ABCD and do not need out->set_mem_desc(
// permute dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
if (IsOutputFused(ctx) && !phi::funcs::is_int8<T_out>()) {
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_Out");
auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
out->set_mem_desc(permuted_md.reshape(vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
} }
template <typename T> template <typename T>
...@@ -462,13 +363,11 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -462,13 +363,11 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto *x = ctx.Input<phi::DenseTensor>("X"); auto *x = ctx.Input<phi::DenseTensor>("X");
auto *y = ctx.Input<phi::DenseTensor>("Y"); auto *y = ctx.Input<phi::DenseTensor>("Y");
auto *out = ctx.Output<phi::DenseTensor>("Out"); auto *out = ctx.Output<phi::DenseTensor>("Out");
bool trans_x = ctx.HasAttr("trans_x") ? ctx.Attr<bool>("trans_x") bool trans_x = ctx.Attr<bool>("transpose_X");
: ctx.Attr<bool>("transpose_X"); bool trans_y = ctx.Attr<bool>("transpose_Y");
bool trans_y = ctx.HasAttr("trans_y") ? ctx.Attr<bool>("trans_y")
: ctx.Attr<bool>("transpose_Y");
auto x_dims = vectorize(GetDimForInput(ctx, "X")); auto x_dims = vectorize(x->dims());
auto y_dims = vectorize(GetDimForInput(ctx, "Y")); auto y_dims = vectorize(y->dims());
int ndims = std::max(x_dims.size(), y_dims.size()); int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3); ndims = std::max(ndims, 3);
...@@ -539,7 +438,7 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -539,7 +438,7 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
} }
} }
if (!IsOutputFused(ctx) && x_dims.size() > 2 && y_dims.size() > 2) { if (x_dims.size() > 2 && y_dims.size() > 2) {
auto out_dims = vectorize(out->dims()); auto out_dims = vectorize(out->dims());
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) { for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -587,12 +486,8 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -587,12 +486,8 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto *dy = auto *dy =
ctx.Output<phi::DenseTensor>(paddle::framework::GradVarName("Y")); ctx.Output<phi::DenseTensor>(paddle::framework::GradVarName("Y"));
bool transpose_x = ctx.HasAttr("transpose_X") bool transpose_x = ctx.Attr<bool>("transpose_X");
? ctx.Attr<bool>("transpose_X") bool transpose_y = ctx.Attr<bool>("transpose_Y");
: ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.HasAttr("transpose_Y")
? ctx.Attr<bool>("transpose_Y")
: ctx.Attr<bool>("trans_y");
ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
...@@ -696,14 +591,14 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -696,14 +591,14 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
out->dims().size() == 2; out->dims().size() == 2;
phi::DenseTensor x_combined, y_combined; phi::DenseTensor x_combined, y_combined;
if (!need_combine) { if (need_combine) {
x_combined = *x;
y_combined = *y;
} else {
x_combined = is_fold_init_dims_x ? FoldOuterDims(*x) x_combined = is_fold_init_dims_x ? FoldOuterDims(*x)
: FoldFirstAndLastDims<T>(dev_ctx, x); : FoldFirstAndLastDims<T>(dev_ctx, x);
y_combined = is_fold_init_dims_y ? FoldOuterDims(*y) y_combined = is_fold_init_dims_y ? FoldOuterDims(*y)
: FoldFirstAndLastDims<T>(dev_ctx, y); : FoldFirstAndLastDims<T>(dev_ctx, y);
} else {
x_combined = *x;
y_combined = *y;
} }
MatMulOneDNNHandler<T, T, T> handler(engine, MatMulOneDNNHandler<T, T, T> handler(engine,
......
...@@ -1283,9 +1283,7 @@ ...@@ -1283,9 +1283,7 @@
outputs : outputs :
out : Out out : Out
extra : extra :
attrs : [bool use_mkldnn = false, 'int[] fused_reshape_Out = {}', 'int[] fused_transpose_Out = {}', attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]
str mkldnn_data_type = "float32", 'int[] fused_reshape_X = {}', 'int[] fused_reshape_Y = {}',
'int[] fused_transpose_X = {}', 'int[] fused_transpose_Y = {}']
- op : matmul_with_flatten (mul) - op : matmul_with_flatten (mul)
backward : matmul_with_flatten_grad (mul_grad) backward : matmul_with_flatten_grad (mul_grad)
......
...@@ -93,23 +93,7 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> { ...@@ -93,23 +93,7 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
auto y_md = memory::desc(y_dims, OneDNNGetDataType<YT>(), y_strides); auto y_md = memory::desc(y_dims, OneDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_ddims, OneDNNGetDataType<OT>(), out_strides); auto out_md = memory::desc(out_ddims, OneDNNGetDataType<OT>(), out_strides);
const auto matmul_attrs = CreateMatmulAttrs(dev_ctx); this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
}
dnnl::primitive_attr CreateMatmulAttrs(const OneDNNContext& dev_ctx) {
dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float scale_out = dev_ctx.HasDnnAttr("alpha")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("alpha"))
: 1.0f;
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
} }
std::shared_ptr<memory> AcquireWeightsMemory(const DenseTensor* input) { std::shared_ptr<memory> AcquireWeightsMemory(const DenseTensor* input) {
...@@ -175,27 +159,6 @@ inline void ExecuteMatmul(const OneDNNContext& dev_ctx, ...@@ -175,27 +159,6 @@ inline void ExecuteMatmul(const OneDNNContext& dev_ctx,
bool trans_x, bool trans_x,
bool trans_y, bool trans_y,
DenseTensor* out) { DenseTensor* out) {
auto shape_x = dev_ctx.HasDnnAttr("fused_reshape_X")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_X"))
: std::vector<int>();
auto axis_x = dev_ctx.HasDnnAttr("fused_transpose_X")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_X"))
: std::vector<int>();
auto shape_y = dev_ctx.HasDnnAttr("fused_reshape_Y")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_Y"))
: std::vector<int>();
auto axis_y = dev_ctx.HasDnnAttr("fused_transpose_Y")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_Y"))
: std::vector<int>();
auto x_strides_override =
GetInputStrides("X", x.dims(), trans_x, shape_x, shape_x);
auto y_strides_override =
GetInputStrides("Y", y.dims(), trans_y, shape_y, axis_y);
MatmulOneDNNHandler<T, T, T_out> handler( MatmulOneDNNHandler<T, T, T_out> handler(
dev_ctx, x_dims, y_dims, trans_x, trans_y); dev_ctx, x_dims, y_dims, trans_x, trans_y);
......
...@@ -28,25 +28,6 @@ using phi::ReshapeToMatrix; ...@@ -28,25 +28,6 @@ using phi::ReshapeToMatrix;
namespace phi { namespace phi {
DDim GetDimsForInput(const OneDNNContext &dev_ctx,
DDim input_dims,
std::string input_name) {
auto shape =
dev_ctx.HasDnnAttr("fused_reshape_" + input_name)
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_" + input_name))
: std::vector<int>();
auto axis = dev_ctx.HasDnnAttr("fused_transpose_" + input_name)
? PADDLE_GET_CONST(
std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_" + input_name))
: std::vector<int>();
if (!shape.empty() && !axis.empty()) {
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
}
void CalculateMatrixDims(const std::vector<int64_t> &x_dims, void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims, const std::vector<int64_t> &y_dims,
std::vector<int64_t> *x_bd_dims, std::vector<int64_t> *x_bd_dims,
...@@ -120,9 +101,8 @@ void MatmulKernel(const Context &dev_ctx, ...@@ -120,9 +101,8 @@ void MatmulKernel(const Context &dev_ctx,
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
: false; : false;
auto x_dims = vectorize(GetDimsForInput(dev_ctx, x.dims(), "X")); auto x_dims = vectorize(x.dims());
auto y_dims = vectorize(GetDimsForInput(dev_ctx, y.dims(), "Y")); auto y_dims = vectorize(y.dims());
int ndims = std::max(x_dims.size(), y_dims.size()); int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3); ndims = std::max(ndims, 3);
......
...@@ -93,12 +93,6 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest): ...@@ -93,12 +93,6 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest):
alpha=alpha, alpha=alpha,
transpose_X=transpose_X, transpose_X=transpose_X,
transpose_Y=transpose_Y, transpose_Y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[],
) )
add_op = OpConfig( add_op = OpConfig(
......
...@@ -96,12 +96,6 @@ class TestMapMatmulToMulPass(PassAutoScanTest): ...@@ -96,12 +96,6 @@ class TestMapMatmulToMulPass(PassAutoScanTest):
alpha=alpha, alpha=alpha,
transpose_X=transpose_X, transpose_X=transpose_X,
transpose_Y=transpose_Y, transpose_Y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[],
) )
ops = [ ops = [
......
...@@ -84,12 +84,6 @@ class TestMatmulScaleFusePass(PassAutoScanTest): ...@@ -84,12 +84,6 @@ class TestMatmulScaleFusePass(PassAutoScanTest):
transpose_X=transpose_X, transpose_X=transpose_X,
transpose_Y=transpose_Y, transpose_Y=transpose_Y,
alpha=alpha, alpha=alpha,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[],
head_number=1, head_number=1,
) )
is_scale_tensor = draw(st.booleans()) is_scale_tensor = draw(st.booleans())
......
...@@ -113,12 +113,6 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest): ...@@ -113,12 +113,6 @@ class TestScaleMatmulMkldnnFusePass(PassAutoScanTest):
'transpose_X': attrs[1]['transpose_X'], 'transpose_X': attrs[1]['transpose_X'],
'transpose_Y': attrs[1]['transpose_Y'], 'transpose_Y': attrs[1]['transpose_Y'],
'alpha': attrs[1]['alpha'], 'alpha': attrs[1]['alpha'],
'fused_reshape_X': [],
'fused_reshape_Y': [],
'fused_transpose_X': [],
'fused_transpose_Y': [],
'fused_reshape_Out': [],
'fused_transpose_Out': [],
}, },
}, },
] ]
......
...@@ -118,12 +118,6 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest): ...@@ -118,12 +118,6 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest):
alpha=0.125, alpha=0.125,
transpose_X=False, transpose_X=False,
transpose_Y=False, transpose_Y=False,
fused_reshape_Out=[],
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_Out=[],
fused_transpose_X=[],
fused_transpose_Y=[],
) )
ele_3 = OpConfig( ele_3 = OpConfig(
"elementwise_add", "elementwise_add",
...@@ -151,12 +145,6 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest): ...@@ -151,12 +145,6 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest):
alpha=1.0, alpha=1.0,
transpose_X=False, transpose_X=False,
transpose_Y=False, transpose_Y=False,
fused_reshape_Out=[],
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_Out=[],
fused_transpose_X=[],
fused_transpose_Y=[],
) )
transpose_3 = OpConfig( transpose_3 = OpConfig(
"transpose2", "transpose2",
......
...@@ -263,12 +263,6 @@ class TestMultiheadMatmulRoformerFusePass(PassAutoScanTest): ...@@ -263,12 +263,6 @@ class TestMultiheadMatmulRoformerFusePass(PassAutoScanTest):
alpha=1.0, alpha=1.0,
transpose_X=False, transpose_X=False,
transpose_Y=True, transpose_Y=True,
fused_reshape_Out=[],
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_Out=[],
fused_transpose_X=[],
fused_transpose_Y=[],
) )
ele_3 = OpConfig( ele_3 = OpConfig(
"elementwise_add", "elementwise_add",
...@@ -316,12 +310,6 @@ class TestMultiheadMatmulRoformerFusePass(PassAutoScanTest): ...@@ -316,12 +310,6 @@ class TestMultiheadMatmulRoformerFusePass(PassAutoScanTest):
alpha=1.0, alpha=1.0,
transpose_X=False, transpose_X=False,
transpose_Y=False, transpose_Y=False,
fused_reshape_Out=[],
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_Out=[],
fused_transpose_X=[],
fused_transpose_Y=[],
) )
ops = [ ops = [
mul_0, mul_0,
......
...@@ -67,15 +67,9 @@ class TestOneDNNMatmulTransposeReshapeFusePass(PassAutoScanTest): ...@@ -67,15 +67,9 @@ class TestOneDNNMatmulTransposeReshapeFusePass(PassAutoScanTest):
inputs={'X': ['input_data1'], 'Y': ['input_data2']}, inputs={'X': ['input_data1'], 'Y': ['input_data2']},
outputs={'Out': ['matmul_output']}, outputs={'Out': ['matmul_output']},
attrs={ attrs={
'transpose_X': transpose_X, "transpose_X": transpose_X,
'transpose_Y': transpose_Y, "transpose_Y": transpose_Y,
'alpha': alpha, "alpha": alpha,
'fused_reshape_X': [],
'fused_reshape_Y': [],
'fused_transpose_X': [],
'fused_transpose_Y': [],
'fused_reshape_Out': [],
'fused_transpose_Out': [],
}, },
) )
......
...@@ -124,12 +124,6 @@ class TestOneDNNReshapeTransposeMatmulFusePass(PassAutoScanTest): ...@@ -124,12 +124,6 @@ class TestOneDNNReshapeTransposeMatmulFusePass(PassAutoScanTest):
'transpose_X': attrs[2]['transpose_X'], 'transpose_X': attrs[2]['transpose_X'],
'transpose_Y': attrs[2]['transpose_Y'], 'transpose_Y': attrs[2]['transpose_Y'],
'alpha': attrs[2]['alpha'], 'alpha': attrs[2]['alpha'],
'fused_reshape_X': [],
'fused_reshape_Y': [],
'fused_transpose_X': [],
'fused_transpose_Y': [],
'fused_reshape_Out': [],
'fused_transpose_Out': [],
}, },
}, },
] ]
......
...@@ -102,12 +102,6 @@ class TestReshape2MatmulFusePass(PassAutoScanTest): ...@@ -102,12 +102,6 @@ class TestReshape2MatmulFusePass(PassAutoScanTest):
alpha=alpha, alpha=alpha,
transpose_X=transpose_X, transpose_X=transpose_X,
transpose_Y=transpose_Y, transpose_Y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[],
) )
add_op = OpConfig( add_op = OpConfig(
......
...@@ -56,12 +56,6 @@ class TestSquaredMatSubFusePass(PassAutoScanTest): ...@@ -56,12 +56,6 @@ class TestSquaredMatSubFusePass(PassAutoScanTest):
"transpose_X": transpose_X, "transpose_X": transpose_X,
"transpose_Y": transpose_Y, "transpose_Y": transpose_Y,
"alpha": alpha1, "alpha": alpha1,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
) )
...@@ -94,12 +88,6 @@ class TestSquaredMatSubFusePass(PassAutoScanTest): ...@@ -94,12 +88,6 @@ class TestSquaredMatSubFusePass(PassAutoScanTest):
"transpose_X": transpose_X, "transpose_X": transpose_X,
"transpose_Y": transpose_Y, "transpose_Y": transpose_Y,
"alpha": alpha2, "alpha": alpha2,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
) )
......
...@@ -104,12 +104,6 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest): ...@@ -104,12 +104,6 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest):
alpha=alpha, alpha=alpha,
transpose_X=transpose_X, transpose_X=transpose_X,
transpose_Y=transpose_Y, transpose_Y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[],
) )
add_op = OpConfig( add_op = OpConfig(
......
...@@ -52,12 +52,6 @@ class TrtConvertMatmulTest_static(TrtLayerAutoScanTest): ...@@ -52,12 +52,6 @@ class TrtConvertMatmulTest_static(TrtLayerAutoScanTest):
"transpose_X": trans_x, "transpose_X": trans_x,
"transpose_Y": trans_y, "transpose_Y": trans_y,
"alpha": alpha, "alpha": alpha,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
} }
] ]
ops_config = [ ops_config = [
...@@ -143,12 +137,6 @@ class TrtConvertMatmulTest_dynamic(TrtLayerAutoScanTest): ...@@ -143,12 +137,6 @@ class TrtConvertMatmulTest_dynamic(TrtLayerAutoScanTest):
"transpose_X": trans_x, "transpose_X": trans_x,
"transpose_Y": trans_y, "transpose_Y": trans_y,
"alpha": alpha, "alpha": alpha,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
} }
] ]
ops_config = [ ops_config = [
......
...@@ -72,12 +72,6 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -72,12 +72,6 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest):
"alpha": 1.0, "alpha": 1.0,
"transpose_X": False, "transpose_X": False,
"transpose_Y": True, "transpose_Y": True,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
{"axis": axis}, {"axis": axis},
{"axis": -1, "is_test": True}, {"axis": -1, "is_test": True},
...@@ -92,12 +86,6 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -92,12 +86,6 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest):
"alpha": 1.0, "alpha": 1.0,
"transpose_X": False, "transpose_X": False,
"transpose_Y": False, "transpose_Y": False,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
{"axis": [0, 2, 1, 3]}, {"axis": [0, 2, 1, 3]},
{"shape": [0, 0, 768]}, {"shape": [0, 0, 768]},
...@@ -512,12 +500,6 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest): ...@@ -512,12 +500,6 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"alpha": 1.0, "alpha": 1.0,
"transpose_X": False, "transpose_X": False,
"transpose_Y": True, "transpose_Y": True,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
{"axis": axis}, {"axis": axis},
{"axis": -1, "is_test": True}, {"axis": -1, "is_test": True},
...@@ -532,12 +514,6 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest): ...@@ -532,12 +514,6 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"alpha": 1.0, "alpha": 1.0,
"transpose_X": False, "transpose_X": False,
"transpose_Y": False, "transpose_Y": False,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
{"axis": [0, 2, 1, 3]}, {"axis": [0, 2, 1, 3]},
{"shape": [0, 0, 768]}, {"shape": [0, 0, 768]},
...@@ -1142,12 +1118,6 @@ class TrtConvertMultiHeadMatmulTest_biasqk_seqseq(TrtLayerAutoScanTest): ...@@ -1142,12 +1118,6 @@ class TrtConvertMultiHeadMatmulTest_biasqk_seqseq(TrtLayerAutoScanTest):
"alpha": 1.0, "alpha": 1.0,
"transpose_X": False, "transpose_X": False,
"transpose_Y": True, "transpose_Y": True,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
{"axis": axis}, {"axis": axis},
{"axis": -1, "is_test": True}, {"axis": -1, "is_test": True},
...@@ -1162,12 +1132,6 @@ class TrtConvertMultiHeadMatmulTest_biasqk_seqseq(TrtLayerAutoScanTest): ...@@ -1162,12 +1132,6 @@ class TrtConvertMultiHeadMatmulTest_biasqk_seqseq(TrtLayerAutoScanTest):
"alpha": 1.0, "alpha": 1.0,
"transpose_X": False, "transpose_X": False,
"transpose_Y": False, "transpose_Y": False,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
{"axis": [0, 2, 1, 3]}, {"axis": [0, 2, 1, 3]},
{"shape": [0, 0, 768]}, {"shape": [0, 0, 768]},
......
...@@ -98,12 +98,6 @@ class TrtConvertMultiHeadMatmulRoformerTest(TrtLayerAutoScanTest): ...@@ -98,12 +98,6 @@ class TrtConvertMultiHeadMatmulRoformerTest(TrtLayerAutoScanTest):
"alpha": 1.0, "alpha": 1.0,
"transpose_X": False, "transpose_X": False,
"transpose_Y": True, "transpose_Y": True,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
{"axis": axis}, {"axis": axis},
{"axis": -1, "is_test": True}, {"axis": -1, "is_test": True},
...@@ -118,12 +112,6 @@ class TrtConvertMultiHeadMatmulRoformerTest(TrtLayerAutoScanTest): ...@@ -118,12 +112,6 @@ class TrtConvertMultiHeadMatmulRoformerTest(TrtLayerAutoScanTest):
"alpha": 1.0, "alpha": 1.0,
"transpose_X": False, "transpose_X": False,
"transpose_Y": False, "transpose_Y": False,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": [],
}, },
{"axis": [0, 2, 1, 3]}, {"axis": [0, 2, 1, 3]},
{"shape": [0, 0, 768]}, {"shape": [0, 0, 768]},
......
...@@ -111,12 +111,6 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest): ...@@ -111,12 +111,6 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest):
alpha=alpha, alpha=alpha,
transpose_X=transpose_X, transpose_X=transpose_X,
transpose_Y=transpose_Y, transpose_Y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[],
) )
add_op = OpConfig( add_op = OpConfig(
......
...@@ -116,12 +116,6 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest): ...@@ -116,12 +116,6 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest):
alpha=alpha, alpha=alpha,
transpose_X=transpose_X, transpose_X=transpose_X,
transpose_Y=transpose_Y, transpose_Y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[],
) )
add_op = OpConfig( add_op = OpConfig(
......
...@@ -17,10 +17,7 @@ import unittest ...@@ -17,10 +17,7 @@ import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.eager_op_test import ( from paddle.fluid.tests.unittests.eager_op_test import OpTest
OpTest,
skip_check_grad_ci,
)
class TestDnnlMatMulOp(OpTest): class TestDnnlMatMulOp(OpTest):
...@@ -257,321 +254,6 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp): ...@@ -257,321 +254,6 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp):
self.attrs = {'force_fp32_output': True} self.attrs = {'force_fp32_output': True}
@skip_check_grad_ci(reason="DNNL's MatMul doesn't implement grad kernel.")
class TestReshapeTransposeMatMulOp(OpTest):
def init_data_type(self):
self.data_type_ = 'float32'
def generate_data(self):
self.x = (
np.random.random([2, 128, 768])
.astype("float32")
.reshape([2, 128, 12, 64])
.transpose([0, 2, 1, 3])
)
self.y = (
np.random.random([2, 128, 768])
.astype("float32")
.reshape([2, 128, 12, 64])
.transpose([0, 2, 1, 3])
)
self.out = np.matmul(self.x, self.y.transpose([0, 1, 3, 2]))
self.fused_reshape_X = []
self.fused_transpose_X = []
self.fused_reshape_Y = []
self.fused_transpose_Y = []
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul"
self.transpose_y_name = "transpose_Y"
def setUp(self):
self.set_op_type_and_transpose_y_name()
self._cpu_only = True
self.use_mkldnn = True
self.transpose_y = True
self.init_data_type()
self.generate_data()
self.inputs = {'X': self.x, 'Y': self.y}
self.attrs = {
'use_mkldnn': self.use_mkldnn,
self.transpose_y_name: self.transpose_y,
}
if len(self.fused_transpose_X) > 0:
self.attrs['fused_transpose_X'] = self.fused_transpose_X
if len(self.fused_transpose_Y) > 0:
self.attrs['fused_transpose_Y'] = self.fused_transpose_Y
if len(self.fused_reshape_X) > 0:
self.attrs['fused_reshape_X'] = self.fused_reshape_X
if len(self.fused_reshape_Y) > 0:
self.attrs['fused_reshape_Y'] = self.fused_reshape_Y
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
class TestReshapeTransposeMatMulOp4DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32")
self.y = (
np.random.random([2, 128, 768])
.astype("float32")
.reshape([2, 128, 12, 64])
.transpose([0, 2, 1, 3])
)
self.fused_transpose_X = [0, 2, 1, 3]
self.fused_reshape_X = [0, 0, 12, 64]
self.fused_transpose_Y = []
self.fused_reshape_Y = []
self.out = np.matmul(
self.x.reshape([2, 128, 12, 64]).transpose([0, 2, 1, 3]),
self.y.transpose([0, 1, 3, 2]),
)
class TestReshapeTransposeMatMulOp4DXInt8(TestReshapeTransposeMatMulOp4DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp4DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = (
np.random.random([2, 128, 768])
.astype("float32")
.reshape([2, 128, 12, 64])
.transpose([0, 2, 1, 3])
)
self.y = np.random.random([2, 128, 768]).astype("float32")
self.fused_transpose_X = []
self.fused_reshape_X = []
self.fused_transpose_Y = [0, 2, 1, 3]
self.fused_reshape_Y = [0, 0, 12, 64]
self.out = np.matmul(
self.x, self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1])
)
class TestReshapeTransposeMatMulOp4DYInt8(TestReshapeTransposeMatMulOp4DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp4DXYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32")
self.y = np.random.random([2, 128, 768]).astype("float32")
self.fused_transpose_X = [0, 2, 1, 3]
self.fused_reshape_X = [0, 0, 12, 64]
self.fused_transpose_Y = [0, 2, 1, 3]
self.fused_reshape_Y = [0, 0, 12, 64]
self.out = np.matmul(
self.x.reshape([2, 128, 12, 64]).transpose([0, 2, 1, 3]),
self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1]),
)
class TestReshapeTransposeMatMulOp4DXYInt8(
TestReshapeTransposeMatMulOp4DXYFloat
):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp2DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 5, 10]).astype("float32")
self.y = (
np.random.random([2, 5, 10])
.astype("float32")
.reshape([10, 10])
.transpose([1, 0])
)
self.fused_transpose_X = [1, 0]
self.fused_reshape_X = [10, 10]
self.fused_transpose_Y = []
self.fused_reshape_Y = []
self.out = np.matmul(
self.x.reshape([10, 10]).transpose([1, 0]), self.y.transpose([1, 0])
)
class TestReshapeTransposeMatMulOp2DXInt8(TestReshapeTransposeMatMulOp2DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp2DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = (
np.random.random([2, 5, 10])
.astype("float32")
.reshape([10, 10])
.transpose([1, 0])
)
self.y = np.random.random([2, 5, 10]).astype("float32")
self.fused_transpose_X = []
self.fused_reshape_X = []
self.fused_transpose_Y = [1, 0]
self.fused_reshape_Y = [10, 10]
self.out = np.matmul(self.x, self.y.reshape([10, 10]))
class TestReshapeTransposeMatMulOp2DYInt8(TestReshapeTransposeMatMulOp2DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp3DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 2, 5, 5]).astype("float32")
self.y = (
np.random.random([2, 2, 5, 5])
.astype("float32")
.reshape([2, 10, 5])
.transpose([0, 2, 1])
)
self.fused_transpose_X = [0, 2, 1]
self.fused_reshape_X = [2, 10, 5]
self.fused_transpose_Y = []
self.fused_reshape_Y = []
self.out = np.matmul(
self.x.reshape([2, 10, 5]).transpose(0, 2, 1),
self.y.transpose(0, 2, 1),
)
class TestReshapeTransposeMatMulOp3DXInt8(TestReshapeTransposeMatMulOp3DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp3DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = (
np.random.random([2, 2, 5, 5])
.astype(self.data_type_)
.reshape([2, 10, 5])
.transpose([0, 2, 1])
)
self.y = np.random.random([2, 2, 5, 5]).astype(self.data_type_)
self.fused_transpose_X = []
self.fused_reshape_X = []
self.fused_transpose_Y = [0, 2, 1]
self.fused_reshape_Y = [2, 10, 5]
self.out = np.matmul(self.x, self.y.reshape([2, 10, 5]))
class TestReshapeTransposeMatMulOp3DYInt8(TestReshapeTransposeMatMulOp3DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
@skip_check_grad_ci(reason="Tests inference only optimization.")
class TestMatMulOpTransposeReshapeEmptyFloat(OpTest):
def init_data_type(self):
self.data_type_ = np.float32
def generate_data(self):
self.bs = 1
self.x = np.random.random([self.bs, 128, 128]).astype(self.data_type_)
self.y = np.random.random([self.bs, 128, 64]).astype(self.data_type_)
def init_params_and_out(self):
self.transpose_out = []
self.reshape_out = []
self.out = np.matmul(self.x, self.y)
def set_op_type(self):
self.op_type = "matmul"
def setUp(self):
self.set_op_type()
self._cpu_only = True
self.use_mkldnn = True
self.init_data_type()
self.generate_data()
self.init_params_and_out()
self.inputs = {'X': self.x, 'Y': self.y}
self.attrs = {'use_mkldnn': self.use_mkldnn}
if len(self.reshape_out) > 0:
self.attrs['fused_reshape_Out'] = self.reshape_out
if len(self.transpose_out) > 0:
self.attrs['fused_transpose_Out'] = self.transpose_out
self.inputs = {'X': self.x, 'Y': self.y}
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
def check_raise_error(self, msg):
try:
self.check_output()
except Exception as e:
if msg in str(e):
raise AttributeError
else:
print(e)
class TestMatMulOpTransposeReshapeIntEmptyInt(
TestMatMulOpTransposeReshapeEmptyFloat
):
def init_data_type(self):
self.data_type_ = np.int8
class TestMatMulOpTransposeReshapeBasicFloat(
TestMatMulOpTransposeReshapeEmptyFloat
):
def generate_data(self):
self.bs = 8
self.x = np.random.random([self.bs, 12, 128, 128]).astype(
self.data_type_
)
self.y = np.random.random([self.bs, 12, 128, 64]).astype(
self.data_type_
)
def init_params_and_out(self):
self.transpose_out = [0, 2, 1, 3]
self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]]
self.out = (
np.matmul(self.x, self.y)
.transpose([0, 2, 1, 3])
.reshape([self.bs, -1, self.x.shape[1] * self.y.shape[-1]])
)
class TestMatMulOpTransposeReshapeBasicInt(
TestMatMulOpTransposeReshapeBasicFloat
):
def init_data_type(self):
self.data_type_ = np.int8
class TestMatMulOpTransposeReshapeOtherDimFloat(
TestMatMulOpTransposeReshapeBasicFloat
):
def generate_data(self):
self.bs = 11
self.x = np.random.random([self.bs, 12, 14, 18]).astype(self.data_type_)
self.y = np.random.random([self.bs, 12, 18, 13]).astype(self.data_type_)
class TestMatMulOpTransposeReshapeOtherDimInt(
TestMatMulOpTransposeReshapeOtherDimFloat
):
def init_data_type(self):
self.data_type_ = np.int8
if __name__ == "__main__": if __name__ == "__main__":
from paddle import enable_static from paddle import enable_static
......
...@@ -15,19 +15,6 @@ ...@@ -15,19 +15,6 @@
import unittest import unittest
import numpy as np import numpy as np
from test_matmul_mkldnn_op import (
TestMatMulOpTransposeReshapeBasicFloat,
TestMatMulOpTransposeReshapeEmptyFloat,
TestMatMulOpTransposeReshapeOtherDimFloat,
TestReshapeTransposeMatMulOp,
TestReshapeTransposeMatMulOp2DXFloat,
TestReshapeTransposeMatMulOp2DYFloat,
TestReshapeTransposeMatMulOp3DXFloat,
TestReshapeTransposeMatMulOp3DYFloat,
TestReshapeTransposeMatMulOp4DXFloat,
TestReshapeTransposeMatMulOp4DXYFloat,
TestReshapeTransposeMatMulOp4DYFloat,
)
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -472,89 +459,6 @@ create_bf16_test_class(TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp) ...@@ -472,89 +459,6 @@ create_bf16_test_class(TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp)
create_bf16_test_class(TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp)
class TestMatMulV2OpTransposeReshapeEmptyFloat(
TestMatMulOpTransposeReshapeEmptyFloat
):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeBasicFloat(
TestMatMulOpTransposeReshapeBasicFloat
):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeOtherDimFloat(
TestMatMulOpTransposeReshapeOtherDimFloat
):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpReshapeTranspose(TestReshapeTransposeMatMulOp):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose4DXFloat(
TestReshapeTransposeMatMulOp4DXFloat
):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose4DYFloat(
TestReshapeTransposeMatMulOp4DYFloat
):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose4DXYFloat(
TestReshapeTransposeMatMulOp4DXYFloat
):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose2DXFloat(
TestReshapeTransposeMatMulOp2DXFloat
):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose2DYFloat(
TestReshapeTransposeMatMulOp2DYFloat
):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose3DXFloat(
TestReshapeTransposeMatMulOp3DXFloat
):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose3DYFloat(
TestReshapeTransposeMatMulOp3DYFloat
):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册