未验证 提交 f5007051 编写于 作者: W Wojciech Uss 提交者: GitHub

A fix for oneDNN matmul kernel. Fixes issue #30309 for oneDNN 1.6 (#31066)

* A fix for oneDNN matmul kernel. Fixes issue #30309 (#30723)

* A fix for #30309 with oneDNN 1.6
上级 36710ebc
......@@ -65,18 +65,14 @@ class MatMulFactory {
public:
void CreateAndExecute(const ExecutionContext& ctx) {
SetDNNLEngine(ctx);
if (IsInitialized()) {
UpdateDataPointers(ctx);
Execute();
SetOutputFormat(ctx);
return;
}
if (!IsInitialized()) {
CreateMemories(ctx);
CreatePrimitive(ctx);
Execute();
SetOutputFormat(ctx);
SetInitialized();
}
Execute(ctx);
SetOutputFormat(ctx);
}
private:
struct MatMulDims {
......@@ -181,41 +177,63 @@ class MatMulFactory {
}
MatMulDims GetMatmulDims(const ExecutionContext& ctx) {
math::MatDescriptor mat_dim_x;
math::MatDescriptor x_mat_dims;
memory::dims strides_x;
std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X");
math::MatDescriptor mat_dim_y;
std::tie(x_mat_dims, strides_x) = GetInputDimsAndStrides(ctx, "X");
math::MatDescriptor y_mat_dims;
memory::dims strides_y;
std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y");
std::tie(y_mat_dims, strides_y) = GetInputDimsAndStrides(ctx, "Y");
const auto x_bs = mat_dim_x.batch_size_;
const auto y_bs = mat_dim_y.batch_size_;
PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false,
platform::errors::InvalidArgument(
auto x_mat_bs = x_mat_dims.batch_size_;
auto y_mat_bs = y_mat_dims.batch_size_;
PADDLE_ENFORCE_EQ(x_mat_bs > 0 && y_mat_bs > 0 && x_mat_bs != y_mat_bs,
false, platform::errors::InvalidArgument(
"If batch sizes of X and Y are positive,"
"they have to be equal."));
memory::dim out_mat_bs =
x_mat_bs || y_mat_bs ? std::max(x_mat_bs, y_mat_bs) : 1;
const memory::dim M = x_mat_dims.height_;
const memory::dim N = y_mat_dims.width_;
const memory::dim K = x_mat_dims.width_;
// Find total batch size of the data
const memory::dim x_bs = (x_mat_bs) ? x_mat_bs : 1;
const memory::dim y_bs = (y_mat_bs) ? y_mat_bs : 1;
const memory::dim total_bs = std::max(x_bs, y_bs);
// Find batch size for oneDNN primitive
memory::dim onednn_bs = std::min(x_bs, y_bs);
// Find the number of times the oneDNN primitive has to be executed
execute_loop_steps_ = total_bs / onednn_bs;
if (execute_loop_steps_ > 1) {
x_mat_bs /= execute_loop_steps_;
y_mat_bs /= execute_loop_steps_;
out_mat_bs /= execute_loop_steps_;
}
// Store 1 if both batches are zero, otherwise save the nonzero batch
const memory::dim BS = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_;
const memory::dim K = mat_dim_x.width_;
batch_size_ = 1;
auto b = BS;
if (BS > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
auto& x_dims = ctx.Input<Tensor>("X")->dims();
auto& y_dims = ctx.Input<Tensor>("Y")->dims();
batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0];
b = BS / batch_size_;
// Take original format batch size into account
if (out_mat_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
auto x_orig_bs = ctx.Input<Tensor>("X")->dims()[0];
auto y_orig_bs = ctx.Input<Tensor>("Y")->dims()[0];
auto orig_bs = x_mat_bs > y_mat_bs ? x_orig_bs : y_orig_bs;
execute_loop_steps_ *= orig_bs;
onednn_bs /= orig_bs;
x_mat_bs /= orig_bs;
y_mat_bs /= orig_bs;
out_mat_bs /= orig_bs;
}
memory::dims x_dims = {b, M, K};
memory::dims y_dims = {b, K, N};
memory::dims out_dims = {b, M, N};
x_offset_ = b * M * K * sizeof(XT);
y_offset_ = b * K * N * sizeof(YT);
out_offset_ = b * M * N * sizeof(OT);
// Set dimensions for the oneDNN memories
memory::dims x_dims = {onednn_bs, M, K};
memory::dims y_dims = {onednn_bs, K, N};
memory::dims out_dims = {onednn_bs, M, N};
// Find data offsets for each oneDNN primitive execution step
x_offset_ = x_mat_bs * M * K * sizeof(XT);
y_offset_ = y_mat_bs * K * N * sizeof(YT);
out_offset_ = out_mat_bs * M * N * sizeof(OT);
// Translate transA and transB
if (strides_x.empty())
......@@ -226,7 +244,7 @@ class MatMulFactory {
: memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};
CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides);
CorrectStridesWhenFloatOutputFused(ctx, N, out_mat_bs, &out_strides);
return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides};
}
......@@ -266,13 +284,15 @@ class MatMulFactory {
matmul_prim_ = dnnl::matmul(matmul_pd);
}
void Execute() {
void Execute(const ExecutionContext& ctx) {
dnnl::stream stream(engine_);
void* x_ptr = x_mem_.get_data_handle();
void* y_ptr = y_mem_.get_data_handle();
void* out_ptr = out_mem_.get_data_handle();
for (uint16_t i = 0; i < batch_size_; i++) {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out");
void* x_ptr = to_void_cast(x->data<XT>());
void* y_ptr = to_void_cast(y->data<YT>());
void* out_ptr = to_void_cast(out->mutable_data<OT>(ctx.GetPlace()));
for (uint16_t i = 0; i < execute_loop_steps_; i++) {
x_mem_.set_data_handle(x_ptr);
y_mem_.set_data_handle(y_ptr);
out_mem_.set_data_handle(out_ptr);
......@@ -297,15 +317,6 @@ class MatMulFactory {
out->set_layout(DataLayout::kMKLDNN);
}
void UpdateDataPointers(const ExecutionContext& ctx) {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out");
x_mem_.set_data_handle(to_void_cast(x->data<XT>()));
y_mem_.set_data_handle(to_void_cast(y->data<YT>()));
out_mem_.set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
}
// If initialized, x memory should've been already initialized
bool IsInitialized() { return initialized_; }
......@@ -326,7 +337,7 @@ class MatMulFactory {
uint32_t x_offset_;
uint32_t y_offset_;
uint32_t out_offset_;
uint16_t batch_size_;
uint16_t execute_loop_steps_;
bool initialized_ = false;
};
......
......@@ -48,6 +48,20 @@ class TestDnnlMatMulOp(OpTest):
self.check_output()
class TestDnnlMatMulOpMixedDims1(TestDnnlMatMulOp):
def generate_data(self):
self.x = np.random.random((17, 2, 3)).astype("float32")
self.y = np.random.random((3, 4)).astype("float32")
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpMixedDims2(TestDnnlMatMulOp):
def generate_data(self):
self.x = np.random.random((2, 3)).astype("float32")
self.y = np.random.random((17, 3, 4)).astype("float32")
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpAlpha(TestDnnlMatMulOp):
def generate_data(self):
self.x = np.random.random((17, 2, 3)).astype("float32")
......@@ -396,10 +410,10 @@ class TestMatMulOpTransposeReshapeBasicFloat(
TestMatMulOpTransposeReshapeEmptyFloat):
def generate_data(self):
self.bs = 8
self.x = np.random.random(
[self.bs, 12, 128, 128]).astype(self.data_type_)
self.y = np.random.random(
[self.bs, 12, 128, 64]).astype(self.data_type_)
self.x = np.random.random([self.bs, 12, 128,
128]).astype(self.data_type_)
self.y = np.random.random([self.bs, 12, 128,
64]).astype(self.data_type_)
def init_params_and_out(self):
self.transpose_out = [0, 2, 1, 3]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册