未验证 提交 f5007051 编写于 作者: W Wojciech Uss 提交者: GitHub

A fix for oneDNN matmul kernel. Fixes issue #30309 for oneDNN 1.6 (#31066)

* A fix for oneDNN matmul kernel. Fixes issue #30309 (#30723)

* A fix for #30309 with oneDNN 1.6
上级 36710ebc
...@@ -65,17 +65,13 @@ class MatMulFactory { ...@@ -65,17 +65,13 @@ class MatMulFactory {
public: public:
void CreateAndExecute(const ExecutionContext& ctx) { void CreateAndExecute(const ExecutionContext& ctx) {
SetDNNLEngine(ctx); SetDNNLEngine(ctx);
if (IsInitialized()) { if (!IsInitialized()) {
UpdateDataPointers(ctx); CreateMemories(ctx);
Execute(); CreatePrimitive(ctx);
SetOutputFormat(ctx); SetInitialized();
return;
} }
CreateMemories(ctx); Execute(ctx);
CreatePrimitive(ctx);
Execute();
SetOutputFormat(ctx); SetOutputFormat(ctx);
SetInitialized();
} }
private: private:
...@@ -181,41 +177,63 @@ class MatMulFactory { ...@@ -181,41 +177,63 @@ class MatMulFactory {
} }
MatMulDims GetMatmulDims(const ExecutionContext& ctx) { MatMulDims GetMatmulDims(const ExecutionContext& ctx) {
math::MatDescriptor mat_dim_x; math::MatDescriptor x_mat_dims;
memory::dims strides_x; memory::dims strides_x;
std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X"); std::tie(x_mat_dims, strides_x) = GetInputDimsAndStrides(ctx, "X");
math::MatDescriptor mat_dim_y; math::MatDescriptor y_mat_dims;
memory::dims strides_y; memory::dims strides_y;
std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y"); std::tie(y_mat_dims, strides_y) = GetInputDimsAndStrides(ctx, "Y");
auto x_mat_bs = x_mat_dims.batch_size_;
auto y_mat_bs = y_mat_dims.batch_size_;
PADDLE_ENFORCE_EQ(x_mat_bs > 0 && y_mat_bs > 0 && x_mat_bs != y_mat_bs,
false, platform::errors::InvalidArgument(
"If batch sizes of X and Y are positive,"
"they have to be equal."));
memory::dim out_mat_bs =
x_mat_bs || y_mat_bs ? std::max(x_mat_bs, y_mat_bs) : 1;
const memory::dim M = x_mat_dims.height_;
const memory::dim N = y_mat_dims.width_;
const memory::dim K = x_mat_dims.width_;
// Find total batch size of the data
const memory::dim x_bs = (x_mat_bs) ? x_mat_bs : 1;
const memory::dim y_bs = (y_mat_bs) ? y_mat_bs : 1;
const memory::dim total_bs = std::max(x_bs, y_bs);
// Find batch size for oneDNN primitive
memory::dim onednn_bs = std::min(x_bs, y_bs);
// Find the number of times the oneDNN primitive has to be executed
execute_loop_steps_ = total_bs / onednn_bs;
if (execute_loop_steps_ > 1) {
x_mat_bs /= execute_loop_steps_;
y_mat_bs /= execute_loop_steps_;
out_mat_bs /= execute_loop_steps_;
}
const auto x_bs = mat_dim_x.batch_size_; // Take original format batch size into account
const auto y_bs = mat_dim_y.batch_size_; if (out_mat_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false, auto x_orig_bs = ctx.Input<Tensor>("X")->dims()[0];
platform::errors::InvalidArgument( auto y_orig_bs = ctx.Input<Tensor>("Y")->dims()[0];
"If batch sizes of X and Y are positive," auto orig_bs = x_mat_bs > y_mat_bs ? x_orig_bs : y_orig_bs;
"they have to be equal.")); execute_loop_steps_ *= orig_bs;
onednn_bs /= orig_bs;
// Store 1 if both batches are zero, otherwise save the nonzero batch x_mat_bs /= orig_bs;
const memory::dim BS = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; y_mat_bs /= orig_bs;
const memory::dim M = mat_dim_x.height_; out_mat_bs /= orig_bs;
const memory::dim N = mat_dim_y.width_;
const memory::dim K = mat_dim_x.width_;
batch_size_ = 1;
auto b = BS;
if (BS > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
auto& x_dims = ctx.Input<Tensor>("X")->dims();
auto& y_dims = ctx.Input<Tensor>("Y")->dims();
batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0];
b = BS / batch_size_;
} }
memory::dims x_dims = {b, M, K};
memory::dims y_dims = {b, K, N};
memory::dims out_dims = {b, M, N};
x_offset_ = b * M * K * sizeof(XT); // Set dimensions for the oneDNN memories
y_offset_ = b * K * N * sizeof(YT); memory::dims x_dims = {onednn_bs, M, K};
out_offset_ = b * M * N * sizeof(OT); memory::dims y_dims = {onednn_bs, K, N};
memory::dims out_dims = {onednn_bs, M, N};
// Find data offsets for each oneDNN primitive execution step
x_offset_ = x_mat_bs * M * K * sizeof(XT);
y_offset_ = y_mat_bs * K * N * sizeof(YT);
out_offset_ = out_mat_bs * M * N * sizeof(OT);
// Translate transA and transB // Translate transA and transB
if (strides_x.empty()) if (strides_x.empty())
...@@ -226,7 +244,7 @@ class MatMulFactory { ...@@ -226,7 +244,7 @@ class MatMulFactory {
: memory::dims{N * K, 1, K}; : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1}; memory::dims out_strides = memory::dims{M * N, N, 1};
CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides); CorrectStridesWhenFloatOutputFused(ctx, N, out_mat_bs, &out_strides);
return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides}; return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides};
} }
...@@ -266,13 +284,15 @@ class MatMulFactory { ...@@ -266,13 +284,15 @@ class MatMulFactory {
matmul_prim_ = dnnl::matmul(matmul_pd); matmul_prim_ = dnnl::matmul(matmul_pd);
} }
void Execute() { void Execute(const ExecutionContext& ctx) {
dnnl::stream stream(engine_); dnnl::stream stream(engine_);
auto* x = ctx.Input<Tensor>("X");
void* x_ptr = x_mem_.get_data_handle(); auto* y = ctx.Input<Tensor>("Y");
void* y_ptr = y_mem_.get_data_handle(); auto* out = ctx.Output<Tensor>("Out");
void* out_ptr = out_mem_.get_data_handle(); void* x_ptr = to_void_cast(x->data<XT>());
for (uint16_t i = 0; i < batch_size_; i++) { void* y_ptr = to_void_cast(y->data<YT>());
void* out_ptr = to_void_cast(out->mutable_data<OT>(ctx.GetPlace()));
for (uint16_t i = 0; i < execute_loop_steps_; i++) {
x_mem_.set_data_handle(x_ptr); x_mem_.set_data_handle(x_ptr);
y_mem_.set_data_handle(y_ptr); y_mem_.set_data_handle(y_ptr);
out_mem_.set_data_handle(out_ptr); out_mem_.set_data_handle(out_ptr);
...@@ -297,15 +317,6 @@ class MatMulFactory { ...@@ -297,15 +317,6 @@ class MatMulFactory {
out->set_layout(DataLayout::kMKLDNN); out->set_layout(DataLayout::kMKLDNN);
} }
void UpdateDataPointers(const ExecutionContext& ctx) {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out");
x_mem_.set_data_handle(to_void_cast(x->data<XT>()));
y_mem_.set_data_handle(to_void_cast(y->data<YT>()));
out_mem_.set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
}
// If initialized, x memory should've been already initialized // If initialized, x memory should've been already initialized
bool IsInitialized() { return initialized_; } bool IsInitialized() { return initialized_; }
...@@ -326,7 +337,7 @@ class MatMulFactory { ...@@ -326,7 +337,7 @@ class MatMulFactory {
uint32_t x_offset_; uint32_t x_offset_;
uint32_t y_offset_; uint32_t y_offset_;
uint32_t out_offset_; uint32_t out_offset_;
uint16_t batch_size_; uint16_t execute_loop_steps_;
bool initialized_ = false; bool initialized_ = false;
}; };
......
...@@ -48,6 +48,20 @@ class TestDnnlMatMulOp(OpTest): ...@@ -48,6 +48,20 @@ class TestDnnlMatMulOp(OpTest):
self.check_output() self.check_output()
class TestDnnlMatMulOpMixedDims1(TestDnnlMatMulOp):
def generate_data(self):
self.x = np.random.random((17, 2, 3)).astype("float32")
self.y = np.random.random((3, 4)).astype("float32")
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpMixedDims2(TestDnnlMatMulOp):
def generate_data(self):
self.x = np.random.random((2, 3)).astype("float32")
self.y = np.random.random((17, 3, 4)).astype("float32")
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpAlpha(TestDnnlMatMulOp): class TestDnnlMatMulOpAlpha(TestDnnlMatMulOp):
def generate_data(self): def generate_data(self):
self.x = np.random.random((17, 2, 3)).astype("float32") self.x = np.random.random((17, 2, 3)).astype("float32")
...@@ -396,10 +410,10 @@ class TestMatMulOpTransposeReshapeBasicFloat( ...@@ -396,10 +410,10 @@ class TestMatMulOpTransposeReshapeBasicFloat(
TestMatMulOpTransposeReshapeEmptyFloat): TestMatMulOpTransposeReshapeEmptyFloat):
def generate_data(self): def generate_data(self):
self.bs = 8 self.bs = 8
self.x = np.random.random( self.x = np.random.random([self.bs, 12, 128,
[self.bs, 12, 128, 128]).astype(self.data_type_) 128]).astype(self.data_type_)
self.y = np.random.random( self.y = np.random.random([self.bs, 12, 128,
[self.bs, 12, 128, 64]).astype(self.data_type_) 64]).astype(self.data_type_)
def init_params_and_out(self): def init_params_and_out(self):
self.transpose_out = [0, 2, 1, 3] self.transpose_out = [0, 2, 1, 3]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册