未验证 提交 48f09caa 编写于 作者: C cc 提交者: GitHub

Optimize matmul for size(x_dims)=2 size(y_dims)>2 (#3400)

* Optimize matmul for size(x_dims)=2  size(y_dims)>2
上级 dcf6acce
...@@ -213,7 +213,7 @@ void print_usage() { ...@@ -213,7 +213,7 @@ void print_usage() {
" --param_filename (The filename of param file, set param_file when\n" " --param_filename (The filename of param file, set param_file when\n"
" the model is combined formate. Otherwise, it is not necessary\n" " the model is combined formate. Otherwise, it is not necessary\n"
" to set it.) type: string \n" " to set it.) type: string \n"
" --input_shape (Tet input shapes according to the model, separated by\n" " --input_shape (Set input shapes according to the model, separated by\n"
" colon and comma, such as 1,3,244,244) type: string\n" " colon and comma, such as 1,3,244,244) type: string\n"
" default: 1,3,224,224 \n" " default: 1,3,224,224 \n"
" --input_img_path (The path of input image, if not set\n" " --input_img_path (The path of input image, if not set\n"
......
...@@ -45,32 +45,13 @@ void MatMulCompute::Run() { ...@@ -45,32 +45,13 @@ void MatMulCompute::Run() {
operators::ActivationParam act_param; operators::ActivationParam act_param;
act_param.has_active = false; act_param.has_active = false;
if (x_dims.size() > 2 && y_dims.size() >= 2) { if ((x_dims.size() >= 2 && y_dims.size() >= 2) &&
(x_dims.size() != 2 || y_dims.size() != 2)) {
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N] // x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N] // x: [B, M, K], y: [K, N], out: [B, M, N]
// or
if (!x_transpose && !y_transpose) { // x: [M, K], y: [B, ..., K, N], out: [B, ..., M, N]
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2]) // x: [M, K], y: [B, K, N], out: [B, M, N]
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ") x_transpose is " << x_transpose << "y_transpose is "
<< y_transpose;
} else if (!x_transpose && y_transpose) {
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ") x_transpose is " << x_transpose << "y_transpose is "
<< y_transpose;
} else if (x_transpose && !y_transpose) {
CHECK_EQ(x_dims[x_dims.size() - 2], y_dims[y_dims.size() - 2])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ") x_transpose is " << x_transpose << "y_transpose is "
<< y_transpose;
} else {
CHECK_EQ(x_dims[x_dims.size() - 2], y_dims[y_dims.size() - 1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ") x_transpose is " << x_transpose << "y_transpose is "
<< y_transpose;
}
int lda, ldb, ldc; int lda, ldb, ldc;
if (!x_transpose) { if (!x_transpose) {
m_ = x_dims[x_dims.size() - 2]; m_ = x_dims[x_dims.size() - 2];
...@@ -96,11 +77,7 @@ void MatMulCompute::Run() { ...@@ -96,11 +77,7 @@ void MatMulCompute::Run() {
int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1]; int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1];
int out_inner = o_dims[o_dims.size() - 2] * o_dims[o_dims.size() - 1]; int out_inner = o_dims[o_dims.size() - 2] * o_dims[o_dims.size() - 1];
float* x_data_trans = nullptr; if (x_dims.size() > 2 && y_dims.size() > 2) {
if (x_transpose) {
x_data_trans = static_cast<float*>(malloc(sizeof(float) * x_inner));
}
if (y_dims.size() > 2) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
lite::arm::math::sgemm(x_transpose, lite::arm::math::sgemm(x_transpose,
y_transpose, y_transpose,
...@@ -120,7 +97,7 @@ void MatMulCompute::Run() { ...@@ -120,7 +97,7 @@ void MatMulCompute::Run() {
act_param, act_param,
&ctx); &ctx);
} }
} else { } else if (x_dims.size() > 2 && y_dims.size() == 2) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) { for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
lite::arm::math::sgemm(x_transpose, lite::arm::math::sgemm(x_transpose,
y_transpose, y_transpose,
...@@ -140,34 +117,29 @@ void MatMulCompute::Run() { ...@@ -140,34 +117,29 @@ void MatMulCompute::Run() {
act_param, act_param,
&ctx); &ctx);
} }
} } else if (x_dims.size() == 2 && y_dims.size() > 2) {
if (x_data_trans) { for (size_t i = 0; i < y_dims.count(0, y_dims.size() - 2); ++i) {
free(x_data_trans); lite::arm::math::sgemm(x_transpose,
y_transpose,
m_,
n_,
k_,
alpha,
x_data,
lda,
y_data + i * y_inner,
ldb,
0.f,
o_data + i * out_inner,
ldc,
nullptr,
false,
act_param,
&ctx);
}
} }
} else if (x_dims.size() == 2 && y_dims.size() == 2) { } else if (x_dims.size() == 2 && y_dims.size() == 2) {
// x: [M, K], y: [K, N], out: [M, N] // x: [M, K], y: [K, N], out: [M, N]
if (!x_transpose && !y_transpose) {
CHECK_EQ(x_dims[1], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else if (!x_transpose && y_transpose) {
CHECK_EQ(x_dims[1], y_dims[1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else if (x_transpose && !y_transpose) {
CHECK_EQ(x_dims[0], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else {
CHECK_EQ(x_dims[0], y_dims[1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
}
int lda, ldb, ldc; int lda, ldb, ldc;
if (!x_transpose) { if (!x_transpose) {
m_ = x_dims[0]; m_ = x_dims[0];
......
...@@ -24,19 +24,12 @@ bool MatMulOpLite::CheckShape() const { ...@@ -24,19 +24,12 @@ bool MatMulOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Y); CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
return true;
}
bool MatMulOpLite::InferShapeImpl() const {
const auto x_dims = param_.X->dims(); const auto x_dims = param_.X->dims();
const auto y_dims = param_.Y->dims(); const auto y_dims = param_.Y->dims();
bool x_transpose = param_.transpose_X; bool x_transpose = param_.transpose_X;
bool y_transpose = param_.transpose_Y; bool y_transpose = param_.transpose_Y;
std::vector<int64_t> dim_out_vec;
if (x_dims.size() > 2 && y_dims.size() >= 2) { if (x_dims.size() > 1 && y_dims.size() > 1) {
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N]
if (!x_transpose && !y_transpose) { if (!x_transpose && !y_transpose) {
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2]) CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
...@@ -54,48 +47,49 @@ bool MatMulOpLite::InferShapeImpl() const { ...@@ -54,48 +47,49 @@ bool MatMulOpLite::InferShapeImpl() const {
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")"; << ")";
} }
} else if (x_dims.size() > 2 && y_dims.size() == 1) {
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")";
}
return true;
}
dim_out_vec.resize(x_dims.size()); bool MatMulOpLite::InferShapeImpl() const {
for (size_t i = 0; i < x_dims.size() - 2; ++i) { const auto x_dims = param_.X->dims();
dim_out_vec[i] = x_dims[i]; const auto y_dims = param_.Y->dims();
bool x_transpose = param_.transpose_X;
bool y_transpose = param_.transpose_Y;
std::vector<int64_t> dim_out_vec;
if ((x_dims.size() >= 2 && y_dims.size() >= 2) &&
(x_dims.size() != 2 || y_dims.size() != 2)) {
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N]
// or
// x: [M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [M, K], y: [B, K, N], out: [B, M, N]
DDim dims = x_dims.size() >= y_dims.size() ? x_dims : y_dims;
dim_out_vec.resize(dims.size());
for (size_t i = 0; i < dims.size() - 2; ++i) {
dim_out_vec[i] = dims[i];
} }
if (!x_transpose && !y_transpose) { if (!x_transpose && !y_transpose) {
dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 2]; dim_out_vec[dims.size() - 2] = x_dims[x_dims.size() - 2];
dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 1]; dim_out_vec[dims.size() - 1] = y_dims[y_dims.size() - 1];
} else if (!x_transpose && y_transpose) { } else if (!x_transpose && y_transpose) {
dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 2]; dim_out_vec[dims.size() - 2] = x_dims[x_dims.size() - 2];
dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 2]; dim_out_vec[dims.size() - 1] = y_dims[y_dims.size() - 2];
} else if (x_transpose && !y_transpose) { } else if (x_transpose && !y_transpose) {
dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 1]; dim_out_vec[dims.size() - 2] = x_dims[x_dims.size() - 1];
dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 1]; dim_out_vec[dims.size() - 1] = y_dims[y_dims.size() - 1];
} else { } else {
dim_out_vec[x_dims.size() - 2] = x_dims[x_dims.size() - 1]; dim_out_vec[dims.size() - 2] = x_dims[x_dims.size() - 1];
dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 2]; dim_out_vec[dims.size() - 1] = y_dims[y_dims.size() - 2];
} }
} else if (x_dims.size() == 2 && y_dims.size() == 2) { } else if (x_dims.size() == 2 && y_dims.size() == 2) {
// x: [M, K], y: [K, N], out: [M, N] // x: [M, K], y: [K, N], out: [M, N]
// x: [M, K], y: [K, N], out: [M, N] // x: [M, K], y: [K, N], out: [M, N]
if (!x_transpose && !y_transpose) {
CHECK_EQ(x_dims[1], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else if (!x_transpose && y_transpose) {
CHECK_EQ(x_dims[1], y_dims[1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else if (x_transpose && !y_transpose) {
CHECK_EQ(x_dims[0], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else {
CHECK_EQ(x_dims[0], y_dims[1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
}
dim_out_vec.resize(x_dims.size()); dim_out_vec.resize(x_dims.size());
if (x_transpose) { if (x_transpose) {
dim_out_vec[0] = x_dims[1]; dim_out_vec[0] = x_dims[1];
...@@ -109,9 +103,6 @@ bool MatMulOpLite::InferShapeImpl() const { ...@@ -109,9 +103,6 @@ bool MatMulOpLite::InferShapeImpl() const {
} }
} else if (x_dims.size() > 2 && y_dims.size() == 1) { } else if (x_dims.size() > 2 && y_dims.size() == 1) {
// x: [B, M, K], y: [K], out: [B, M] // x: [B, M, K], y: [K], out: [B, M]
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")";
dim_out_vec.resize(x_dims.size() - 1); dim_out_vec.resize(x_dims.size() - 1);
for (size_t i = 0; i < dim_out_vec.size(); ++i) { for (size_t i = 0; i < dim_out_vec.size(); ++i) {
dim_out_vec[i] = x_dims[i]; dim_out_vec[i] = x_dims[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册