From 48f09caa5f86215d1e0706a620ad069b30d977ae Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 14 Apr 2020 09:57:56 +0800 Subject: [PATCH] Optimize matmul for size(x_dims)=2 size(y_dims)>2 (#3400) * Optimize matmul for size(x_dims)=2 size(y_dims)>2 --- lite/api/benchmark.cc | 2 +- lite/kernels/arm/matmul_compute.cc | 82 ++++++++++-------------------- lite/operators/matmul_op.cc | 77 +++++++++++++--------------- 3 files changed, 62 insertions(+), 99 deletions(-) diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc index 17932cc5cd..f82fcb87ff 100644 --- a/lite/api/benchmark.cc +++ b/lite/api/benchmark.cc @@ -213,7 +213,7 @@ void print_usage() { " --param_filename (The filename of param file, set param_file when\n" " the model is combined formate. Otherwise, it is not necessary\n" " to set it.) type: string \n" - " --input_shape (Tet input shapes according to the model, separated by\n" + " --input_shape (Set input shapes according to the model, separated by\n" " colon and comma, such as 1,3,244,244) type: string\n" " default: 1,3,224,224 \n" " --input_img_path (The path of input image, if not set\n" diff --git a/lite/kernels/arm/matmul_compute.cc b/lite/kernels/arm/matmul_compute.cc index 2841fa13f7..d22b14155a 100644 --- a/lite/kernels/arm/matmul_compute.cc +++ b/lite/kernels/arm/matmul_compute.cc @@ -45,32 +45,13 @@ void MatMulCompute::Run() { operators::ActivationParam act_param; act_param.has_active = false; - if (x_dims.size() > 2 && y_dims.size() >= 2) { + if ((x_dims.size() >= 2 && y_dims.size() >= 2) && + (x_dims.size() != 2 || y_dims.size() != 2)) { // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] // x: [B, M, K], y: [K, N], out: [B, M, N] - - if (!x_transpose && !y_transpose) { - CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << ") x_transpose is " << x_transpose << "y_transpose is " - << y_transpose; - } else if (!x_transpose && y_transpose) { - CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 1]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << ") x_transpose is " << x_transpose << "y_transpose is " - << y_transpose; - } else if (x_transpose && !y_transpose) { - CHECK_EQ(x_dims[x_dims.size() - 2], y_dims[y_dims.size() - 2]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << ") x_transpose is " << x_transpose << "y_transpose is " - << y_transpose; - } else { - CHECK_EQ(x_dims[x_dims.size() - 2], y_dims[y_dims.size() - 1]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << ") x_transpose is " << x_transpose << "y_transpose is " - << y_transpose; - } - + // or + // x: [M, K], y: [B, ..., K, N], out: [B, ..., M, N] + // x: [M, K], y: [B, K, N], out: [B, M, N] int lda, ldb, ldc; if (!x_transpose) { m_ = x_dims[x_dims.size() - 2]; @@ -96,11 +77,7 @@ void MatMulCompute::Run() { int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1]; int out_inner = o_dims[o_dims.size() - 2] * o_dims[o_dims.size() - 1]; - float* x_data_trans = nullptr; - if (x_transpose) { - x_data_trans = static_cast(malloc(sizeof(float) * x_inner)); - } - if (y_dims.size() > 2) { + if (x_dims.size() > 2 && y_dims.size() > 2) { for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { lite::arm::math::sgemm(x_transpose, y_transpose, @@ -120,7 +97,7 @@ void MatMulCompute::Run() { act_param, &ctx); } - } else { + } else if (x_dims.size() > 2 && y_dims.size() == 2) { for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { lite::arm::math::sgemm(x_transpose, y_transpose, @@ -140,34 +117,29 @@ void MatMulCompute::Run() { act_param, &ctx); } - } - if (x_data_trans) { - free(x_data_trans); + } else if (x_dims.size() == 2 && y_dims.size() > 2) { + for (size_t i = 0; i < y_dims.count(0, y_dims.size() - 2); ++i) { + lite::arm::math::sgemm(x_transpose, + y_transpose, + m_, + n_, + k_, + alpha, + x_data, + lda, + y_data + i * y_inner, + ldb, + 0.f, + o_data + i * out_inner, + ldc, + nullptr, + false, + act_param, + &ctx); + } } } else if (x_dims.size() == 2 && y_dims.size() == 2) { // x: [M, K], y: [K, N], out: [M, N] - if (!x_transpose && !y_transpose) { - CHECK_EQ(x_dims[1], y_dims[0]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << "), x_transpose is " << x_transpose << ", y_transpose is " - << y_transpose; - } else if (!x_transpose && y_transpose) { - CHECK_EQ(x_dims[1], y_dims[1]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << "), x_transpose is " << x_transpose << ", y_transpose is " - << y_transpose; - } else if (x_transpose && !y_transpose) { - CHECK_EQ(x_dims[0], y_dims[0]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << "), x_transpose is " << x_transpose << ", y_transpose is " - << y_transpose; - } else { - CHECK_EQ(x_dims[0], y_dims[1]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << "), x_transpose is " << x_transpose << ", y_transpose is " - << y_transpose; - } - int lda, ldb, ldc; if (!x_transpose) { m_ = x_dims[0]; diff --git a/lite/operators/matmul_op.cc b/lite/operators/matmul_op.cc index 1cdcdfa167..04a0fc97d7 100644 --- a/lite/operators/matmul_op.cc +++ b/lite/operators/matmul_op.cc @@ -24,19 +24,12 @@ bool MatMulOpLite::CheckShape() const { CHECK_OR_FALSE(param_.Y); CHECK_OR_FALSE(param_.Out); - return true; -} - -bool MatMulOpLite::InferShapeImpl() const { const auto x_dims = param_.X->dims(); const auto y_dims = param_.Y->dims(); bool x_transpose = param_.transpose_X; bool y_transpose = param_.transpose_Y; - std::vector dim_out_vec; - if (x_dims.size() > 2 && y_dims.size() >= 2) { - // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] - // x: [B, M, K], y: [K, N], out: [B, M, N] + if (x_dims.size() > 1 && y_dims.size() > 1) { if (!x_transpose && !y_transpose) { CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2]) << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims @@ -54,48 +47,49 @@ bool MatMulOpLite::InferShapeImpl() const { << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims << ")"; } + } else if (x_dims.size() > 2 && y_dims.size() == 1) { + CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0]) + << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims + << ")"; + } + return true; +} - dim_out_vec.resize(x_dims.size()); - for (size_t i = 0; i < x_dims.size() - 2; ++i) { - dim_out_vec[i] = x_dims[i]; +bool MatMulOpLite::InferShapeImpl() const { + const auto x_dims = param_.X->dims(); + const auto y_dims = param_.Y->dims(); + bool x_transpose = param_.transpose_X; + bool y_transpose = param_.transpose_Y; + std::vector dim_out_vec; + + if ((x_dims.size() >= 2 && y_dims.size() >= 2) && + (x_dims.size() != 2 || y_dims.size() != 2)) { + // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] + // x: [B, M, K], y: [K, N], out: [B, M, N] + // or + // x: [M, K], y: [B, ..., K, N], out: [B, ..., M, N] + // x: [M, K], y: [B, K, N], out: [B, M, N] + DDim dims = x_dims.size() >= y_dims.size() ? x_dims : y_dims; + dim_out_vec.resize(dims.size()); + for (size_t i = 0; i < dims.size() - 2; ++i) { + dim_out_vec[i] = dims[i]; } if (!x_transpose && !y_transpose) { - dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 2]; - dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 1]; + dim_out_vec[dims.size() - 2] = x_dims[x_dims.size() - 2]; + dim_out_vec[dims.size() - 1] = y_dims[y_dims.size() - 1]; } else if (!x_transpose && y_transpose) { - dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 2]; - dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 2]; + dim_out_vec[dims.size() - 2] = x_dims[x_dims.size() - 2]; + dim_out_vec[dims.size() - 1] = y_dims[y_dims.size() - 2]; } else if (x_transpose && !y_transpose) { - dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 1]; - dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 1]; + dim_out_vec[dims.size() - 2] = x_dims[x_dims.size() - 1]; + dim_out_vec[dims.size() - 1] = y_dims[y_dims.size() - 1]; } else { - dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 1]; - dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 2]; + dim_out_vec[dims.size() - 2] = x_dims[x_dims.size() - 1]; + dim_out_vec[dims.size() - 1] = y_dims[y_dims.size() - 2]; } } else if (x_dims.size() == 2 && y_dims.size() == 2) { // x: [M, K], y: [K, N], out: [M, N] // x: [M, K], y: [K, N], out: [M, N] - if (!x_transpose && !y_transpose) { - CHECK_EQ(x_dims[1], y_dims[0]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << "), x_transpose is " << x_transpose << ", y_transpose is " - << y_transpose; - } else if (!x_transpose && y_transpose) { - CHECK_EQ(x_dims[1], y_dims[1]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << "), x_transpose is " << x_transpose << ", y_transpose is " - << y_transpose; - } else if (x_transpose && !y_transpose) { - CHECK_EQ(x_dims[0], y_dims[0]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << "), x_transpose is " << x_transpose << ", y_transpose is " - << y_transpose; - } else { - CHECK_EQ(x_dims[0], y_dims[1]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << "), x_transpose is " << x_transpose << ", y_transpose is " - << y_transpose; - } dim_out_vec.resize(x_dims.size()); if (x_transpose) { dim_out_vec[0] = x_dims[1]; @@ -109,9 +103,6 @@ bool MatMulOpLite::InferShapeImpl() const { } } else if (x_dims.size() > 2 && y_dims.size() == 1) { // x: [B, M, K], y: [K], out: [B, M] - CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0]) - << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims - << ")"; dim_out_vec.resize(x_dims.size() - 1); for (size_t i = 0; i < dim_out_vec.size(); ++i) { dim_out_vec[i] = x_dims[i]; -- GitLab