未验证 提交 fc002405 编写于 作者: W Wojciech Uss 提交者: GitHub

A fix for oneDNN matmul kernel. Fixes issue #30309 (#30723)

上级 46989e88
......@@ -188,34 +188,34 @@ class MatMulFactory {
memory::dims strides_y;
std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y");
const auto x_bs = mat_dim_x.batch_size_;
const auto y_bs = mat_dim_y.batch_size_;
auto x_bs = mat_dim_x.batch_size_;
auto y_bs = mat_dim_y.batch_size_;
PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false,
platform::errors::InvalidArgument(
"If batch sizes of X and Y are positive,"
"they have to be equal."));
// Store 1 if both batches are zero, otherwise save the nonzero batch
const memory::dim BS = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_;
const memory::dim K = mat_dim_x.width_;
batch_size_ = 1;
auto b = BS;
if (BS > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
if (out_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
auto& x_dims = ctx.Input<Tensor>("X")->dims();
auto& y_dims = ctx.Input<Tensor>("Y")->dims();
batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0];
b = BS / batch_size_;
x_bs /= batch_size_;
y_bs /= batch_size_;
out_bs /= batch_size_;
}
memory::dims x_dims = {b, M, K};
memory::dims y_dims = {b, K, N};
memory::dims out_dims = {b, M, N};
memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
memory::dims out_dims = {out_bs, M, N};
x_offset_ = b * M * K * sizeof(XT);
y_offset_ = b * K * N * sizeof(YT);
out_offset_ = b * M * N * sizeof(OT);
x_offset_ = x_bs * M * K * sizeof(XT);
y_offset_ = y_bs * K * N * sizeof(YT);
out_offset_ = out_bs * M * N * sizeof(OT);
// Translate transA and transB
if (strides_x.empty())
......@@ -226,7 +226,7 @@ class MatMulFactory {
: memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};
CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides);
CorrectStridesWhenFloatOutputFused(ctx, N, out_bs, &out_strides);
return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides};
}
......
......@@ -48,6 +48,20 @@ class TestDnnlMatMulOp(OpTest):
self.check_output()
class TestDnnlMatMulOpMixedDims1(TestDnnlMatMulOp):
def generate_data(self):
self.x = np.random.random((17, 2, 3)).astype("float32")
self.y = np.random.random((3, 4)).astype("float32")
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpMixedDims2(TestDnnlMatMulOp):
def generate_data(self):
self.x = np.random.random((2, 3)).astype("float32")
self.y = np.random.random((17, 3, 4)).astype("float32")
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpAlpha(TestDnnlMatMulOp):
def generate_data(self):
self.x = np.random.random((17, 2, 3)).astype("float32")
......@@ -396,10 +410,10 @@ class TestMatMulOpTransposeReshapeBasicFloat(
TestMatMulOpTransposeReshapeEmptyFloat):
def generate_data(self):
self.bs = 8
self.x = np.random.random(
[self.bs, 12, 128, 128]).astype(self.data_type_)
self.y = np.random.random(
[self.bs, 12, 128, 64]).astype(self.data_type_)
self.x = np.random.random([self.bs, 12, 128,
128]).astype(self.data_type_)
self.y = np.random.random([self.bs, 12, 128,
64]).astype(self.data_type_)
def init_params_and_out(self):
self.transpose_out = [0, 2, 1, 3]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册