未验证 提交 27819153 编写于 作者: S smilejames 提交者: GitHub

Merge pull request #463 from smilejames/develop

optimize gemm code
...@@ -26,12 +26,12 @@ alignas(64) float packedA[MC * KC]; ...@@ -26,12 +26,12 @@ alignas(64) float packedA[MC * KC];
alignas(64) float packedB[KC * NC]; alignas(64) float packedB[KC * NC];
alignas(64) float ab[MR * NR]; alignas(64) float ab[MR * NR];
// 将A矩阵分块复制到连续内存(ColMajor) // 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int paddingM, const float *A, int lda, void PackMatrixA(int m, int k, const float *A, int lda, float *buffer) {
float *buffer) { int i, j, m_tail;
int i, j;
const float *Aij; const float *Aij;
for (i = 0; i < m - paddingM; i += MR) { m_tail = m % NR;
for (int j = 0; j < k; ++j) { for (i = 0; i < m - m_tail; i += MR) {
for (j = 0; j < k; ++j) {
Aij = &A(i, j); Aij = &A(i, j);
*buffer++ = *Aij; *buffer++ = *Aij;
*buffer++ = *(Aij + 1); *buffer++ = *(Aij + 1);
...@@ -39,13 +39,13 @@ void PackMatrixA(int m, int k, int paddingM, const float *A, int lda, ...@@ -39,13 +39,13 @@ void PackMatrixA(int m, int k, int paddingM, const float *A, int lda,
*buffer++ = *(Aij + 3); *buffer++ = *(Aij + 3);
} }
} }
if (paddingM != 0) { if (m_tail != 0) {
for (j = 0; j < k; ++j) { for (j = 0; j < k; ++j) {
Aij = &A(m - paddingM, j); Aij = &A(m - m_tail, j);
for (i = 0; i < paddingM; ++i) { for (i = 0; i < m_tail; ++i) {
*buffer++ = *(Aij + i); *buffer++ = *(Aij + i);
} }
for (i = paddingM; i < MR; ++i) { for (i = m_tail; i < MR; ++i) {
*buffer++ = 0; *buffer++ = 0;
} }
} }
...@@ -53,11 +53,11 @@ void PackMatrixA(int m, int k, int paddingM, const float *A, int lda, ...@@ -53,11 +53,11 @@ void PackMatrixA(int m, int k, int paddingM, const float *A, int lda,
} }
// 将A矩阵分块复制到连续内存(RowMajor) // 将A矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda, void PackMatrixA_(int m, int k, const float *A, int lda, float *buffer) {
float *buffer) { int i, j, m_tail;
int i, j;
const float *Ai, *Ai1, *Ai2, *Ai3; const float *Ai, *Ai1, *Ai2, *Ai3;
for (i = 0; i < m - paddingM; i += MR) { m_tail = m % NR;
for (i = 0; i < m - m_tail; i += MR) {
Ai = &A(i, 0); Ai = &A(i, 0);
Ai1 = &A(i + 1, 0); Ai1 = &A(i + 1, 0);
Ai2 = &A(i + 2, 0); Ai2 = &A(i + 2, 0);
...@@ -69,12 +69,12 @@ void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda, ...@@ -69,12 +69,12 @@ void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda,
*buffer++ = *Ai3++; *buffer++ = *Ai3++;
} }
} }
if (paddingM != 0) { if (m_tail != 0) {
for (j = 0; j < k; ++j) { for (j = 0; j < k; ++j) {
for (i = m - paddingM; i < m; ++i) { for (i = m - m_tail; i < m; ++i) {
*buffer++ = A(i, j); *buffer++ = A(i, j);
} }
for (i = m; i < m + (MR - paddingM); ++i) { for (i = m; i < m + (MR - m_tail); ++i) {
*buffer++ = 0; *buffer++ = 0;
} }
} }
...@@ -82,11 +82,11 @@ void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda, ...@@ -82,11 +82,11 @@ void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda,
} }
// 将B矩阵分块复制到连续内存(ColMajor) // 将B矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb, void PackMatrixB(int k, int n, const float *B, int ldb, float *buffer) {
float *buffer) { int i, j, n_tail;
int i, j;
const float *Bj, *Bj1, *Bj2, *Bj3; const float *Bj, *Bj1, *Bj2, *Bj3;
for (j = 0; j < n - paddingN; j += NR) { n_tail = n % NR;
for (j = 0; j < n - n_tail; j += NR) {
Bj = &B(0, j); Bj = &B(0, j);
Bj1 = &B(0, j + 1); Bj1 = &B(0, j + 1);
Bj2 = &B(0, j + 2); Bj2 = &B(0, j + 2);
...@@ -98,12 +98,12 @@ void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb, ...@@ -98,12 +98,12 @@ void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb,
*buffer++ = *Bj3++; *buffer++ = *Bj3++;
} }
} }
if (paddingN != 0) { if (n_tail != 0) {
for (i = 0; i < k; ++i) { for (i = 0; i < k; ++i) {
for (int j = n - paddingN; j < n; ++j) { for (int j = n - n_tail; j < n; ++j) {
*buffer++ = B(i, j); *buffer++ = B(i, j);
} }
for (int j = n; j < n + (NR - paddingN); ++j) { for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0; *buffer++ = 0;
} }
} }
...@@ -111,11 +111,11 @@ void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb, ...@@ -111,11 +111,11 @@ void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb,
} }
// 将B矩阵分块复制到连续内存(RowMajor) // 将B矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb, void PackMatrixB_(int k, int n, const float *B, int ldb, float *buffer) {
float *buffer) { int i, j, n_tail;
int i, j;
const float *Bij; const float *Bij;
for (j = 0; j < n - paddingN; j += NR) { n_tail = n % NR;
for (j = 0; j < n - n_tail; j += NR) {
for (i = 0; i < k; ++i) { for (i = 0; i < k; ++i) {
Bij = &B(i, j); Bij = &B(i, j);
asm volatile( asm volatile(
...@@ -126,13 +126,13 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb, ...@@ -126,13 +126,13 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
: "memory", "q0"); : "memory", "q0");
} }
} }
if (paddingN != 0) { if (n_tail != 0) {
for (i = 0; i < k; ++i) { for (i = 0; i < k; ++i) {
Bij = &B(i, n - paddingN); Bij = &B(i, n - n_tail);
for (int j = n - paddingN; j < n; ++j) { for (int j = n - n_tail; j < n; ++j) {
*buffer++ = *Bij++; *buffer++ = *Bij++;
} }
for (int j = n; j < n + (NR - paddingN); ++j) { for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0; *buffer++ = 0;
} }
} }
...@@ -143,33 +143,25 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb, ...@@ -143,33 +143,25 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda, void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
int first_time) { int first_time) {
int Buff_A_M = m; int m_block = (m + MR - 1) / MR * MR;
int Buff_B_N = n; int n_block = (n + NR - 1) / NR * NR;
int _mc = m % MR; int m_tail = m % MR;
int _nc = n % NR; int n_tail = n % NR;
if (_mc != 0) {
Buff_A_M = m + (MR - _mc);
}
if (_nc != 0) {
Buff_B_N = n + (NR - _nc);
}
if (first_time) { if (first_time) {
PackMatrixB_(k, n, _nc, B, ldb, packedB); PackMatrixB_(k, n, B, ldb, packedB);
} }
PackMatrixA_(m, k, _mc, A, lda, packedA); PackMatrixA_(m, k, A, lda, packedA);
int i, j, mc, nc; int i, j, mc, nc;
// B 取 4 列, 打包预热 // B 取 4 列, 打包预热
for (j = 0; j < Buff_B_N; j += NR) { for (j = 0; j < n_block; j += NR) {
nc = (n - j) < NR ? _nc : NR; nc = (n - j) < NR ? n_tail : NR;
// A 取 4 行,打包预热 // A 取 4 行,打包预热
for (i = 0; i < Buff_A_M; i += MR) { for (i = 0; i < m_block; i += MR) {
mc = (m - i) < MR ? _mc : MR; mc = (m - i) < MR ? m_tail : MR;
AddDot4x4(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta, AddDot4x4(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta,
&C(i, j), ldc, mc, nc); &C(i, j), ldc, mc, nc);
} }
...@@ -180,36 +172,25 @@ void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -180,36 +172,25 @@ void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
void InnerKernel_relu(int m, int n, int k, float alpha, const float *A, int lda, void InnerKernel_relu(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
int first_time, bool relu = false) { int first_time, bool relu = false) {
int Buff_A_M = m; int m_block = (m + MR - 1) / MR * MR;
int Buff_B_N = n; int n_block = (n + NR - 1) / NR * NR;
int _mc = m % MR;
int _nc = n % NR;
if (_mc != 0) {
Buff_A_M = m + (MR - _mc);
}
if (_nc != 0) {
Buff_B_N = n + (NR - _nc);
}
float packedA[MC * KC]; int m_tail = m % MR;
static float packedB[KC * NC]; int n_tail = n % NR;
if (first_time) { if (first_time) {
PackMatrixB_(k, n, _nc, B, ldb, packedB); PackMatrixB_(k, n, B, ldb, packedB);
} }
PackMatrixA_(m, k, _mc, A, lda, packedA); PackMatrixA_(m, k, A, lda, packedA);
int i, j, mc, nc; int i, j, mc, nc;
// B 取 4 列, 打包预热 // B 取 4 列, 打包预热
for (j = 0; j < Buff_B_N; j += NR) { for (j = 0; j < n_block; j += NR) {
nc = (n - j) < NR ? _nc : NR; nc = (n - j) < NR ? n_tail : NR;
// A 取 4 行,打包预热 // A 取 4 行,打包预热
for (i = 0; i < Buff_A_M; i += MR) { for (i = 0; i < m_block; i += MR) {
mc = (m - i) < MR ? _mc : MR; mc = (m - i) < MR ? m_tail : MR;
AddDot4x4_relu(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta, AddDot4x4_relu(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta,
&C(i, j), ldc, mc, nc, relu); &C(i, j), ldc, mc, nc, relu);
} }
...@@ -359,16 +340,16 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b, ...@@ -359,16 +340,16 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
"vmla.f32 q11, q3, d2[1] \n\t" "vmla.f32 q11, q3, d2[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t" "vmla.f32 q12, q3, d3[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t" "vmla.f32 q13, q3, d3[1] \n\t"
"vld1.32 {q0, q1}, [%[a]]! \n\t" "vld1.32 {q4, q5}, [%[a]]! \n\t"
"vld1.32 {q2, q3}, [%[b]]! \n\t" "vld1.32 {q6, q7}, [%[b]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t" "vmla.f32 q10, q6, d8[0] \n\t"
"vmla.f32 q11, q2, d0[1] \n\t" "vmla.f32 q11, q6, d8[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t" "vmla.f32 q12, q6, d9[0] \n\t"
"vmla.f32 q13, q2, d1[1] \n\t" "vmla.f32 q13, q6, d9[1] \n\t"
"vmla.f32 q10, q3, d2[0] \n\t" "vmla.f32 q10, q7, d10[0] \n\t"
"vmla.f32 q11, q3, d2[1] \n\t" "vmla.f32 q11, q7, d10[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t" "vmla.f32 q12, q7, d11[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t" "vmla.f32 q13, q7, d11[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t" "subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t" "bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t" "end_kc1_%=: \n\t"
...@@ -391,13 +372,11 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b, ...@@ -391,13 +372,11 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
"cmp %[nc], #4 \n\t" "cmp %[nc], #4 \n\t"
"bne temp_%= \n\t" "bne temp_%= \n\t"
"vmov.f32 d8[0], %[alpha] \n\t"
"vmov.f32 d8[1], %[beta] \n\t"
"cmp %[flag_alpha], #1 \n\t" "cmp %[flag_alpha], #1 \n\t"
"bne alpha_%= \n\t" "bne alpha_%= \n\t"
"alpha_%=: \n\t" "alpha_%=: \n\t"
"vmov.f32 d8[0], %[alpha] \n\t"
"vmul.f32 q10, q10, d8[0] \n\t" "vmul.f32 q10, q10, d8[0] \n\t"
"vmul.f32 q11, q11, d8[0] \n\t" "vmul.f32 q11, q11, d8[0] \n\t"
"vmul.f32 q12, q12, d8[0] \n\t" "vmul.f32 q12, q12, d8[0] \n\t"
...@@ -425,6 +404,7 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b, ...@@ -425,6 +404,7 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
"b memory_%= \n\t" "b memory_%= \n\t"
"beta_ne1_%=: \n\t" "beta_ne1_%=: \n\t"
"vmov.f32 d8[1], %[beta] \n\t"
"vmla.f32 q10, q0, d8[1] \n\t" "vmla.f32 q10, q0, d8[1] \n\t"
"vmla.f32 q11, q1, d8[1] \n\t" "vmla.f32 q11, q1, d8[1] \n\t"
"vmla.f32 q12, q2, d8[1] \n\t" "vmla.f32 q12, q2, d8[1] \n\t"
...@@ -448,7 +428,8 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b, ...@@ -448,7 +428,8 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
[kc2] "r"(kc2), [mc] "r"(mc), [nc] "r"(nc), [alpha] "r"(alpha), [kc2] "r"(kc2), [mc] "r"(mc), [nc] "r"(nc), [alpha] "r"(alpha),
[beta] "r"(beta), [bytes_ldc] "r"(bytes_ldc), [beta] "r"(beta), [bytes_ldc] "r"(bytes_ldc),
[flag_alpha] "r"(flag_alpha), [flag_beta] "r"(flag_beta) [flag_alpha] "r"(flag_alpha), [flag_beta] "r"(flag_beta)
: "memory", "q0", "q1", "q2", "q3", "q4", "q10", "q11", "q12", "q13"); : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q12", "q13");
if (mc != MR || nc != NR) { if (mc != MR || nc != NR) {
int i, j; int i, j;
...@@ -512,28 +493,31 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, ...@@ -512,28 +493,31 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
"vmla.f32 q11, q3, d2[1] \n\t" "vmla.f32 q11, q3, d2[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t" "vmla.f32 q12, q3, d3[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t" "vmla.f32 q13, q3, d3[1] \n\t"
"vld1.32 {q0, q1}, [%[a]]! \n\t" "vld1.32 {q4, q5}, [%[a]]! \n\t"
"vld1.32 {q2, q3}, [%[b]]! \n\t" "vld1.32 {q6, q7}, [%[b]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t" "vmla.f32 q10, q6, d8[0] \n\t"
"vmla.f32 q11, q2, d0[1] \n\t" "vmla.f32 q11, q6, d8[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t" "vmla.f32 q12, q6, d9[0] \n\t"
"vmla.f32 q13, q2, d1[1] \n\t" "vmla.f32 q13, q6, d9[1] \n\t"
"vmla.f32 q10, q3, d2[0] \n\t" "vmla.f32 q10, q7, d10[0] \n\t"
"vmla.f32 q11, q3, d2[1] \n\t" "vmla.f32 q11, q7, d10[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t" "vmla.f32 q12, q7, d11[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t" "vmla.f32 q13, q7, d11[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t" "subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t" "bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t" "end_kc1_%=: \n\t"
"subs %[kc2], %[kc2], #1 \n\t" "subs %[kc2], %[kc2], #1 \n\t"
"blt end_kc2_%= \n\t" "blt end_kc2_%= \n\t"
"loop_kc2_%=: \n\t"
"vld1.32 {q0}, [%[a]]! \n\t" "vld1.32 {q0}, [%[a]]! \n\t"
"vld1.32 {q1}, [%[b]]! \n\t" "vld1.32 {q1}, [%[b]]! \n\t"
"vmla.f32 q10, q1, d0[0] \n\t" "vmla.f32 q10, q1, d0[0] \n\t"
"vmla.f32 q11, q1, d0[1] \n\t" "vmla.f32 q11, q1, d0[1] \n\t"
"vmla.f32 q12, q1, d1[0] \n\t" "vmla.f32 q12, q1, d1[0] \n\t"
"vmla.f32 q13, q1, d1[1] \n\t" "vmla.f32 q13, q1, d1[1] \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"bge loop_kc2_%= \n\t"
"end_kc2_%=: \n\t" "end_kc2_%=: \n\t"
"cmp %[mc], #4 \n\t" "cmp %[mc], #4 \n\t"
...@@ -541,13 +525,11 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, ...@@ -541,13 +525,11 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
"cmp %[nc], #4 \n\t" "cmp %[nc], #4 \n\t"
"bne temp_%= \n\t" "bne temp_%= \n\t"
"vmov.f32 d8[0], %[alpha] \n\t"
"vmov.f32 d8[1], %[beta] \n\t"
"cmp %[flag_alpha], #1 \n\t" "cmp %[flag_alpha], #1 \n\t"
"bne alpha_%= \n\t" "bne alpha_%= \n\t"
"alpha_%=: \n\t" "alpha_%=: \n\t"
"vmov.f32 d8[0], %[alpha] \n\t"
"vmul.f32 q10, q10, d8[0] \n\t" "vmul.f32 q10, q10, d8[0] \n\t"
"vmul.f32 q11, q11, d8[0] \n\t" "vmul.f32 q11, q11, d8[0] \n\t"
"vmul.f32 q12, q12, d8[0] \n\t" "vmul.f32 q12, q12, d8[0] \n\t"
...@@ -575,16 +557,18 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, ...@@ -575,16 +557,18 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
"b memory_%= \n\t" "b memory_%= \n\t"
"beta_ne1_%=: \n\t" "beta_ne1_%=: \n\t"
"vmov.f32 d8[1], %[beta] \n\t"
"vmla.f32 q10, q0, d8[1] \n\t" "vmla.f32 q10, q0, d8[1] \n\t"
"vmla.f32 q11, q1, d8[1] \n\t" "vmla.f32 q11, q1, d8[1] \n\t"
"vmla.f32 q12, q2, d8[1] \n\t" "vmla.f32 q12, q2, d8[1] \n\t"
"vmla.f32 q13, q3, d8[1] \n\t" "vmla.f32 q13, q3, d8[1] \n\t"
"memory_%=: \n\t" "memory_%=: \n\t"
"vmax.f32 q10, q10, q14 \n\t" "vmov.f32 q14, #0.0 \n\t"
"vmax.f32 q11, q11, q14 \n\t" "vmax.f32 q10, q10, q14 \n\t"
"vmax.f32 q12, q12, q14 \n\t" "vmax.f32 q11, q11, q14 \n\t"
"vmax.f32 q13, q13, q14 \n\t" "vmax.f32 q12, q12, q14 \n\t"
"vmax.f32 q13, q13, q14 \n\t"
"mov r5, %[C] \n\t" "mov r5, %[C] \n\t"
"mov r6, %[bytes_ldc]\n\t" "mov r6, %[bytes_ldc]\n\t"
"vst1.32 {q10}, [r5], r6 \n\t" "vst1.32 {q10}, [r5], r6 \n\t"
...@@ -602,7 +586,8 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, ...@@ -602,7 +586,8 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
[kc2] "r"(kc2), [mc] "r"(mc), [nc] "r"(nc), [alpha] "r"(alpha), [kc2] "r"(kc2), [mc] "r"(mc), [nc] "r"(nc), [alpha] "r"(alpha),
[beta] "r"(beta), [bytes_ldc] "r"(bytes_ldc), [beta] "r"(beta), [bytes_ldc] "r"(bytes_ldc),
[flag_alpha] "r"(flag_alpha), [flag_beta] "r"(flag_beta) [flag_alpha] "r"(flag_alpha), [flag_beta] "r"(flag_beta)
: "memory", "q0", "q1", "q2", "q3", "q4", "q10", "q11", "q12", "q13"); : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q12", "q13", "q14");
if (mc != MR || nc != NR) { if (mc != MR || nc != NR) {
int i, j; int i, j;
......
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <iostream> #include <iostream>
#include "../test_helper.h"
#include "common/log.h" #include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm.h" #include "operators/math/gemm.h"
#define a(i, j) a[(i)*lda + (j)] #define a(i, j) a[(i)*lda + (j)]
...@@ -29,10 +31,15 @@ int main() { ...@@ -29,10 +31,15 @@ int main() {
int ldb = n; int ldb = n;
int ldc = n; int ldc = n;
float a[62 * 74]; float *a =
float b[74 * 63]; static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * k));
float c[62 * 63] = {0}; float *b =
float c1[62 * 63] = {0}; static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * k * n));
float *c =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
float *c1 =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
for (int i = 0; i < m * k; ++i) { for (int i = 0; i < m * k; ++i) {
a[i] = 2; a[i] = 2;
} }
...@@ -44,8 +51,11 @@ int main() { ...@@ -44,8 +51,11 @@ int main() {
c1[i] = 2; c1[i] = 2;
} }
auto time1 = time();
paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c,
ldc); ldc);
auto time2 = time();
DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n";
for (int i = 0; i < m * n; ++i) { for (int i = 0; i < m * n; ++i) {
std::cout << c[i] << " | "; std::cout << c[i] << " | ";
if (i % n == (n - 1)) { if (i % n == (n - 1)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册