提交 050fc184 编写于 作者: J jiweibo

use lite::arm::math::sgemm func to implement matmul

上级 03fd37e4
...@@ -23,14 +23,6 @@ namespace lite { ...@@ -23,14 +23,6 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
static void NaiveTranspose(int m, int n, const float* src, float* dst) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
dst[j * m + i] = src[i * n + j];
}
}
}
void MatMulCompute::PrepareForRun() { void MatMulCompute::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
} }
...@@ -76,24 +68,26 @@ void MatMulCompute::Run() { ...@@ -76,24 +68,26 @@ void MatMulCompute::Run() {
<< y_transpose; << y_transpose;
} }
int lda, ldb, ldc;
if (!x_transpose) { if (!x_transpose) {
m_ = x_dims[x_dims.size() - 2]; m_ = x_dims[x_dims.size() - 2];
k_ = x_dims[x_dims.size() - 1]; k_ = x_dims[x_dims.size() - 1];
lda = k_;
} else { } else {
m_ = x_dims[x_dims.size() - 1]; m_ = x_dims[x_dims.size() - 1];
k_ = x_dims[x_dims.size() - 2]; k_ = x_dims[x_dims.size() - 2];
lda = m_;
} }
if (!y_transpose) { if (!y_transpose) {
n_ = y_dims[y_dims.size() - 1]; n_ = y_dims[y_dims.size() - 1];
ldb = n_;
} else { } else {
n_ = y_dims[y_dims.size() - 2]; n_ = y_dims[y_dims.size() - 2];
ldb = k_;
} }
int hblock = lite::arm::math::get_hblock(ctx.arch()); ldc = n_;
int m_round = 0;
m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * k_ * sizeof(float));
int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1]; int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1];
int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1]; int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1];
...@@ -105,169 +99,44 @@ void MatMulCompute::Run() { ...@@ -105,169 +99,44 @@ void MatMulCompute::Run() {
} }
if (y_dims.size() > 2) { if (y_dims.size() > 2) {
if (n_ == 1) { for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { lite::arm::math::sgemm(x_transpose,
if (x_transpose) { y_transpose,
NaiveTranspose(x_dims[x_dims.size() - 2], m_,
x_dims[x_dims.size() - 1], n_,
x_data + i * x_inner, k_,
x_data_trans); alpha,
lite::arm::math::sgemv(x_data_trans, x_data + i * x_inner,
y_data + i * y_inner, lda,
o_data + i * out_inner, y_data + i * y_inner,
false, ldb,
m_, 0.f,
k_, o_data + i * out_inner,
false, ldc,
nullptr, nullptr,
false); false,
} else { false,
lite::arm::math::sgemv(x_data + i * x_inner, &ctx);
y_data + i * y_inner,
o_data + i * out_inner,
false,
m_,
k_,
false,
nullptr,
false);
}
}
if (fabsf(param.alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= param.alpha;
}
}
} else {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
if (x_transpose) {
NaiveTranspose(x_dims[x_dims.size() - 2],
x_dims[x_dims.size() - 1],
x_data + i * x_inner,
x_data_trans);
lite::arm::math::prepackA(packed_x,
x_data_trans,
alpha,
k_,
0,
m_,
0,
k_,
false,
&ctx);
} else {
lite::arm::math::prepackA(packed_x,
x_data + i * x_inner,
alpha,
k_,
0,
m_,
0,
k_,
false,
&ctx);
}
int ldb = n_;
if (y_transpose) {
ldb = k_;
}
lite::arm::math::sgemm_prepack(y_transpose,
m_,
n_,
k_,
packed_x,
y_data + i * y_inner,
ldb,
0.f,
o_data + i * out_inner,
n_,
nullptr,
false,
false,
&ctx);
}
} }
} else { } else {
if (n_ == 1) { for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { lite::arm::math::sgemm(x_transpose,
if (x_transpose) { y_transpose,
NaiveTranspose(x_dims[x_dims.size() - 2], m_,
x_dims[x_dims.size() - 1], n_,
x_data + i * x_inner, k_,
x_data_trans); alpha,
lite::arm::math::sgemv(x_data_trans, x_data + i * x_inner,
y_data, lda,
o_data + i * out_inner, y_data,
false, ldb,
m_, 0.f,
k_, o_data + i * out_inner,
false, ldc,
nullptr, nullptr,
false); false,
} else { false,
lite::arm::math::sgemv(x_data + i * x_inner, &ctx);
y_data,
o_data + i * out_inner,
false,
m_,
k_,
false,
nullptr,
false);
}
}
if (fabsf(param.alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= param.alpha;
}
}
} else {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
if (x_transpose) {
NaiveTranspose(x_dims[x_dims.size() - 2],
x_dims[x_dims.size() - 1],
x_data + i * x_inner,
x_data_trans);
lite::arm::math::prepackA(
packed_x, x_data_trans, alpha, k_, 0, m_, 0, k_, false, &ctx);
} else {
lite::arm::math::prepackA(packed_x,
x_data + i * x_inner,
alpha,
k_,
0,
m_,
0,
k_,
false,
&ctx);
}
int ldb = n_;
if (y_transpose) {
ldb = k_;
}
lite::arm::math::sgemm_prepack(y_transpose,
m_,
n_,
k_,
packed_x,
y_data,
ldb,
0.f,
o_data + i * out_inner,
n_,
nullptr,
false,
false,
&ctx);
}
} }
} }
if (x_data_trans) { if (x_data_trans) {
...@@ -297,74 +166,42 @@ void MatMulCompute::Run() { ...@@ -297,74 +166,42 @@ void MatMulCompute::Run() {
<< y_transpose; << y_transpose;
} }
int lda, ldb, ldc;
if (!x_transpose) { if (!x_transpose) {
m_ = x_dims[0]; m_ = x_dims[0];
k_ = x_dims[1]; k_ = x_dims[1];
lda = k_;
} else { } else {
m_ = x_dims[1]; m_ = x_dims[1];
k_ = x_dims[0]; k_ = x_dims[0];
lda = m_;
} }
if (!y_transpose) { if (!y_transpose) {
n_ = y_dims[1]; n_ = y_dims[1];
ldb = n_;
} else { } else {
n_ = y_dims[0]; n_ = y_dims[0];
ldb = k_;
} }
int hblock = lite::arm::math::get_hblock(ctx.arch()); ldc = n_;
int m_round = 0;
m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * k_ * sizeof(float));
if (n_ == 1) { lite::arm::math::sgemm(x_transpose,
// lite::arm::math::sgemv doesn't support transpose. y_transpose,
if (x_transpose) { m_,
float* x_data_trans = n_,
static_cast<float*>(malloc(sizeof(float) * x_dims[0] * x_dims[1])); k_,
NaiveTranspose(x_dims[0], x_dims[1], x_data, x_data_trans); alpha,
lite::arm::math::sgemv( x_data,
x_data_trans, y_data, o_data, false, m_, k_, false, nullptr, false); lda,
} else { y_data,
lite::arm::math::sgemv( ldb,
x_data, y_data, o_data, false, m_, k_, false, nullptr, false); 0.f,
} o_data,
if (fabsf(param.alpha - 1.f) > 1e-8f) { ldc,
for (size_t i = 0; i < param.Out->dims().production(); ++i) { nullptr,
o_data[i] *= param.alpha; false,
} false,
} &ctx);
} else {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
// prepackA seems that doesn't support transpose.
if (x_transpose) {
float* x_data_trans =
static_cast<float*>(malloc(sizeof(float) * x_dims[0] * x_dims[1]));
NaiveTranspose(x_dims[0], x_dims[1], x_data, x_data_trans);
lite::arm::math::prepackA(
packed_x, x_data_trans, alpha, k_, 0, m_, 0, k_, false, &ctx);
} else {
lite::arm::math::prepackA(
packed_x, x_data, alpha, k_, 0, m_, 0, k_, false, &ctx);
}
int ldb = n_;
if (y_transpose) {
ldb = k_;
}
lite::arm::math::sgemm_prepack(y_transpose,
m_,
n_,
k_,
packed_x,
y_data,
ldb,
0.f,
o_data,
n_,
nullptr,
false,
false,
&ctx);
}
} else if (x_dims.size() > 2 && y_dims.size() == 1) { } else if (x_dims.size() > 2 && y_dims.size() == 1) {
// x: [B, M, K], y: [K], out: [B, M] // x: [B, M, K], y: [K], out: [B, M]
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0]) CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0])
...@@ -390,6 +227,9 @@ void MatMulCompute::Run() { ...@@ -390,6 +227,9 @@ void MatMulCompute::Run() {
m_ = x_dims[0]; m_ = x_dims[0];
k_ = 1; k_ = 1;
n_ = y_dims[0]; n_ = y_dims[0];
int lda = k_;
int ldb = n_;
int ldc = n_;
if (n_ == 1) { if (n_ == 1) {
lite::arm::math::sgemv( lite::arm::math::sgemv(
x_data, y_data, o_data, false, m_, k_, false, nullptr, false); x_data, y_data, o_data, false, m_, k_, false, nullptr, false);
...@@ -399,25 +239,23 @@ void MatMulCompute::Run() { ...@@ -399,25 +239,23 @@ void MatMulCompute::Run() {
} }
} }
} else { } else {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) + lite::arm::math::sgemm(false,
ctx.llc_size() / sizeof(float); false,
lite::arm::math::prepackA( m_,
packed_x, x_data, alpha, k_, 0, m_, 0, k_, false, &ctx); n_,
int ldb = n_; k_,
lite::arm::math::sgemm_prepack(false, alpha,
m_, x_data,
n_, lda,
k_, y_data,
packed_x, ldb,
y_data, 0.f,
ldb, o_data,
0.f, ldc,
o_data, nullptr,
n_, false,
nullptr, false,
false, &ctx);
false,
&ctx);
} }
} }
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册