diff --git a/lite/kernels/arm/matmul_compute.cc b/lite/kernels/arm/matmul_compute.cc index 17bac148d9211969e3903bb2d09d2e7f8b740d62..d6928652eee456b1345eb7882267a40540aa88a9 100644 --- a/lite/kernels/arm/matmul_compute.cc +++ b/lite/kernels/arm/matmul_compute.cc @@ -23,14 +23,6 @@ namespace lite { namespace kernels { namespace arm { -static void NaiveTranspose(int m, int n, const float* src, float* dst) { - for (int i = 0; i < m; ++i) { - for (int j = 0; j < n; ++j) { - dst[j * m + i] = src[i * n + j]; - } - } -} - void MatMulCompute::PrepareForRun() { auto& ctx = this->ctx_->template As(); } @@ -76,24 +68,26 @@ void MatMulCompute::Run() { << y_transpose; } + int lda, ldb, ldc; if (!x_transpose) { m_ = x_dims[x_dims.size() - 2]; k_ = x_dims[x_dims.size() - 1]; + lda = k_; } else { m_ = x_dims[x_dims.size() - 1]; k_ = x_dims[x_dims.size() - 2]; + lda = m_; } if (!y_transpose) { n_ = y_dims[y_dims.size() - 1]; + ldb = n_; } else { n_ = y_dims[y_dims.size() - 2]; + ldb = k_; } - int hblock = lite::arm::math::get_hblock(ctx.arch()); - int m_round = 0; - m_round = hblock * ((m_ + hblock - 1) / hblock); - ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); + ldc = n_; int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1]; int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1]; @@ -105,169 +99,44 @@ void MatMulCompute::Run() { } if (y_dims.size() > 2) { - if (n_ == 1) { - for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { - if (x_transpose) { - NaiveTranspose(x_dims[x_dims.size() - 2], - x_dims[x_dims.size() - 1], - x_data + i * x_inner, - x_data_trans); - lite::arm::math::sgemv(x_data_trans, - y_data + i * y_inner, - o_data + i * out_inner, - false, - m_, - k_, - false, - nullptr, - false); - } else { - lite::arm::math::sgemv(x_data + i * x_inner, - y_data + i * y_inner, - o_data + i * out_inner, - false, - m_, - k_, - false, - nullptr, - false); - } - } - if (fabsf(param.alpha - 1.f) > 1e-8f) { - for (size_t i = 0; i < param.Out->dims().production(); ++i) { - o_data[i] *= param.alpha; - } - } - } else { - float* packed_x = static_cast(ctx.workspace_data()) + - ctx.llc_size() / sizeof(float); - for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { - if (x_transpose) { - NaiveTranspose(x_dims[x_dims.size() - 2], - x_dims[x_dims.size() - 1], - x_data + i * x_inner, - x_data_trans); - lite::arm::math::prepackA(packed_x, - x_data_trans, - alpha, - k_, - 0, - m_, - 0, - k_, - false, - &ctx); - } else { - lite::arm::math::prepackA(packed_x, - x_data + i * x_inner, - alpha, - k_, - 0, - m_, - 0, - k_, - false, - &ctx); - } - - int ldb = n_; - if (y_transpose) { - ldb = k_; - } - lite::arm::math::sgemm_prepack(y_transpose, - m_, - n_, - k_, - packed_x, - y_data + i * y_inner, - ldb, - 0.f, - o_data + i * out_inner, - n_, - nullptr, - false, - false, - &ctx); - } + for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { + lite::arm::math::sgemm(x_transpose, + y_transpose, + m_, + n_, + k_, + alpha, + x_data + i * x_inner, + lda, + y_data + i * y_inner, + ldb, + 0.f, + o_data + i * out_inner, + ldc, + nullptr, + false, + false, + &ctx); } } else { - if (n_ == 1) { - for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { - if (x_transpose) { - NaiveTranspose(x_dims[x_dims.size() - 2], - x_dims[x_dims.size() - 1], - x_data + i * x_inner, - x_data_trans); - lite::arm::math::sgemv(x_data_trans, - y_data, - o_data + i * out_inner, - false, - m_, - k_, - false, - nullptr, - false); - } else { - lite::arm::math::sgemv(x_data + i * x_inner, - y_data, - o_data + i * out_inner, - false, - m_, - k_, - false, - nullptr, - false); - - } - } - if (fabsf(param.alpha - 1.f) > 1e-8f) { - for (size_t i = 0; i < param.Out->dims().production(); ++i) { - o_data[i] *= param.alpha; - } - } - } else { - float* packed_x = static_cast(ctx.workspace_data()) + - ctx.llc_size() / sizeof(float); - for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { - if (x_transpose) { - NaiveTranspose(x_dims[x_dims.size() - 2], - x_dims[x_dims.size() - 1], - x_data + i * x_inner, - x_data_trans); - lite::arm::math::prepackA( - packed_x, x_data_trans, alpha, k_, 0, m_, 0, k_, false, &ctx); - } else { - lite::arm::math::prepackA(packed_x, - x_data + i * x_inner, - alpha, - k_, - 0, - m_, - 0, - k_, - false, - &ctx); - } - - int ldb = n_; - if (y_transpose) { - ldb = k_; - } - lite::arm::math::sgemm_prepack(y_transpose, - m_, - n_, - k_, - packed_x, - y_data, - ldb, - 0.f, - o_data + i * out_inner, - n_, - nullptr, - false, - false, - &ctx); - } + for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { + lite::arm::math::sgemm(x_transpose, + y_transpose, + m_, + n_, + k_, + alpha, + x_data + i * x_inner, + lda, + y_data, + ldb, + 0.f, + o_data + i * out_inner, + ldc, + nullptr, + false, + false, + &ctx); } } if (x_data_trans) { @@ -297,74 +166,42 @@ void MatMulCompute::Run() { << y_transpose; } + int lda, ldb, ldc; if (!x_transpose) { m_ = x_dims[0]; k_ = x_dims[1]; + lda = k_; } else { m_ = x_dims[1]; k_ = x_dims[0]; + lda = m_; } if (!y_transpose) { n_ = y_dims[1]; + ldb = n_; } else { n_ = y_dims[0]; + ldb = k_; } - int hblock = lite::arm::math::get_hblock(ctx.arch()); - int m_round = 0; - m_round = hblock * ((m_ + hblock - 1) / hblock); - ctx.ExtendWorkspace(m_round * k_ * sizeof(float)); + ldc = n_; - if (n_ == 1) { - // lite::arm::math::sgemv doesn't support transpose. - if (x_transpose) { - float* x_data_trans = - static_cast(malloc(sizeof(float) * x_dims[0] * x_dims[1])); - NaiveTranspose(x_dims[0], x_dims[1], x_data, x_data_trans); - lite::arm::math::sgemv( - x_data_trans, y_data, o_data, false, m_, k_, false, nullptr, false); - } else { - lite::arm::math::sgemv( - x_data, y_data, o_data, false, m_, k_, false, nullptr, false); - } - if (fabsf(param.alpha - 1.f) > 1e-8f) { - for (size_t i = 0; i < param.Out->dims().production(); ++i) { - o_data[i] *= param.alpha; - } - } - } else { - float* packed_x = static_cast(ctx.workspace_data()) + - ctx.llc_size() / sizeof(float); - // prepackA seems that doesn't support transpose. - if (x_transpose) { - float* x_data_trans = - static_cast(malloc(sizeof(float) * x_dims[0] * x_dims[1])); - NaiveTranspose(x_dims[0], x_dims[1], x_data, x_data_trans); - lite::arm::math::prepackA( - packed_x, x_data_trans, alpha, k_, 0, m_, 0, k_, false, &ctx); - } else { - lite::arm::math::prepackA( - packed_x, x_data, alpha, k_, 0, m_, 0, k_, false, &ctx); - } - - int ldb = n_; - if (y_transpose) { - ldb = k_; - } - lite::arm::math::sgemm_prepack(y_transpose, - m_, - n_, - k_, - packed_x, - y_data, - ldb, - 0.f, - o_data, - n_, - nullptr, - false, - false, - &ctx); - } + lite::arm::math::sgemm(x_transpose, + y_transpose, + m_, + n_, + k_, + alpha, + x_data, + lda, + y_data, + ldb, + 0.f, + o_data, + ldc, + nullptr, + false, + false, + &ctx); } else if (x_dims.size() > 2 && y_dims.size() == 1) { // x: [B, M, K], y: [K], out: [B, M] CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0]) @@ -390,6 +227,9 @@ void MatMulCompute::Run() { m_ = x_dims[0]; k_ = 1; n_ = y_dims[0]; + int lda = k_; + int ldb = n_; + int ldc = n_; if (n_ == 1) { lite::arm::math::sgemv( x_data, y_data, o_data, false, m_, k_, false, nullptr, false); @@ -399,25 +239,23 @@ void MatMulCompute::Run() { } } } else { - float* packed_x = static_cast(ctx.workspace_data()) + - ctx.llc_size() / sizeof(float); - lite::arm::math::prepackA( - packed_x, x_data, alpha, k_, 0, m_, 0, k_, false, &ctx); - int ldb = n_; - lite::arm::math::sgemm_prepack(false, - m_, - n_, - k_, - packed_x, - y_data, - ldb, - 0.f, - o_data, - n_, - nullptr, - false, - false, - &ctx); + lite::arm::math::sgemm(false, + false, + m_, + n_, + k_, + alpha, + x_data, + lda, + y_data, + ldb, + 0.f, + o_data, + ldc, + nullptr, + false, + false, + &ctx); } } } else {