提交 7bc66b71 编写于 作者: S smilejames 提交者: GitHub

Merge pull request #463 from smilejames/develop

optimize gemm code
......@@ -26,12 +26,12 @@ alignas(64) float packedA[MC * KC];
alignas(64) float packedB[KC * NC];
alignas(64) float ab[MR * NR];
// 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int paddingM, const float *A, int lda,
float *buffer) {
int i, j;
void PackMatrixA(int m, int k, const float *A, int lda, float *buffer) {
int i, j, m_tail;
const float *Aij;
for (i = 0; i < m - paddingM; i += MR) {
for (int j = 0; j < k; ++j) {
m_tail = m % NR;
for (i = 0; i < m - m_tail; i += MR) {
for (j = 0; j < k; ++j) {
Aij = &A(i, j);
*buffer++ = *Aij;
*buffer++ = *(Aij + 1);
......@@ -39,13 +39,13 @@ void PackMatrixA(int m, int k, int paddingM, const float *A, int lda,
*buffer++ = *(Aij + 3);
}
}
if (paddingM != 0) {
if (m_tail != 0) {
for (j = 0; j < k; ++j) {
Aij = &A(m - paddingM, j);
for (i = 0; i < paddingM; ++i) {
Aij = &A(m - m_tail, j);
for (i = 0; i < m_tail; ++i) {
*buffer++ = *(Aij + i);
}
for (i = paddingM; i < MR; ++i) {
for (i = m_tail; i < MR; ++i) {
*buffer++ = 0;
}
}
......@@ -53,11 +53,11 @@ void PackMatrixA(int m, int k, int paddingM, const float *A, int lda,
}
// 将A矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda,
float *buffer) {
int i, j;
void PackMatrixA_(int m, int k, const float *A, int lda, float *buffer) {
int i, j, m_tail;
const float *Ai, *Ai1, *Ai2, *Ai3;
for (i = 0; i < m - paddingM; i += MR) {
m_tail = m % NR;
for (i = 0; i < m - m_tail; i += MR) {
Ai = &A(i, 0);
Ai1 = &A(i + 1, 0);
Ai2 = &A(i + 2, 0);
......@@ -69,12 +69,12 @@ void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda,
*buffer++ = *Ai3++;
}
}
if (paddingM != 0) {
if (m_tail != 0) {
for (j = 0; j < k; ++j) {
for (i = m - paddingM; i < m; ++i) {
for (i = m - m_tail; i < m; ++i) {
*buffer++ = A(i, j);
}
for (i = m; i < m + (MR - paddingM); ++i) {
for (i = m; i < m + (MR - m_tail); ++i) {
*buffer++ = 0;
}
}
......@@ -82,11 +82,11 @@ void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda,
}
// 将B矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb,
float *buffer) {
int i, j;
void PackMatrixB(int k, int n, const float *B, int ldb, float *buffer) {
int i, j, n_tail;
const float *Bj, *Bj1, *Bj2, *Bj3;
for (j = 0; j < n - paddingN; j += NR) {
n_tail = n % NR;
for (j = 0; j < n - n_tail; j += NR) {
Bj = &B(0, j);
Bj1 = &B(0, j + 1);
Bj2 = &B(0, j + 2);
......@@ -98,12 +98,12 @@ void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb,
*buffer++ = *Bj3++;
}
}
if (paddingN != 0) {
if (n_tail != 0) {
for (i = 0; i < k; ++i) {
for (int j = n - paddingN; j < n; ++j) {
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = B(i, j);
}
for (int j = n; j < n + (NR - paddingN); ++j) {
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
......@@ -111,11 +111,11 @@ void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb,
}
// 将B矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
float *buffer) {
int i, j;
void PackMatrixB_(int k, int n, const float *B, int ldb, float *buffer) {
int i, j, n_tail;
const float *Bij;
for (j = 0; j < n - paddingN; j += NR) {
n_tail = n % NR;
for (j = 0; j < n - n_tail; j += NR) {
for (i = 0; i < k; ++i) {
Bij = &B(i, j);
asm volatile(
......@@ -126,13 +126,13 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
: "memory", "q0");
}
}
if (paddingN != 0) {
if (n_tail != 0) {
for (i = 0; i < k; ++i) {
Bij = &B(i, n - paddingN);
for (int j = n - paddingN; j < n; ++j) {
Bij = &B(i, n - n_tail);
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = *Bij++;
}
for (int j = n; j < n + (NR - paddingN); ++j) {
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
......@@ -143,33 +143,25 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
int first_time) {
int Buff_A_M = m;
int Buff_B_N = n;
int m_block = (m + MR - 1) / MR * MR;
int n_block = (n + NR - 1) / NR * NR;
int _mc = m % MR;
int _nc = n % NR;
if (_mc != 0) {
Buff_A_M = m + (MR - _mc);
}
if (_nc != 0) {
Buff_B_N = n + (NR - _nc);
}
int m_tail = m % MR;
int n_tail = n % NR;
if (first_time) {
PackMatrixB_(k, n, _nc, B, ldb, packedB);
PackMatrixB_(k, n, B, ldb, packedB);
}
PackMatrixA_(m, k, _mc, A, lda, packedA);
PackMatrixA_(m, k, A, lda, packedA);
int i, j, mc, nc;
// B 取 4 列, 打包预热
for (j = 0; j < Buff_B_N; j += NR) {
nc = (n - j) < NR ? _nc : NR;
for (j = 0; j < n_block; j += NR) {
nc = (n - j) < NR ? n_tail : NR;
// A 取 4 行,打包预热
for (i = 0; i < Buff_A_M; i += MR) {
mc = (m - i) < MR ? _mc : MR;
for (i = 0; i < m_block; i += MR) {
mc = (m - i) < MR ? m_tail : MR;
AddDot4x4(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta,
&C(i, j), ldc, mc, nc);
}
......@@ -180,36 +172,25 @@ void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
void InnerKernel_relu(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
int first_time, bool relu = false) {
int Buff_A_M = m;
int Buff_B_N = n;
int _mc = m % MR;
int _nc = n % NR;
if (_mc != 0) {
Buff_A_M = m + (MR - _mc);
}
if (_nc != 0) {
Buff_B_N = n + (NR - _nc);
}
int m_block = (m + MR - 1) / MR * MR;
int n_block = (n + NR - 1) / NR * NR;
float packedA[MC * KC];
static float packedB[KC * NC];
int m_tail = m % MR;
int n_tail = n % NR;
if (first_time) {
PackMatrixB_(k, n, _nc, B, ldb, packedB);
PackMatrixB_(k, n, B, ldb, packedB);
}
PackMatrixA_(m, k, _mc, A, lda, packedA);
PackMatrixA_(m, k, A, lda, packedA);
int i, j, mc, nc;
// B 取 4 列, 打包预热
for (j = 0; j < Buff_B_N; j += NR) {
nc = (n - j) < NR ? _nc : NR;
for (j = 0; j < n_block; j += NR) {
nc = (n - j) < NR ? n_tail : NR;
// A 取 4 行,打包预热
for (i = 0; i < Buff_A_M; i += MR) {
mc = (m - i) < MR ? _mc : MR;
for (i = 0; i < m_block; i += MR) {
mc = (m - i) < MR ? m_tail : MR;
AddDot4x4_relu(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta,
&C(i, j), ldc, mc, nc, relu);
}
......@@ -359,16 +340,16 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
"vmla.f32 q11, q3, d2[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t"
"vld1.32 {q0, q1}, [%[a]]! \n\t"
"vld1.32 {q2, q3}, [%[b]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vmla.f32 q11, q2, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q2, d1[1] \n\t"
"vmla.f32 q10, q3, d2[0] \n\t"
"vmla.f32 q11, q3, d2[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t"
"vld1.32 {q4, q5}, [%[a]]! \n\t"
"vld1.32 {q6, q7}, [%[b]]! \n\t"
"vmla.f32 q10, q6, d8[0] \n\t"
"vmla.f32 q11, q6, d8[1] \n\t"
"vmla.f32 q12, q6, d9[0] \n\t"
"vmla.f32 q13, q6, d9[1] \n\t"
"vmla.f32 q10, q7, d10[0] \n\t"
"vmla.f32 q11, q7, d10[1] \n\t"
"vmla.f32 q12, q7, d11[0] \n\t"
"vmla.f32 q13, q7, d11[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
......@@ -391,13 +372,11 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
"cmp %[nc], #4 \n\t"
"bne temp_%= \n\t"
"vmov.f32 d8[0], %[alpha] \n\t"
"vmov.f32 d8[1], %[beta] \n\t"
"cmp %[flag_alpha], #1 \n\t"
"bne alpha_%= \n\t"
"alpha_%=: \n\t"
"vmov.f32 d8[0], %[alpha] \n\t"
"vmul.f32 q10, q10, d8[0] \n\t"
"vmul.f32 q11, q11, d8[0] \n\t"
"vmul.f32 q12, q12, d8[0] \n\t"
......@@ -425,6 +404,7 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
"b memory_%= \n\t"
"beta_ne1_%=: \n\t"
"vmov.f32 d8[1], %[beta] \n\t"
"vmla.f32 q10, q0, d8[1] \n\t"
"vmla.f32 q11, q1, d8[1] \n\t"
"vmla.f32 q12, q2, d8[1] \n\t"
......@@ -448,7 +428,8 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
[kc2] "r"(kc2), [mc] "r"(mc), [nc] "r"(nc), [alpha] "r"(alpha),
[beta] "r"(beta), [bytes_ldc] "r"(bytes_ldc),
[flag_alpha] "r"(flag_alpha), [flag_beta] "r"(flag_beta)
: "memory", "q0", "q1", "q2", "q3", "q4", "q10", "q11", "q12", "q13");
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q12", "q13");
if (mc != MR || nc != NR) {
int i, j;
......@@ -512,28 +493,31 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
"vmla.f32 q11, q3, d2[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t"
"vld1.32 {q0, q1}, [%[a]]! \n\t"
"vld1.32 {q2, q3}, [%[b]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vmla.f32 q11, q2, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q2, d1[1] \n\t"
"vmla.f32 q10, q3, d2[0] \n\t"
"vmla.f32 q11, q3, d2[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t"
"vld1.32 {q4, q5}, [%[a]]! \n\t"
"vld1.32 {q6, q7}, [%[b]]! \n\t"
"vmla.f32 q10, q6, d8[0] \n\t"
"vmla.f32 q11, q6, d8[1] \n\t"
"vmla.f32 q12, q6, d9[0] \n\t"
"vmla.f32 q13, q6, d9[1] \n\t"
"vmla.f32 q10, q7, d10[0] \n\t"
"vmla.f32 q11, q7, d10[1] \n\t"
"vmla.f32 q12, q7, d11[0] \n\t"
"vmla.f32 q13, q7, d11[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"blt end_kc2_%= \n\t"
"loop_kc2_%=: \n\t"
"vld1.32 {q0}, [%[a]]! \n\t"
"vld1.32 {q1}, [%[b]]! \n\t"
"vmla.f32 q10, q1, d0[0] \n\t"
"vmla.f32 q11, q1, d0[1] \n\t"
"vmla.f32 q12, q1, d1[0] \n\t"
"vmla.f32 q13, q1, d1[1] \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"bge loop_kc2_%= \n\t"
"end_kc2_%=: \n\t"
"cmp %[mc], #4 \n\t"
......@@ -541,13 +525,11 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
"cmp %[nc], #4 \n\t"
"bne temp_%= \n\t"
"vmov.f32 d8[0], %[alpha] \n\t"
"vmov.f32 d8[1], %[beta] \n\t"
"cmp %[flag_alpha], #1 \n\t"
"bne alpha_%= \n\t"
"alpha_%=: \n\t"
"vmov.f32 d8[0], %[alpha] \n\t"
"vmul.f32 q10, q10, d8[0] \n\t"
"vmul.f32 q11, q11, d8[0] \n\t"
"vmul.f32 q12, q12, d8[0] \n\t"
......@@ -575,16 +557,18 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
"b memory_%= \n\t"
"beta_ne1_%=: \n\t"
"vmov.f32 d8[1], %[beta] \n\t"
"vmla.f32 q10, q0, d8[1] \n\t"
"vmla.f32 q11, q1, d8[1] \n\t"
"vmla.f32 q12, q2, d8[1] \n\t"
"vmla.f32 q13, q3, d8[1] \n\t"
"memory_%=: \n\t"
"vmax.f32 q10, q10, q14 \n\t"
"vmax.f32 q11, q11, q14 \n\t"
"vmax.f32 q12, q12, q14 \n\t"
"vmax.f32 q13, q13, q14 \n\t"
"vmov.f32 q14, #0.0 \n\t"
"vmax.f32 q10, q10, q14 \n\t"
"vmax.f32 q11, q11, q14 \n\t"
"vmax.f32 q12, q12, q14 \n\t"
"vmax.f32 q13, q13, q14 \n\t"
"mov r5, %[C] \n\t"
"mov r6, %[bytes_ldc]\n\t"
"vst1.32 {q10}, [r5], r6 \n\t"
......@@ -602,7 +586,8 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
[kc2] "r"(kc2), [mc] "r"(mc), [nc] "r"(nc), [alpha] "r"(alpha),
[beta] "r"(beta), [bytes_ldc] "r"(bytes_ldc),
[flag_alpha] "r"(flag_alpha), [flag_beta] "r"(flag_beta)
: "memory", "q0", "q1", "q2", "q3", "q4", "q10", "q11", "q12", "q13");
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q12", "q13", "q14");
if (mc != MR || nc != NR) {
int i, j;
......
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm.h"
#define a(i, j) a[(i)*lda + (j)]
......@@ -29,10 +31,15 @@ int main() {
int ldb = n;
int ldc = n;
float a[62 * 74];
float b[74 * 63];
float c[62 * 63] = {0};
float c1[62 * 63] = {0};
float *a =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * k));
float *b =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * k * n));
float *c =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
float *c1 =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
for (int i = 0; i < m * k; ++i) {
a[i] = 2;
}
......@@ -44,8 +51,11 @@ int main() {
c1[i] = 2;
}
auto time1 = time();
paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c,
ldc);
auto time2 = time();
DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n";
for (int i = 0; i < m * n; ++i) {
std::cout << c[i] << " | ";
if (i % n == (n - 1)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册