未验证 提交 e2a3a6f7 编写于 作者: J jakpiase 提交者: GitHub

Added oneDNN matmul grad BF16/FP32 kernel (#32968)

* added support for most matmul cases

* added more functionality

* full functionality of matmul op, fp32 only

* added bf16 tests and functionality

* added formatting

* changes after review

* minor change

* added reviewers suggestions
上级 79d918d9
......@@ -825,6 +825,21 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
context->SetOutputDim(y_grad_name, y_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace platform {
......@@ -37,6 +37,111 @@ using platform::MKLDNNGetDataType;
using platform::to_void_cast;
using Tensor = framework::Tensor;
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
static framework::Tensor FoldOuterDims(const Tensor& input) {
auto output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
}
return output;
}
// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename T>
static framework::Tensor FoldFirstAndLastDims(
const MKLDNNDeviceContext& dev_ctx, const Tensor* input) {
auto input_dims = framework::vectorize(input->dims());
if (input_dims.size() != 3) {
return *input;
}
framework::Tensor output;
output.Resize({input_dims[1], input_dims[0], input_dims[2]});
auto output_dims = framework::vectorize(output.dims());
memory::data_type input_type = framework::ToMKLDNNDataType(input->type());
std::string key = platform::CreateKey(dev_ctx, input_dims, input->format(),
input->format(), input_type);
platform::ReorderMKLDNNHandler reorder_handler(output_dims, input->type(),
input_type, dev_ctx,
dev_ctx.GetEngine(), key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
memory::format_tag::abc, platform::to_void_cast(input->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
&output, memory::format_tag::bac, dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
output.Resize({input_dims[1], input_dims[0] * input_dims[2]});
return output;
}
template <typename T>
class MatMulMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::matmul> {
public:
MatMulMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
Tensor* x, bool trans_x, Tensor* y, bool trans_y,
Tensor* out, float scale, const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
auto mat_dim_x = math::CreateMatrixDescriptor(x->dims(), 0, trans_x);
auto mat_dim_y = math::CreateMatrixDescriptor(y->dims(), 0, trans_y);
memory::dim x_bs = mat_dim_x.batch_size_;
memory::dim y_bs = mat_dim_y.batch_size_;
memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_;
const memory::dim K = mat_dim_x.width_;
memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
memory::dims out_dims = {out_bs, M, N};
memory::dims x_strides =
!trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M};
memory::dims y_strides =
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<T>(), out_strides);
dnnl::primitive_attr attrs;
if (scale != 1.0f) attrs.set_output_scales(0, {scale});
this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
}
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data),
"@weights_mem_p");
}
};
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
......@@ -44,7 +149,7 @@ constexpr bool IsInt8() {
template <typename T>
constexpr bool IsBfloat16() {
return std::is_same<T, paddle::platform::bfloat16>::value;
return std::is_same<T, platform::bfloat16>::value;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
......@@ -60,6 +165,60 @@ static framework::DDim ColumnMatrixDimsFromVector(
return y_dim.size() > 1 ? y_dim : framework::make_ddim({y_dim[0], 1});
}
/**
* Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
*
* The shape would be [BatchSize, H, W] or [H, W].
* If transposed, `H,W` will be swapped.
*/
static void ReshapeTensorToMatrixSequence(
framework::Tensor* x, const math::MatDescriptor& descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
if (descriptor.batch_size_) {
x->Resize({descriptor.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
/**
* Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor
* Out = matmul(x, y)
*
* This method will first calculate X,Y matrix sequence, and then calculate
* the out shape.
*
* Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2]
* The out = [BatchSize, H1, W2]
*
* If there is no batch size in `X` and `Y`, the out will be [H1, W2]
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
*/
static void ReshapeXYOutToMatrixSequence(framework::Tensor* x,
framework::Tensor* y,
framework::Tensor* out, bool trans_x,
bool trans_y) {
auto x_dim = RowMatrixDimsFromVector(x->dims());
auto y_dim = ColumnMatrixDimsFromVector(y->dims());
auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_, mat_dim_y.width_});
}
ReshapeTensorToMatrixSequence(x, mat_dim_x);
ReshapeTensorToMatrixSequence(y, mat_dim_y);
}
template <typename XT, typename YT, typename OT>
class MatMulFactory {
public:
......@@ -372,7 +531,7 @@ static void ExecuteMatMul(const ExecutionContext& ctx) {
template <typename T>
class DNNLMatMulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
void Compute(const ExecutionContext& ctx) const override {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1,
......@@ -385,6 +544,137 @@ class DNNLMatMulKernel : public framework::OpKernel<T> {
ExecuteMatMul<T, T>(ctx);
}
};
template <typename T>
class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1,
platform::errors::Unimplemented(
"DNNL matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
RunKernel<T>(ctx);
}
private:
void ExecuteMatMulGrad(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y,
bool is_fold_init_dims_y, Tensor* out,
int execution_number) const {
// gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2;
Tensor x_combined, y_combined;
if (!need_combine) {
x_combined = *x;
y_combined = *y;
} else {
x_combined = is_fold_init_dims_x ? FoldOuterDims(*x)
: FoldFirstAndLastDims<T>(dev_ctx, x);
y_combined = is_fold_init_dims_y ? FoldOuterDims(*y)
: FoldFirstAndLastDims<T>(dev_ctx, y);
}
MatMulMKLDNNHandler<T> handler(
dev_ctx, engine, ctx.GetPlace(), &x_combined, trans_x, &y_combined,
trans_y, out, ctx.Attr<float>("alpha"),
ctx.InputName(framework::GradVarName("Out")) +
std::to_string(execution_number));
const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
const auto dst_memory_p = handler.AcquireDstMemory(out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape(
framework::vectorize<int64_t>(out->dims()))));
}
template <typename Tout = T>
void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto x = *ctx.Input<Tensor>("X");
auto y = *ctx.Input<Tensor>("Y");
auto dout = *ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
bool transpose_x = ctx.Attr<bool>("transpose_X");
bool transpose_y = ctx.Attr<bool>("transpose_Y");
ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
if (transpose_x && transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true,
&dout, true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true,
&x, true, false, dy, 1);
} else if (transpose_x) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false,
&dout, true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false,
&dout, false, true, dy, 1);
} else if (transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, false, true, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true,
&x, false, true, dy, 1);
} else {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, true, false, dx, 0);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true,
&dout, false, true, dy, 1);
}
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
......@@ -394,3 +684,7 @@ REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace,
ops::DNNLMatMulKernel<paddle::platform::bfloat16>,
ops::DNNLMatMulKernel<int8_t>,
ops::DNNLMatMulKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulGradMKLDNNKernel<float>,
ops::MatMulGradMKLDNNKernel<paddle::platform::bfloat16>);
......@@ -26,22 +26,23 @@ from paddle import enable_static
"place does not support BF16 evaluation")
class TestMatmulBf16MklDNNOp(OpTest):
def generate_data(self):
self.x = np.random.random((25, 2, 2)).astype(np.float32)
self.y = np.random.random((25, 2, 2)).astype(np.float32)
self.alpha = 1.0
self.out = self.alpha * np.matmul(self.x, self.y)
self.x_fp32 = np.random.random((25, 2, 2)).astype(np.float32)
self.y_fp32 = np.random.random((25, 2, 2)).astype(np.float32)
self.out = self.alpha * np.matmul(self.x_fp32, self.y_fp32)
def set_attributes(self):
self.alpha = self.alpha if hasattr(self, 'alpha') else 1.0
self.attrs = {
'alpha': self.alpha,
"use_mkldnn": self.use_mkldnn,
"mkldnn_data_type": self.mkldnn_data_type,
"force_fp32_output": self.force_fp32_output
"force_fp32_output": self.force_fp32_output,
'transpose_X': False,
'transpose_Y': False
}
def setUp(self):
self.op_type = "matmul"
self.alpha = 1.0
self.use_mkldnn = True
self.dtype = np.uint16
self.mkldnn_data_type = "bfloat16"
......@@ -53,67 +54,113 @@ class TestMatmulBf16MklDNNOp(OpTest):
self.out = convert_float_to_uint16(self.out)
self.outputs = {'Out': self.out}
self.x = convert_float_to_uint16(self.x)
self.y = convert_float_to_uint16(self.y)
self.inputs = {'X': self.x, 'Y': self.y}
self.x_bf16 = convert_float_to_uint16(self.x_fp32)
self.y_bf16 = convert_float_to_uint16(self.y_fp32)
self.inputs = {'X': self.x_bf16, 'Y': self.y_bf16}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
check_dygraph=False,
user_defined_grads=[self.dx, self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
def matmul_grad(self, x, transpose_x, y, transpose_y):
x_transpose_axes = [1, 0] if x.ndim == 2 else [0, 2, 1]
y_transpose_axes = [1, 0] if y.ndim == 2 else [0, 2, 1]
x = np.transpose(x, x_transpose_axes) if transpose_x else x
y = np.transpose(y, y_transpose_axes) if transpose_y else y
return self.alpha * np.matmul(x, y)
def calculate_grads(self):
x_transpose_axes = [1, 0] if self.x_fp32.ndim == 2 else [0, 2, 1]
y_transpose_axes = [1, 0] if self.y_fp32.ndim == 2 else [0, 2, 1]
x = np.transpose(self.x_fp32, x_transpose_axes) if self.attrs[
'transpose_X'] is True else self.x_fp32
y = np.transpose(self.y_fp32, y_transpose_axes) if self.attrs[
'transpose_Y'] is True else self.y_fp32
dout = self.alpha * np.matmul(x, y)
if self.attrs['transpose_X'] is True and self.attrs[
'transpose_Y'] is True:
self.dx = self.matmul_grad(self.y_fp32, True, dout, True)
self.dy = self.matmul_grad(dout, True, self.x_fp32, True)
elif self.attrs['transpose_X'] is True and self.attrs[
'transpose_Y'] is False:
self.dx = self.matmul_grad(self.y_fp32, False, dout, True)
self.dy = self.matmul_grad(self.x_fp32, False, dout, False)
elif self.attrs['transpose_X'] is False and self.attrs[
'transpose_Y'] is True:
self.dx = self.matmul_grad(dout, False, self.y_fp32, False)
self.dy = self.matmul_grad(dout, True, self.x_fp32, False)
else:
self.dx = self.matmul_grad(dout, False, self.y_fp32, True)
self.dy = self.matmul_grad(self.x_fp32, True, dout, False)
self.dout = dout
class TestDnnlMatMulOpAlpha(TestMatmulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.random((17, 2, 3)).astype(np.float32)
self.y = np.random.random((17, 3, 2)).astype(np.float32)
self.x_fp32 = np.random.random((17, 2, 3)).astype(np.float32)
self.y_fp32 = np.random.random((17, 3, 2)).astype(np.float32)
self.alpha = 2.0
self.out = self.alpha * np.matmul(self.x, self.y)
self.out = self.alpha * np.matmul(self.x_fp32, self.y_fp32)
class TestDnnlMatMulOp2D(TestMatmulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype(np.float32)
self.y = np.random.random((9, 12)).astype(np.float32)
self.out = np.matmul(self.x, self.y)
self.x_fp32 = np.random.random((12, 9)).astype(np.float32)
self.y_fp32 = np.random.random((9, 12)).astype(np.float32)
self.out = np.matmul(self.x_fp32, self.y_fp32)
class TestDnnlMatMulOpTransposeX(TestMatmulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype(np.float32)
self.y = np.random.random((12, 9)).astype(np.float32)
self.out = np.matmul(np.transpose(self.x), self.y)
self.x_fp32 = np.random.random((12, 9)).astype(np.float32)
self.y_fp32 = np.random.random((12, 9)).astype(np.float32)
self.out = np.matmul(np.transpose(self.x_fp32), self.y_fp32)
def set_attributes(self):
self.attrs = {
"use_mkldnn": self.use_mkldnn,
"mkldnn_data_type": self.mkldnn_data_type,
'transpose_X': True
'transpose_X': True,
'transpose_Y': False
}
class TestDnnlMatMulOpTransposeY(TestMatmulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype(np.float32)
self.y = np.random.random((12, 9)).astype(np.float32)
self.out = np.matmul(self.x, np.transpose(self.y))
self.x_fp32 = np.random.random((12, 9)).astype(np.float32)
self.y_fp32 = np.random.random((12, 9)).astype(np.float32)
self.out = np.matmul(self.x_fp32, np.transpose(self.y_fp32))
def set_attributes(self):
self.attrs = {
"use_mkldnn": self.use_mkldnn,
"mkldnn_data_type": self.mkldnn_data_type,
'transpose_Y': True
'transpose_Y': True,
'transpose_X': False
}
class TestMatmulBf16MklDNNForceFp32Output(TestMatmulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype(np.float32)
self.y = np.random.random((9, 12)).astype(np.float32)
self.x_fp32 = np.random.random((12, 9)).astype(np.float32)
self.y_fp32 = np.random.random((9, 12)).astype(np.float32)
self.force_fp32_output = True
self.alpha = 0.5
self.out = self.alpha * np.matmul(self.x, self.y)
self.out = self.alpha * np.matmul(self.x_fp32, self.y_fp32)
if __name__ == "__main__":
......
......@@ -19,7 +19,6 @@ import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
@skip_check_grad_ci(reason="DNNL's MatMul doesn't implemend grad kernel.")
class TestDnnlMatMulOp(OpTest):
def generate_data(self):
self.x = np.random.random((25, 2, 2)).astype("float32")
......@@ -48,21 +47,99 @@ class TestDnnlMatMulOp(OpTest):
self.check_output()
class TestDnnlMatMulOpMixedDims1(TestDnnlMatMulOp):
class TestDnnlMatMulWithGradOp(TestDnnlMatMulOp):
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-2)
class TestDnnlMatMulOpMixedDims1(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((17, 2, 3)).astype("float32")
self.y = np.random.random((3, 4)).astype("float32")
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpMixedDims2(TestDnnlMatMulOp):
class TestDnnlMatMulOpMixedDimsYWiderTransposeY(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((8, 2, 3)).astype("float32")
self.y = np.random.random((4, 3)).astype("float32")
self.out = np.matmul(self.x, np.transpose(self.y))
def set_attributes(self):
self.attrs = {'transpose_Y': True}
class TestDnnlMatMulOpMixedDimsYWiderTransposeX(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((8, 3, 2)).astype("float32")
self.y = np.random.random((3, 4)).astype("float32")
self.out = np.matmul(np.transpose(self.x, (0, 2, 1)), self.y)
def set_attributes(self):
self.attrs = {'transpose_X': True}
class TestDnnlMatMulOpMixedDimsXWiderTransposeXY(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((8, 3, 2)).astype("float32")
self.y = np.random.random((4, 3)).astype("float32")
self.out = np.matmul(
np.transpose(self.x, (0, 2, 1)), np.transpose(self.y))
def set_attributes(self):
self.attrs = {'transpose_X': True, 'transpose_Y': True}
class TestDnnlMatMulOpMixedDimsYWiderTransposeXY(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((3, 2)).astype("float32")
self.y = np.random.random((8, 4, 3)).astype("float32")
self.out = np.matmul(
np.transpose(self.x), np.transpose(self.y, (0, 2, 1)))
def set_attributes(self):
self.attrs = {'transpose_X': True, 'transpose_Y': True}
class TestDnnlMatMulOpMixedDimsXWiderTransposeX(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((5, 4)).astype("float32")
self.y = np.random.random((8, 5, 4)).astype("float32")
self.out = np.matmul(np.transpose(self.x), self.y)
def set_attributes(self):
self.attrs = {'transpose_X': True}
class TestDnnlMatMulOpVectorMultiply(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((5)).astype("float32")
self.y = np.random.random((5)).astype("float32")
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpVectorMultiplyTranspose(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((5)).astype("float32")
x_resized = np.copy(self.x)
x_resized = np.expand_dims(x_resized, 1)
self.y = np.random.random((6)).astype("float32")
y_resized = np.copy(self.y)
y_resized = np.expand_dims(y_resized, 0)
self.out = np.matmul(x_resized, y_resized)
def set_attributes(self):
self.attrs = {'transpose_Y': True, 'transpose_X': True}
class TestDnnlMatMulOpMixedDims2(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((2, 3)).astype("float32")
self.y = np.random.random((17, 3, 4)).astype("float32")
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpAlpha(TestDnnlMatMulOp):
class TestDnnlMatMulOpAlpha(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((17, 2, 3)).astype("float32")
self.y = np.random.random((17, 3, 2)).astype("float32")
......@@ -70,18 +147,14 @@ class TestDnnlMatMulOpAlpha(TestDnnlMatMulOp):
self.out = self.alpha * np.matmul(self.x, self.y)
class TestDnnlMatMulOp2D(TestDnnlMatMulOp):
def print_tensor(self, name, tensor):
print(name)
print(tensor)
class TestDnnlMatMulOp2D(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype("float32")
self.y = np.random.random((9, 12)).astype("float32")
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpTransposeX(TestDnnlMatMulOp):
class TestDnnlMatMulOpTransposeX(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype("float32")
self.y = np.random.random((12, 9)).astype("float32")
......@@ -91,7 +164,7 @@ class TestDnnlMatMulOpTransposeX(TestDnnlMatMulOp):
self.attrs = {'transpose_X': True}
class TestDnnlMatMulOpTransposeY(TestDnnlMatMulOp):
class TestDnnlMatMulOpTransposeY(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype("float32")
self.y = np.random.random((12, 9)).astype("float32")
......@@ -101,7 +174,7 @@ class TestDnnlMatMulOpTransposeY(TestDnnlMatMulOp):
self.attrs = {'transpose_Y': True}
class TestDnnlMatMulOpTransposeY3D(TestDnnlMatMulOp):
class TestDnnlMatMulOpTransposeY3D(TestDnnlMatMulWithGradOp):
def generate_data(self):
self.x = np.random.random((17, 3, 2)).astype("float32")
self.y = np.random.random((17, 3, 2)).astype("float32")
......@@ -480,4 +553,6 @@ class TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException(
if __name__ == "__main__":
from paddle import enable_static
enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册