提交 1c893a02 编写于 作者: Z Zhen Wang

add gemm_int8 UT

上级 95e6a2ee
...@@ -142,61 +142,6 @@ void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, ...@@ -142,61 +142,6 @@ void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
} }
} }
// 8位 int PackMatrixA函数
void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const int8_t *A, int lda,
int8_t *buffer) {
const int i_length = m - m_tail;
for (int i = 0; i < i_length; i += MR) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
const int8_t *a4 = A + (i + 4) * lda;
const int8_t *a5 = A + (i + 5) * lda;
int8_t *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
const int8_t *a4 = a0 + 4 * lda;
const int8_t *a5 = a0 + 5 * lda;
int8_t *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
case 4:
a4 = zero_int8;
case 5:
a5 = zero_int8;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
}
void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) { float *buffer) {
const int i_length = m - m_tail; const int i_length = m - m_tail;
...@@ -439,48 +384,6 @@ void Gemm::PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, ...@@ -439,48 +384,6 @@ void Gemm::PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
} }
} }
// 8位 int PackMatrixB函数
void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const int8_t *B, int ldb,
int8_t *buffer) {
const int j_length = n - n_tail;
for (int j = 0; j < j_length; j += NR) {
int8_t *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j);
#if __ARM_NEON
asm volatile(
// "pld [%[b0]] \n\t"
"vld1.s8 {d0}, [%[b0]] \n\t"
"vst1.s8 {d0}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "q0");
#else
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
#endif // __ARM_NEON
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 将B矩阵分块复制到连续内存(RowMajor) // 将B矩阵分块复制到连续内存(RowMajor)
void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
...@@ -745,42 +648,6 @@ void Gemm::InnerKernel(int mc, int nc, float alpha, const float *a, ...@@ -745,42 +648,6 @@ void Gemm::InnerKernel(int mc, int nc, float alpha, const float *a,
} }
} }
// 8位 int 分块矩阵乘法
void Gemm::InnerKernelWithBias(int mc, int nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int *c, int *C,
int ldc, bool relu, int8_t *bias) {
#pragma omp parallel for
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
}
}
if (alpha != 1) {
WriteWithAlphaBeta(mc, nc, c, C, ldc);
return;
}
if (beta == 0) {
WriteBasic(mc, nc, c, C, ldc);
return;
}
if (beta == 1 && !relu) {
if (bias == nullptr) {
WriteWithAdd(mc, nc, c, C, ldc);
} else {
WriteWithAddV1(mc, nc, c, C, ldc, bias);
}
return;
}
if (beta == 1 && relu) {
if (bias == nullptr) {
WriteWithAddRelu(mc, nc, c, C, ldc);
} else {
WriteWithAddReluV1(mc, nc, c, C, ldc, bias);
}
return;
}
}
// 分块矩阵乘法 // 分块矩阵乘法
void Gemm::InnerKernelWithBias(int mc, int nc, float alpha, const float *a, void Gemm::InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, const float *b, float beta, float *c, float *C,
...@@ -2007,63 +1874,6 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -2007,63 +1874,6 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
"q10", "q11", "q12", "q13", "q14", "q15"); "q10", "q11", "q12", "q13", "q14", "q15");
} }
// C = A * B, 8位 int
void Gemm::WriteBasic(int mc, int nc, int *c, int *C, int ldc) {
int nc1 = nc >> 4;
int _nc1 = nc & 15;
int step = sizeof(int) * ldc;
int step1 = sizeof(int) * (NC - (nc1 << 4));
int volatile m = mc;
int *volatile c_ptr, *volatile C_ptr;
int *C0, *c0;
c_ptr = c;
C_ptr = C;
if (nc1 > 0) {
asm volatile(
"subs %[mc], %[mc], #1 \n\t"
"blt end_mc_%= \n\t"
"loop_mc_%=: \n\t"
"mov r6, %[C_ptr] \n\t"
"mov r5, %[nc1] \n\t"
"subs r5, r5, #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c_ptr]]! \n\t"
"vst1.32 {q0, q1}, [r6]! \n\t"
"vld1.32 {q2, q3}, [%[c_ptr]]! \n\t"
"vst1.32 {q2, q3}, [r6]! \n\t"
"subs r5, r5, #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"add %[C_ptr], %[C_ptr], %[step] \n\t"
"add %[c_ptr], %[c_ptr], %[step1] \n\t"
"subs %[mc], %[mc], #1 \n\t"
"bge loop_mc_%= \n\t"
"end_mc_%=: \n\t"
:
: [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1),
[step] "r"(step), [step1] "r"(step1)
: "memory", "r5", "r6", "q0", "q1", "q2", "q3");
}
if (_nc1 != 0) {
for (int i = 0; i < mc; i++) {
C0 = C_ptr + nc1 * 16 + i * ldc;
c0 = c_ptr + nc1 * 16 + i * NC;
for (int j = 0; j < _nc1; j++) {
*C0++ = *c0++;
}
}
}
}
// C = A * B // C = A * B
void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) { void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16; int nc1 = nc / 16;
...@@ -2121,14 +1931,9 @@ void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) { ...@@ -2121,14 +1931,9 @@ void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
// C = alpha * A * B + beta * C
void Gemm::WriteWithAlphaBeta(int mc, int nc, int *c, int *C, int ldc) {}
// C = alpha * A * B + beta * C // C = alpha * A * B + beta * C
void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {} void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
// C = A * B + C
void Gemm::WriteWithAdd(int mc, int nc, int *c, int *C, int ldc) {}
// C = A * B + C // C = A * B + C
void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16; int nc1 = nc / 16;
...@@ -2193,9 +1998,6 @@ void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { ...@@ -2193,9 +1998,6 @@ void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
// C = A * B + bias
void Gemm::WriteWithAddV1(int mc, int nc, int *c, int *C, int ldc,
int8_t *bias) {}
// C = A * B + bias // C = A * B + bias
void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) { float *bias) {
...@@ -2235,9 +2037,6 @@ void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, ...@@ -2235,9 +2037,6 @@ void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc,
} }
} }
// C = A * B + C, relu(C)
void Gemm::WriteWithAddRelu(int mc, int nc, int *c, int *C, int ldc) {}
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16; int nc1 = nc / 16;
...@@ -2311,9 +2110,6 @@ void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { ...@@ -2311,9 +2110,6 @@ void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
} }
// C = A * B + bias, relu(C)
void Gemm::WriteWithAddReluV1(int mc, int nc, int *c, int *C, int ldc,
int8_t *bias) {}
// C = A * B + bias, relu(C) // C = A * B + bias, relu(C)
void Gemm::WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, void Gemm::WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
...@@ -3200,69 +2996,6 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, ...@@ -3200,69 +2996,6 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
#endif // __ARM_NEON #endif // __ARM_NEON
// 8位 int 矩阵乘法 (m*k与k*n的乘积)
void Gemm::Sgemm(int m, int n, int k, float alpha, const int8_t *A, int lda,
const int8_t *B, int ldb, float beta, int *C, int ldc,
bool relu, int8_t *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
int L2 = 512 * 1024;
KC = k;
MC = L1 / (KC * sizeof(int8_t));
NC = L2 / (KC * sizeof(int8_t));
// make sure MC is multiple of MR, and NC is multiple of NR
if (MC == 0) {
MC = MR;
} else {
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
packedC_int8 = static_cast<int *>(
paddle_mobile::memory::Alloc(sizeof(int) * MC * NC));
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * KC);
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8);
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8);
if (bias == nullptr) {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int8, &C(i, j), ldc, relu, nullptr);
} else {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int8, &C(i, j), ldc, relu, bias + i);
}
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int8);
paddle_mobile::memory::Free(zero_int8);
}
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
...@@ -3856,125 +3589,6 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3856,125 +3589,6 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
paddle_mobile::memory::Free(zero); paddle_mobile::memory::Free(zero);
} }
void Gemm::AddDot6x8(int k, const int8_t *a, const int8_t *b, int *c, int ldc) {
#if __ARM_NEON
const int8_t *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k >> 1;
int kc2 = k & 1;
int step = sizeof(int) * ldc;
asm volatile(
// q4-q15: save 48 results
"vmov.s8 q4, #0 \n\t"
"vmov.s8 q5, #0 \n\t"
"vmov.s8 q6, #0 \n\t"
"vmov.s8 q7, #0 \n\t"
"vmov.s8 q8, #0 \n\t"
"vmov.s8 q9, #0 \n\t"
"vmov.s8 q10, #0 \n\t"
"vmov.s8 q11, #0 \n\t"
"vmov.s8 q12, #0 \n\t"
"vmov.s8 q13, #0 \n\t"
"vmov.s8 q14, #0 \n\t"
"vmov.s8 q15, #0 \n\t"
"mov r0, #6 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt 1f \n\t"
"0: \n\t"
"vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0
"vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 used
"vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B row1, q1 used
"vmov.s8 q2, #0 \n\t" // q2 used
"vdup.s8 d6, d0[0] \n\t" // q3 used(but d7 is free)
"vmlal.s8 q2, d2, d6 \n\t" // A col00 * B row0
"vdup.s8 d6, d1[0] \n\t"
"vmlal.s8 q2, d3, d6 \n\t" // A col10 * B row1, q3 free
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[1] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vdup.s8 d6, d1[1] \n\t"
"vmlal.s8 q2, d3, d6 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[2] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vdup.s8 d6, d1[2] \n\t"
"vmlal.s8 q2, d3, d6 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[3] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vdup.s8 d6, d1[3] \n\t"
"vmlal.s8 q2, d3, d6 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[4] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vdup.s8 d6, d1[4] \n\t"
"vmlal.s8 q2, d3, d6 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[5] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vdup.s8 d6, d1[5] \n\t"
"vmlal.s8 q2, d3, d6 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"subs %[kc1], %[kc1], #1 \n\t"
"bge 0b \n\t"
"1: \n\t" // odd, last row
"subs %[kc2], %[kc2], #1 \n\t"
"blt 2f \n\t"
"vld1.s8 {d0}, [%[a_ptr]] \n\t"
"vld1.s8 {d1}, [%[b_ptr]] \n\t"
"vdup.s8 d2, d0[0] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vdup.s8 d2, d0[1] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vdup.s8 d2, d0[2] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vdup.s8 d2, d0[3] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vdup.s8 d2, d0[4] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vdup.s8 d2, d0[5] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 4
"2: \n\t"
"vst1.32 {q4, q5}, [%[c]], %[step] \n\t"
"vst1.32 {q6, q7}, [%[c]], %[step] \n\t"
"vst1.32 {q8, q9}, [%[c]], %[step] \n\t"
"vst1.32 {q10, q11}, [%[c]], %[step] \n\t"
"vst1.32 {q12, q13}, [%[c]], %[step] \n\t"
"vst1.32 {q14, q15}, [%[c]] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step)
: "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif
}
void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
......
...@@ -80,12 +80,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -80,12 +80,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
// 8位 int
void PackMatrixA_6r(int m, int k, int m_tail, const int8_t *A, int lda,
int8_t *buffer);
void PackMatrixB_8c(int k, int n, int n_tail, const int8_t *B, int ldb,
int8_t *buffer);
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu); float beta, float *c, float *C, int ldc, bool relu);
...@@ -104,11 +98,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -104,11 +98,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *c, float *C, int ldc, float *p, float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1); std::string mode, float *bias, float *bias1);
// 8位 int
void InnerKernelWithBias(int mc, int nc, float alpha, const int8_t *a,
const int8_t *b, float beta, int *c, int *C, int ldc,
bool relu, int8_t *bias);
/* /*
// 向量矩阵乘法 (M = 1) // 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
...@@ -127,8 +116,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -127,8 +116,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc); void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc); void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const int8_t *a, const int8_t *b, int *c, int ldc);
// 分块矩阵乘法结果回写 // 分块矩阵乘法结果回写
// C = A * B // C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc); void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
...@@ -154,20 +141,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -154,20 +141,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *new_scale, float *new_bias); float *new_scale, float *new_bias);
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias1); float *new_scale, float *new_bias, float *bias1);
// 8位 int 分块矩阵乘法结果回写
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, int *c, int *C, int ldc);
// C = A * B
void WriteBasic(int mc, int nc, int *c, int *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, int *c, int *C, int ldc);
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, int *c, int *C, int ldc, int8_t *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, int *c, int *C, int ldc);
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int mc, int nc, int *c, int *C, int ldc,
int8_t *bias);
/* /*
// 向量矩阵乘法结果回写 // 向量矩阵乘法结果回写
// C = A * B // C = A * B
...@@ -186,11 +160,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -186,11 +160,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *new_bias); float *new_bias);
*/ */
// 8位 int 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const int8_t *A, int lda,
const int8_t *B, int ldb, float beta, int *C, int ldc, bool relu,
int8_t *bias);
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu, const float *B, int ldb, float beta, float *C, int ldc, bool relu,
...@@ -219,6 +188,47 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -219,6 +188,47 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
const float *B, int ldb, float *C, int ldc, float *p, const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1); std::string mode, float *bias, float *bias1);
/************************ 8 bit function cluster ************************/
// 8 bit int small block inner product
void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc);
// 8 bit int inner product
void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha,
const int8_t *a, const int8_t *b, int8_t beta,
int32_t *c, int32_t *C, int32_t ldc, bool relu,
int8_t *bias);
// 8 bit int pack function
void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer);
void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer);
// 8 bit int matrix product
void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C,
int32_t ldc, bool relu, int8_t *bias);
// 8 bit int write back
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B
void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc);
// C = A * B + C
void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias
void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc);
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias);
private: private:
int MC = 0; int MC = 0;
int KC = 0; int KC = 0;
...@@ -230,10 +240,10 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -230,10 +240,10 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *packedC; float *packedC;
float *zero; float *zero;
// 8 int // 8 bit int
int8_t *packedA_int8; int8_t *packedA_int8;
int8_t *packedB_int8; int8_t *packedB_int8;
int *packedC_int8; int32_t *packedC_int8;
int8_t *zero_int8; int8_t *zero_int8;
}; };
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string.h>
#include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm.h"
#if __ARM_NEON
#include <arm_neon.h>
#endif
#ifdef _OPENMP
#include <omp.h>
#endif
namespace paddle_mobile {
namespace operators {
namespace math {
// 8 bit int small block inner product
void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
#if __ARM_NEON
const int8_t *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int32_t kc1 = k >> 1;
int32_t kc2 = k & 1;
int32_t step = sizeof(int32_t) * ldc;
asm volatile(
// q4-q15: save 48 results
"pld [%[a_ptr]] \n\t"
"pld [%[b_ptr]] \n\t"
"vmov.s8 q4, #0 \n\t"
"vmov.s8 q5, #0 \n\t"
"vmov.s8 q6, #0 \n\t"
"vmov.s8 q7, #0 \n\t"
"vmov.s8 q8, #0 \n\t"
"vmov.s8 q9, #0 \n\t"
"vmov.s8 q10, #0 \n\t"
"vmov.s8 q11, #0 \n\t"
"vmov.s8 q12, #0 \n\t"
"vmov.s8 q13, #0 \n\t"
"vmov.s8 q14, #0 \n\t"
"vmov.s8 q15, #0 \n\t"
"mov r0, #6 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt 1f \n\t"
"0: \n\t"
"pld [%[a_ptr], #64] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0
"vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 used
"vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B row1, q1 used
"vmov.s8 q2, #0 \n\t" // q2 used
"vdup.s8 d6, d0[0] \n\t"
"vdup.s8 d7, d1[0] \n\t" // q3 used
"vmlal.s8 q2, d2, d6 \n\t" // A col00 * B row0
"vmlal.s8 q2, d3, d7 \n\t" // A col10 * B row1, q3 free
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[1] \n\t"
"vdup.s8 d7, d1[1] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[2] \n\t"
"vdup.s8 d7, d1[2] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[3] \n\t"
"vdup.s8 d7, d1[3] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[4] \n\t"
"vdup.s8 d7, d1[4] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[5] \n\t"
"vdup.s8 d7, d1[5] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"subs %[kc1], %[kc1], #1 \n\t"
"bge 0b \n\t"
"1: \n\t" // odd, last row
"subs %[kc2], %[kc2], #1 \n\t"
"blt 2f \n\t"
"vld1.s8 {d0}, [%[a_ptr]] \n\t"
"vld1.s8 {d1}, [%[b_ptr]] \n\t"
"vdup.s8 d2, d0[0] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vdup.s8 d2, d0[1] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vdup.s8 d2, d0[2] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vdup.s8 d2, d0[3] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vdup.s8 d2, d0[4] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vdup.s8 d2, d0[5] \n\t"
"vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 4
"2: \n\t"
"vst1.32 {q4, q5}, [%[c]], %[step] \n\t"
"vst1.32 {q6, q7}, [%[c]], %[step] \n\t"
"vst1.32 {q8, q9}, [%[c]], %[step] \n\t"
"vst1.32 {q10, q11}, [%[c]], %[step] \n\t"
"vst1.32 {q12, q13}, [%[c]], %[step] \n\t"
"vst1.32 {q14, q15}, [%[c]] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step)
: "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif
}
// 8 bit int inner product
void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha,
const int8_t *a, const int8_t *b, int8_t beta,
int32_t *c, int32_t *C, int32_t ldc, bool relu,
int8_t *bias) {
#pragma omp parallel for
for (int32_t j = 0; j < nc; j += NR) {
for (int32_t i = 0; i < mc; i += MR) {
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
}
}
if (alpha != 1) {
WriteWithAlphaBeta(mc, nc, c, C, ldc);
return;
}
if (beta == 0) {
WriteBasic(mc, nc, c, C, ldc);
return;
}
if (beta == 1 && !relu) {
if (bias == nullptr) {
WriteWithAdd(mc, nc, c, C, ldc);
} else {
WriteWithAddV1(mc, nc, c, C, ldc, bias);
}
return;
}
if (beta == 1 && relu) {
if (bias == nullptr) {
WriteWithAddRelu(mc, nc, c, C, ldc);
} else {
WriteWithAddReluV1(mc, nc, c, C, ldc, bias);
}
return;
}
}
// 8 bit int PackMatrixA
void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A,
int32_t lda, int8_t *buffer) {
const int32_t i_length = m - m_tail;
for (int32_t i = 0; i < i_length; i += MR) {
const int8_t *a0 = A + i * lda;
const int8_t *a1 = A + (i + 1) * lda;
const int8_t *a2 = A + (i + 2) * lda;
const int8_t *a3 = A + (i + 3) * lda;
const int8_t *a4 = A + (i + 4) * lda;
const int8_t *a5 = A + (i + 5) * lda;
int8_t *local_buffer = buffer + i * k;
for (int32_t j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
if (m_tail != 0) {
const int8_t *a0 = &A(i_length, 0);
const int8_t *a1 = a0 + lda;
const int8_t *a2 = a0 + 2 * lda;
const int8_t *a3 = a0 + 3 * lda;
const int8_t *a4 = a0 + 4 * lda;
const int8_t *a5 = a0 + 5 * lda;
int8_t *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero_int8;
case 2:
a2 = zero_int8;
case 3:
a3 = zero_int8;
case 4:
a4 = zero_int8;
case 5:
a5 = zero_int8;
break;
default:
break;
}
for (int32_t j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
}
// 8 bit int PackMatrixB
void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
int32_t ldb, int8_t *buffer) {
const int32_t j_length = n - n_tail;
for (int32_t j = 0; j < j_length; j += NR) {
int8_t *local_buffer = buffer + j * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j);
#if __ARM_NEON
asm volatile(
// "pld [%[b0]] \n\t"
"vld1.s8 {d0}, [%[b0]] \n\t"
"vst1.s8 {d0}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "q0");
#else
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
#endif // __ARM_NEON
}
}
if (n_tail != 0) {
int8_t *local_buffer = buffer + j_length * k;
for (int32_t i = 0; i < k; ++i) {
const int8_t *b0 = &B(i, j_length);
for (int32_t j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int32_t j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
}
}
}
}
// 8 bit int matrix product (m*k x k*n)
void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A,
int32_t lda, const int8_t *B, int32_t ldb, int8_t beta,
int32_t *C, int32_t ldc, bool relu, int8_t *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int32_t L1 = 32 * 1024;
int32_t L2 = 512 * 1024;
KC = k;
MC = L1 / (KC * sizeof(int8_t));
NC = L2 / (KC * sizeof(int8_t));
// make sure MC is multiple of MR, and NC is multiple of NR
if (MC == 0) {
MC = MR;
} else {
int32_t mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) {
NC = NR;
} else {
int32_t nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
packedC_int8 = static_cast<int32_t *>(
paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC));
zero_int8 =
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC));
memset(static_cast<void *>(zero_int8), 0, sizeof(int8_t) * KC);
int32_t mc, nc;
for (int32_t j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8);
for (int32_t i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8);
if (bias == nullptr) {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int8, &C(i, j), ldc, relu, nullptr);
} else {
InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta,
packedC_int8, &C(i, j), ldc, relu, bias + i);
}
}
}
paddle_mobile::memory::Free(packedA_int8);
paddle_mobile::memory::Free(packedB_int8);
paddle_mobile::memory::Free(packedC_int8);
paddle_mobile::memory::Free(zero_int8);
}
// 8 bit int write back
// C = alpha * A * B + beta * C
void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) {}
// C = A * B, 8位 int32_t
void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) {
int32_t nc1 = nc >> 4;
int32_t _nc1 = nc & 15;
int32_t step = sizeof(int32_t) * ldc;
int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 4));
int32_t volatile m = mc;
int32_t *volatile c_ptr, *volatile C_ptr;
int32_t *C0, *c0;
c_ptr = c;
C_ptr = C;
if (nc1 > 0) {
asm volatile(
"subs %[mc], %[mc], #1 \n\t"
"blt end_mc_%= \n\t"
"loop_mc_%=: \n\t"
"mov r6, %[C_ptr] \n\t"
"mov r5, %[nc1] \n\t"
"subs r5, r5, #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c_ptr]]! \n\t"
"vst1.32 {q0, q1}, [r6]! \n\t"
"vld1.32 {q2, q3}, [%[c_ptr]]! \n\t"
"vst1.32 {q2, q3}, [r6]! \n\t"
"subs r5, r5, #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"add %[C_ptr], %[C_ptr], %[step] \n\t"
"add %[c_ptr], %[c_ptr], %[step1] \n\t"
"subs %[mc], %[mc], #1 \n\t"
"bge loop_mc_%= \n\t"
"end_mc_%=: \n\t"
:
: [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1),
[step] "r"(step), [step1] "r"(step1)
: "memory", "r5", "r6", "q0", "q1", "q2", "q3");
}
if (_nc1 != 0) {
for (int32_t i = 0; i < mc; i++) {
C0 = C_ptr + nc1 * 16 + i * ldc;
c0 = c_ptr + nc1 * 16 + i * NC;
for (int32_t j = 0; j < _nc1; j++) {
*C0++ = *c0++;
}
}
}
}
// C = A * B + C
void Gemm::WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) {}
// C = A * B + bias
void Gemm::WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias) {}
// C = A * B + C, relu(C)
void Gemm::WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) {}
// C = A * B + bias, relu(C)
void Gemm::WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc, int8_t *bias) {}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
...@@ -25,7 +25,7 @@ template <typename T> ...@@ -25,7 +25,7 @@ template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a, void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha, const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false, framework::Tensor *matrix_out, T beta, bool relu = false,
float *bias = nullptr); T *bias = nullptr);
template <typename T> template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstring>
#include <string>
#include "operators/math/gemm.h"
#include "operators/math/math_function.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <>
void matmul<int8_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
int8_t alpha, framework::Tensor *matrix_out, int8_t beta,
bool relu, int8_t *bias) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_MOBILE_ENFORCE(
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
int32_t M = dim_out[0];
int32_t N = dim_out[1];
int32_t K = (!trans_a) ? dim_a[1] : dim_a[0];
Gemm gemm;
if (trans_a) {
int32_t numel = matrix_a.numel();
int32_t m = matrix_a.dims()[0];
int32_t n = matrix_a.dims()[1];
int8_t *tmp = (int8_t *)(matrix_a.data<int8_t>()); // NOLINT
int8_t *a = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * numel));
int32_t index = 0;
for (int32_t j = 0; j < n; j++) {
for (int32_t i = 0; i < m; i++) {
a[index++] = tmp[i * n + j];
}
}
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<int8_t>(), N, beta,
matrix_out->data<int32_t>(), N, relu, bias);
} else {
gemm.Sgemm(M, N, K, alpha, matrix_a.data<int8_t>(), K,
matrix_b.data<int8_t>(), N, beta, matrix_out->data<int32_t>(), N,
relu, bias);
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
...@@ -57,14 +57,14 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) { ...@@ -57,14 +57,14 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) {
int ldc = n; int ldc = n;
default_random_engine e; default_random_engine e;
uniform_int_distribution<int8_t> pixel(-127, 127); uniform_int_distribution<int8_t> pixel(-127, 127);
int8_t *a = int8_t *a = static_cast<int8_t *>(
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k)); paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k));
int8_t *b = int8_t *b = static_cast<int8_t *>(
static_cast<int8_t *>(paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n)); paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n));
int32_t *c = int32_t *c = static_cast<int32_t *>(
static_cast<int32_t *>(paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n)); paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n));
int32_t *c1 = int32_t *c1 = static_cast<int32_t *>(
static_cast<int32_t *>(paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n)); paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n));
for (int i = 0; i < m * k; ++i) { for (int i = 0; i < m * k; ++i) {
a[i] = pixel(e); a[i] = pixel(e);
...@@ -84,8 +84,8 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) { ...@@ -84,8 +84,8 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) {
} }
paddle_mobile::operators::math::Gemm gemm; paddle_mobile::operators::math::Gemm gemm;
gemm.Sgemm(m, n, k, 1, a, lda, b, ldb, 0, c, ldc, relu, gemm.Sgemm(m, n, k, static_cast<int8_t>(1), a, lda, b, ldb,
nullptr); static_cast<int8_t>(0), c, ldc, relu, nullptr);
int eq = 0; int eq = 0;
int neq = 0; int neq = 0;
for (int i = 0; i < m * n; ++i) { for (int i = 0; i < m * n; ++i) {
...@@ -124,7 +124,8 @@ int main() { ...@@ -124,7 +124,8 @@ int main() {
do_sgemm(512, 256, 384, false, 0); do_sgemm(512, 256, 384, false, 0);
do_sgemm(1366, 768, 256, false, 0); do_sgemm(1366, 768, 256, false, 0);
do_sgemm(1255, 755, 333, false, 0); do_sgemm(1255, 755, 333, false, 0);
do_sgemm(555, 777, 999, false, 0); do_sgemm(555, 777, 999, false, 0);
do_sgemm(1024, 1024, 1024, false, 0);
return 0; return 0;
} }
...@@ -28,13 +28,11 @@ limitations under the License. */ ...@@ -28,13 +28,11 @@ limitations under the License. */
int main() { int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile; paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(4); paddle_mobile.SetThreadNum(1);
Tensor aa, bb, cc, scale, bias; Tensor aa, bb, cc;
auto aaptr = aa.mutable_data<float>({m, k}); auto aaptr = aa.mutable_data<float>({m, k});
auto bbptr = bb.mutable_data<float>({k, n}); auto bbptr = bb.mutable_data<float>({k, n});
auto ccptr = cc.mutable_data<float>({m, n}); auto ccptr = cc.mutable_data<float>({m, n});
auto scaleptr = scale.mutable_data<float>({m});
auto biasptr = bias.mutable_data<float>({m});
for (int i = 0; i < m * k; ++i) { for (int i = 0; i < m * k; ++i) {
aaptr[i] = 2; aaptr[i] = 2;
...@@ -45,23 +43,55 @@ int main() { ...@@ -45,23 +43,55 @@ int main() {
for (int i = 0; i < m * n; ++i) { for (int i = 0; i < m * n; ++i) {
ccptr[i] = 2; ccptr[i] = 2;
} }
for (int i = 0; i < m; ++i) {
scaleptr[i] = 1; Tensor aa_int8, bb_int8, cc_int8;
biasptr[i] = 0; auto aaptr_int8 = aa_int8.mutable_data<int8_t>({m, k});
auto bbptr_int8 = bb_int8.mutable_data<int8_t>({k, n});
auto ccptr_int8 = cc_int8.mutable_data<int32_t>({m, n});
for (int i = 0; i < m * k; ++i) {
aaptr_int8[i] = static_cast<int8_t>(2);
}
for (int i = 0; i < k * n; ++i) {
bbptr_int8[i] = static_cast<int8_t>(2);
}
for (int i = 0; i < m * n; ++i) {
ccptr_int8[i] = static_cast<int32_t>(2);
} }
auto time1 = time(); // float
// warm-up 10 times
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float>( paddle_mobile::operators::math::matmul<float>(
aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0), aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
false, biasptr); false, nullptr);
}
// paddle_mobile::operators::math::matmulWithBn<float>( auto time1 = time();
// aa, false, bb, false, static_cast<float>(1), &cc, for (int j = 0; j < 10; ++j) {
// static_cast<float>(0), true, &scale, &bias, 0); paddle_mobile::operators::math::matmul<float>(
aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
false, nullptr);
} }
auto time2 = time(); auto time2 = time();
std::cout << "gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n";
// int8_t
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t>(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8,
static_cast<int8_t>(0), false, nullptr);
}
auto time3 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t>(
aa_int8, false, bb_int8, false, static_cast<int8_t>(1), &cc_int8,
static_cast<int8_t>(0), false, nullptr);
}
auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
return 0; return 0;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册