提交 050fc184 编写于 作者: J jiweibo

use lite::arm::math::sgemm func to implement matmul

上级 03fd37e4
......@@ -23,14 +23,6 @@ namespace lite {
namespace kernels {
namespace arm {
static void NaiveTranspose(int m, int n, const float* src, float* dst) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
dst[j * m + i] = src[i * n + j];
}
}
}
void MatMulCompute::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
}
......@@ -76,24 +68,26 @@ void MatMulCompute::Run() {
<< y_transpose;
}
int lda, ldb, ldc;
if (!x_transpose) {
m_ = x_dims[x_dims.size() - 2];
k_ = x_dims[x_dims.size() - 1];
lda = k_;
} else {
m_ = x_dims[x_dims.size() - 1];
k_ = x_dims[x_dims.size() - 2];
lda = m_;
}
if (!y_transpose) {
n_ = y_dims[y_dims.size() - 1];
ldb = n_;
} else {
n_ = y_dims[y_dims.size() - 2];
ldb = k_;
}
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = 0;
m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * k_ * sizeof(float));
ldc = n_;
int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1];
int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1];
......@@ -105,171 +99,46 @@ void MatMulCompute::Run() {
}
if (y_dims.size() > 2) {
if (n_ == 1) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
if (x_transpose) {
NaiveTranspose(x_dims[x_dims.size() - 2],
x_dims[x_dims.size() - 1],
x_data + i * x_inner,
x_data_trans);
lite::arm::math::sgemv(x_data_trans,
y_data + i * y_inner,
o_data + i * out_inner,
false,
m_,
k_,
false,
nullptr,
false);
} else {
lite::arm::math::sgemv(x_data + i * x_inner,
y_data + i * y_inner,
o_data + i * out_inner,
false,
lite::arm::math::sgemm(x_transpose,
y_transpose,
m_,
n_,
k_,
false,
nullptr,
false);
}
}
if (fabsf(param.alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= param.alpha;
}
}
} else {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
if (x_transpose) {
NaiveTranspose(x_dims[x_dims.size() - 2],
x_dims[x_dims.size() - 1],
x_data + i * x_inner,
x_data_trans);
lite::arm::math::prepackA(packed_x,
x_data_trans,
alpha,
k_,
0,
m_,
0,
k_,
false,
&ctx);
} else {
lite::arm::math::prepackA(packed_x,
x_data + i * x_inner,
alpha,
k_,
0,
m_,
0,
k_,
false,
&ctx);
}
int ldb = n_;
if (y_transpose) {
ldb = k_;
}
lite::arm::math::sgemm_prepack(y_transpose,
m_,
n_,
k_,
packed_x,
lda,
y_data + i * y_inner,
ldb,
0.f,
o_data + i * out_inner,
n_,
ldc,
nullptr,
false,
false,
&ctx);
}
}
} else {
if (n_ == 1) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
if (x_transpose) {
NaiveTranspose(x_dims[x_dims.size() - 2],
x_dims[x_dims.size() - 1],
x_data + i * x_inner,
x_data_trans);
lite::arm::math::sgemv(x_data_trans,
y_data,
o_data + i * out_inner,
false,
m_,
k_,
false,
nullptr,
false);
} else {
lite::arm::math::sgemv(x_data + i * x_inner,
y_data,
o_data + i * out_inner,
false,
m_,
k_,
false,
nullptr,
false);
}
}
if (fabsf(param.alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= param.alpha;
}
}
} else {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
if (x_transpose) {
NaiveTranspose(x_dims[x_dims.size() - 2],
x_dims[x_dims.size() - 1],
x_data + i * x_inner,
x_data_trans);
lite::arm::math::prepackA(
packed_x, x_data_trans, alpha, k_, 0, m_, 0, k_, false, &ctx);
} else {
lite::arm::math::prepackA(packed_x,
x_data + i * x_inner,
alpha,
k_,
0,
m_,
0,
k_,
false,
&ctx);
}
int ldb = n_;
if (y_transpose) {
ldb = k_;
}
lite::arm::math::sgemm_prepack(y_transpose,
lite::arm::math::sgemm(x_transpose,
y_transpose,
m_,
n_,
k_,
packed_x,
alpha,
x_data + i * x_inner,
lda,
y_data,
ldb,
0.f,
o_data + i * out_inner,
n_,
ldc,
nullptr,
false,
false,
&ctx);
}
}
}
if (x_data_trans) {
free(x_data_trans);
}
......@@ -297,74 +166,42 @@ void MatMulCompute::Run() {
<< y_transpose;
}
int lda, ldb, ldc;
if (!x_transpose) {
m_ = x_dims[0];
k_ = x_dims[1];
lda = k_;
} else {
m_ = x_dims[1];
k_ = x_dims[0];
lda = m_;
}
if (!y_transpose) {
n_ = y_dims[1];
ldb = n_;
} else {
n_ = y_dims[0];
}
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = 0;
m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * k_ * sizeof(float));
if (n_ == 1) {
// lite::arm::math::sgemv doesn't support transpose.
if (x_transpose) {
float* x_data_trans =
static_cast<float*>(malloc(sizeof(float) * x_dims[0] * x_dims[1]));
NaiveTranspose(x_dims[0], x_dims[1], x_data, x_data_trans);
lite::arm::math::sgemv(
x_data_trans, y_data, o_data, false, m_, k_, false, nullptr, false);
} else {
lite::arm::math::sgemv(
x_data, y_data, o_data, false, m_, k_, false, nullptr, false);
}
if (fabsf(param.alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= param.alpha;
}
}
} else {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
// prepackA seems that doesn't support transpose.
if (x_transpose) {
float* x_data_trans =
static_cast<float*>(malloc(sizeof(float) * x_dims[0] * x_dims[1]));
NaiveTranspose(x_dims[0], x_dims[1], x_data, x_data_trans);
lite::arm::math::prepackA(
packed_x, x_data_trans, alpha, k_, 0, m_, 0, k_, false, &ctx);
} else {
lite::arm::math::prepackA(
packed_x, x_data, alpha, k_, 0, m_, 0, k_, false, &ctx);
}
int ldb = n_;
if (y_transpose) {
ldb = k_;
}
lite::arm::math::sgemm_prepack(y_transpose,
ldc = n_;
lite::arm::math::sgemm(x_transpose,
y_transpose,
m_,
n_,
k_,
packed_x,
alpha,
x_data,
lda,
y_data,
ldb,
0.f,
o_data,
n_,
ldc,
nullptr,
false,
false,
&ctx);
}
} else if (x_dims.size() > 2 && y_dims.size() == 1) {
// x: [B, M, K], y: [K], out: [B, M]
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0])
......@@ -390,6 +227,9 @@ void MatMulCompute::Run() {
m_ = x_dims[0];
k_ = 1;
n_ = y_dims[0];
int lda = k_;
int ldb = n_;
int ldc = n_;
if (n_ == 1) {
lite::arm::math::sgemv(
x_data, y_data, o_data, false, m_, k_, false, nullptr, false);
......@@ -399,21 +239,19 @@ void MatMulCompute::Run() {
}
}
} else {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
lite::arm::math::prepackA(
packed_x, x_data, alpha, k_, 0, m_, 0, k_, false, &ctx);
int ldb = n_;
lite::arm::math::sgemm_prepack(false,
lite::arm::math::sgemm(false,
false,
m_,
n_,
k_,
packed_x,
alpha,
x_data,
lda,
y_data,
ldb,
0.f,
o_data,
n_,
ldc,
nullptr,
false,
false,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册