diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 49c63d2e70e0744e57cd896d294b4bcabdb7546c..2990f7a0f8d4712a3dc3c429d9b57e5aa3809325 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -142,61 +142,6 @@ void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, } } -// 8位 int PackMatrixA函数 -void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const int8_t *A, int lda, - int8_t *buffer) { - const int i_length = m - m_tail; - for (int i = 0; i < i_length; i += MR) { - const int8_t *a0 = A + i * lda; - const int8_t *a1 = A + (i + 1) * lda; - const int8_t *a2 = A + (i + 2) * lda; - const int8_t *a3 = A + (i + 3) * lda; - const int8_t *a4 = A + (i + 4) * lda; - const int8_t *a5 = A + (i + 5) * lda; - int8_t *local_buffer = buffer + i * k; - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - } - } - if (m_tail != 0) { - const int8_t *a0 = &A(i_length, 0); - const int8_t *a1 = a0 + lda; - const int8_t *a2 = a0 + 2 * lda; - const int8_t *a3 = a0 + 3 * lda; - const int8_t *a4 = a0 + 4 * lda; - const int8_t *a5 = a0 + 5 * lda; - int8_t *local_buffer = buffer + i_length * k; - switch (m_tail) { - case 1: - a1 = zero_int8; - case 2: - a2 = zero_int8; - case 3: - a3 = zero_int8; - case 4: - a4 = zero_int8; - case 5: - a5 = zero_int8; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - } - } -} - void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, float *buffer) { const int i_length = m - m_tail; @@ -439,48 +384,6 @@ void Gemm::PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, } } -// 8位 int PackMatrixB函数 -void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const int8_t *B, int ldb, - int8_t *buffer) { - const int j_length = n - n_tail; - for (int j = 0; j < j_length; j += NR) { - int8_t *local_buffer = buffer + j * k; - for (int i = 0; i < k; ++i) { - const int8_t *b0 = &B(i, j); -#if __ARM_NEON - asm volatile( - // "pld [%[b0]] \n\t" - "vld1.s8 {d0}, [%[b0]] \n\t" - "vst1.s8 {d0}, [%[local_buffer]]! \n\t" - : [local_buffer] "+r"(local_buffer) - : [b0] "r"(b0) - : "memory", "q0"); -#else - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; -#endif // __ARM_NEON - } - } - if (n_tail != 0) { - int8_t *local_buffer = buffer + j_length * k; - for (int i = 0; i < k; ++i) { - const int8_t *b0 = &B(i, j_length); - for (int j = j_length; j < n; ++j) { - *local_buffer++ = *b0++; - } - for (int j = n; j < j_length + NR; ++j) { - *local_buffer++ = 0; - } - } - } -} - // 将B矩阵分块复制到连续内存(RowMajor) void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, float *buffer) { @@ -745,42 +648,6 @@ void Gemm::InnerKernel(int mc, int nc, float alpha, const float *a, } } -// 8位 int 分块矩阵乘法 -void Gemm::InnerKernelWithBias(int mc, int nc, float alpha, const int8_t *a, - const int8_t *b, float beta, int *c, int *C, - int ldc, bool relu, int8_t *bias) { -#pragma omp parallel for - for (int j = 0; j < nc; j += NR) { - for (int i = 0; i < mc; i += MR) { - AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); - } - } - if (alpha != 1) { - WriteWithAlphaBeta(mc, nc, c, C, ldc); - return; - } - if (beta == 0) { - WriteBasic(mc, nc, c, C, ldc); - return; - } - if (beta == 1 && !relu) { - if (bias == nullptr) { - WriteWithAdd(mc, nc, c, C, ldc); - } else { - WriteWithAddV1(mc, nc, c, C, ldc, bias); - } - return; - } - if (beta == 1 && relu) { - if (bias == nullptr) { - WriteWithAddRelu(mc, nc, c, C, ldc); - } else { - WriteWithAddReluV1(mc, nc, c, C, ldc, bias); - } - return; - } -} - // 分块矩阵乘法 void Gemm::InnerKernelWithBias(int mc, int nc, float alpha, const float *a, const float *b, float beta, float *c, float *C, @@ -2007,63 +1874,6 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { "q10", "q11", "q12", "q13", "q14", "q15"); } -// C = A * B, 8位 int -void Gemm::WriteBasic(int mc, int nc, int *c, int *C, int ldc) { - int nc1 = nc >> 4; - int _nc1 = nc & 15; - int step = sizeof(int) * ldc; - int step1 = sizeof(int) * (NC - (nc1 << 4)); - int volatile m = mc; - - int *volatile c_ptr, *volatile C_ptr; - int *C0, *c0; - c_ptr = c; - C_ptr = C; - if (nc1 > 0) { - asm volatile( - "subs %[mc], %[mc], #1 \n\t" - "blt end_mc_%= \n\t" - "loop_mc_%=: \n\t" - - "mov r6, %[C_ptr] \n\t" - "mov r5, %[nc1] \n\t" - "subs r5, r5, #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" - "vst1.32 {q0, q1}, [r6]! \n\t" - - "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" - "vst1.32 {q2, q3}, [r6]! \n\t" - - "subs r5, r5, #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "add %[C_ptr], %[C_ptr], %[step] \n\t" - "add %[c_ptr], %[c_ptr], %[step1] \n\t" - "subs %[mc], %[mc], #1 \n\t" - "bge loop_mc_%= \n\t" - "end_mc_%=: \n\t" - - : - : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), - [step] "r"(step), [step1] "r"(step1) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); - } - - if (_nc1 != 0) { - for (int i = 0; i < mc; i++) { - C0 = C_ptr + nc1 * 16 + i * ldc; - c0 = c_ptr + nc1 * 16 + i * NC; - for (int j = 0; j < _nc1; j++) { - *C0++ = *c0++; - } - } - } -} - // C = A * B void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) { int nc1 = nc / 16; @@ -2121,14 +1931,9 @@ void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) { } } -// C = alpha * A * B + beta * C -void Gemm::WriteWithAlphaBeta(int mc, int nc, int *c, int *C, int ldc) {} - // C = alpha * A * B + beta * C void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {} -// C = A * B + C -void Gemm::WriteWithAdd(int mc, int nc, int *c, int *C, int ldc) {} // C = A * B + C void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { int nc1 = nc / 16; @@ -2193,9 +1998,6 @@ void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { } } -// C = A * B + bias -void Gemm::WriteWithAddV1(int mc, int nc, int *c, int *C, int ldc, - int8_t *bias) {} // C = A * B + bias void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) { @@ -2235,9 +2037,6 @@ void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, } } -// C = A * B + C, relu(C) -void Gemm::WriteWithAddRelu(int mc, int nc, int *c, int *C, int ldc) {} - // C = A * B + C, relu(C) void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { int nc1 = nc / 16; @@ -2311,9 +2110,6 @@ void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { } } } -// C = A * B + bias, relu(C) -void Gemm::WriteWithAddReluV1(int mc, int nc, int *c, int *C, int ldc, - int8_t *bias) {} // C = A * B + bias, relu(C) void Gemm::WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, @@ -3200,69 +2996,6 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, #endif // __ARM_NEON -// 8位 int 矩阵乘法 (m*k与k*n的乘积) -void Gemm::Sgemm(int m, int n, int k, float alpha, const int8_t *A, int lda, - const int8_t *B, int ldb, float beta, int *C, int ldc, - bool relu, int8_t *bias) { - // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) - // L2 cache is 0.5~4 Mib (Contex-A72 cluster) - int L1 = 32 * 1024; - int L2 = 512 * 1024; - - KC = k; - MC = L1 / (KC * sizeof(int8_t)); - NC = L2 / (KC * sizeof(int8_t)); - - // make sure MC is multiple of MR, and NC is multiple of NR - if (MC == 0) { - MC = MR; - } else { - int mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR - 1) / MR * MR; - } - // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; - if (NC == 0) { - NC = NR; - } else { - int nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; - } - // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; - packedA_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); - packedB_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); - packedC_int8 = static_cast( - paddle_mobile::memory::Alloc(sizeof(int) * MC * NC)); - zero_int8 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); - - memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); - int mc, nc; - for (int j = 0; j < n; j += NC) { - nc = s_min(n - j, NC); - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); - for (int i = 0; i < m; i += MC) { - mc = s_min(m - i, MC); - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8); - if (bias == nullptr) { - InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, - packedC_int8, &C(i, j), ldc, relu, nullptr); - } else { - InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, - packedC_int8, &C(i, j), ldc, relu, bias + i); - } - } - } - - paddle_mobile::memory::Free(packedA_int8); - paddle_mobile::memory::Free(packedB_int8); - paddle_mobile::memory::Free(packedC_int8); - paddle_mobile::memory::Free(zero_int8); -} - // 32位 float 矩阵乘法 void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, @@ -3856,125 +3589,6 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, paddle_mobile::memory::Free(zero); } -void Gemm::AddDot6x8(int k, const int8_t *a, const int8_t *b, int *c, int ldc) { -#if __ARM_NEON - const int8_t *a_ptr, *b_ptr; - a_ptr = a; - b_ptr = b; - int kc1 = k >> 1; - int kc2 = k & 1; - int step = sizeof(int) * ldc; - asm volatile( - // q4-q15: save 48 results - "vmov.s8 q4, #0 \n\t" - "vmov.s8 q5, #0 \n\t" - "vmov.s8 q6, #0 \n\t" - "vmov.s8 q7, #0 \n\t" - "vmov.s8 q8, #0 \n\t" - "vmov.s8 q9, #0 \n\t" - "vmov.s8 q10, #0 \n\t" - "vmov.s8 q11, #0 \n\t" - "vmov.s8 q12, #0 \n\t" - "vmov.s8 q13, #0 \n\t" - "vmov.s8 q14, #0 \n\t" - "vmov.s8 q15, #0 \n\t" - "mov r0, #6 \n\t" - "subs %[kc1], %[kc1], #1 \n\t" - "blt 1f \n\t" - "0: \n\t" - "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 - "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 used - "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B row1, q1 used - "vmov.s8 q2, #0 \n\t" // q2 used - "vdup.s8 d6, d0[0] \n\t" // q3 used(but d7 is free) - "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B row0 - "vdup.s8 d6, d1[0] \n\t" - "vmlal.s8 q2, d3, d6 \n\t" // A col10 * B row1, q3 free - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[1] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vdup.s8 d6, d1[1] \n\t" - "vmlal.s8 q2, d3, d6 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[2] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vdup.s8 d6, d1[2] \n\t" - "vmlal.s8 q2, d3, d6 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[3] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vdup.s8 d6, d1[3] \n\t" - "vmlal.s8 q2, d3, d6 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[4] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vdup.s8 d6, d1[4] \n\t" - "vmlal.s8 q2, d3, d6 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vmov.s8 q2, #0 \n\t" - "vdup.s8 d6, d0[5] \n\t" - "vmlal.s8 q2, d2, d6 \n\t" - "vdup.s8 d6, d1[5] \n\t" - "vmlal.s8 q2, d3, d6 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 5 - - "subs %[kc1], %[kc1], #1 \n\t" - "bge 0b \n\t" - "1: \n\t" // odd, last row - "subs %[kc2], %[kc2], #1 \n\t" - "blt 2f \n\t" - "vld1.s8 {d0}, [%[a_ptr]] \n\t" - "vld1.s8 {d1}, [%[b_ptr]] \n\t" - "vdup.s8 d2, d0[0] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q4, q4, d4 \n\t" - "vaddw.s16 q5, q5, d5 \n\t" // res row 0 - "vdup.s8 d2, d0[1] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q6, q6, d4 \n\t" - "vaddw.s16 q7, q7, d5 \n\t" // res row 1 - "vdup.s8 d2, d0[2] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q8, q8, d4 \n\t" - "vaddw.s16 q9, q9, d5 \n\t" // res row 2 - "vdup.s8 d2, d0[3] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q10, q10, d4 \n\t" - "vaddw.s16 q11, q11, d5 \n\t" // res row 3 - "vdup.s8 d2, d0[4] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q12, q12, d4 \n\t" - "vaddw.s16 q13, q13, d5 \n\t" // res row 4 - "vdup.s8 d2, d0[5] \n\t" - "vmull.s8 q2, d1, d2 \n\t" - "vaddw.s16 q14, q14, d4 \n\t" - "vaddw.s16 q15, q15, d5 \n\t" // res row 4 - "2: \n\t" - "vst1.32 {q4, q5}, [%[c]], %[step] \n\t" - "vst1.32 {q6, q7}, [%[c]], %[step] \n\t" - "vst1.32 {q8, q9}, [%[c]], %[step] \n\t" - "vst1.32 {q10, q11}, [%[c]], %[step] \n\t" - "vst1.32 {q12, q13}, [%[c]], %[step] \n\t" - "vst1.32 {q14, q15}, [%[c]] \n\t" - - : - : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), - [kc2] "r"(kc2), [step] "r"(step) - : "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); -#endif -} - void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { #if __ARM_NEON #if __aarch64__ diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index c28098b68d24300c45be3feae616446742ce9a01..77c3293bf468bb30437dde58d3a56c9bad6358e3 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -80,12 +80,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, float *buffer); - // 8位 int - void PackMatrixA_6r(int m, int k, int m_tail, const int8_t *A, int lda, - int8_t *buffer); - void PackMatrixB_8c(int k, int n, int n_tail, const int8_t *B, int ldb, - int8_t *buffer); - // 分块矩阵乘法 void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, float beta, float *c, float *C, int ldc, bool relu); @@ -104,11 +98,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *c, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); - // 8位 int - void InnerKernelWithBias(int mc, int nc, float alpha, const int8_t *a, - const int8_t *b, float beta, int *c, int *C, int ldc, - bool relu, int8_t *bias); - /* // 向量矩阵乘法 (M = 1) void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, @@ -127,8 +116,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc); void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc); - void AddDot6x8(int k, const int8_t *a, const int8_t *b, int *c, int ldc); - // 分块矩阵乘法结果回写 // C = A * B void WriteBasic(int mc, int nc, float *c, float *C, int ldc); @@ -154,20 +141,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *new_scale, float *new_bias); void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, float *new_scale, float *new_bias, float *bias1); - // 8位 int 分块矩阵乘法结果回写 - // C = alpha * A * B + beta * C - void WriteWithAlphaBeta(int mc, int nc, int *c, int *C, int ldc); - // C = A * B - void WriteBasic(int mc, int nc, int *c, int *C, int ldc); - // C = A * B + C - void WriteWithAdd(int mc, int nc, int *c, int *C, int ldc); - // C = A * B + bias - void WriteWithAddV1(int mc, int nc, int *c, int *C, int ldc, int8_t *bias); - // C = A * B + C, relu(C) - void WriteWithAddRelu(int mc, int nc, int *c, int *C, int ldc); - // C = A * B + bias, relu(C) - void WriteWithAddReluV1(int mc, int nc, int *c, int *C, int ldc, - int8_t *bias); + /* // 向量矩阵乘法结果回写 // C = A * B @@ -186,11 +160,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *new_bias); */ - // 8位 int 矩阵乘法 - void Sgemm(int m, int n, int k, float alpha, const int8_t *A, int lda, - const int8_t *B, int ldb, float beta, int *C, int ldc, bool relu, - int8_t *bias); - // 32位 float 矩阵乘法 void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, @@ -219,6 +188,47 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, const float *B, int ldb, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); + /************************ 8 bit function cluster ************************/ + // 8 bit int small block inner product + void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc); + + // 8 bit int inner product + void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, + const int8_t *a, const int8_t *b, int8_t beta, + int32_t *c, int32_t *C, int32_t ldc, bool relu, + int8_t *bias); + + // 8 bit int pack function + void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); + void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer); + + // 8 bit int matrix product + void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C, + int32_t ldc, bool relu, int8_t *bias); + + // 8 bit int write back + // C = alpha * A * B + beta * C + void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc); + // C = A * B + void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc); + // C = A * B + C + void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc); + // C = A * B + bias + void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias); + // C = A * B + C, relu(C) + void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc); + // C = A * B + bias, relu(C) + void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias); + private: int MC = 0; int KC = 0; @@ -230,10 +240,10 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *packedC; float *zero; - // 8位 int + // 8 bit int int8_t *packedA_int8; int8_t *packedB_int8; - int *packedC_int8; + int32_t *packedC_int8; int8_t *zero_int8; }; diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a885acc0d2b0b5fdf5daeab9d4ccdf4b1e904dbc --- /dev/null +++ b/src/operators/math/gemm_int8.cpp @@ -0,0 +1,431 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "common/log.h" +#include "memory/t_malloc.h" +#include "operators/math/gemm.h" +#if __ARM_NEON +#include +#endif +#ifdef _OPENMP +#include +#endif + +namespace paddle_mobile { +namespace operators { +namespace math { + +// 8 bit int small block inner product +void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc) { +#if __ARM_NEON + const int8_t *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int32_t kc1 = k >> 1; + int32_t kc2 = k & 1; + int32_t step = sizeof(int32_t) * ldc; + asm volatile( + // q4-q15: save 48 results + "pld [%[a_ptr]] \n\t" + "pld [%[b_ptr]] \n\t" + "vmov.s8 q4, #0 \n\t" + "vmov.s8 q5, #0 \n\t" + "vmov.s8 q6, #0 \n\t" + "vmov.s8 q7, #0 \n\t" + "vmov.s8 q8, #0 \n\t" + "vmov.s8 q9, #0 \n\t" + "vmov.s8 q10, #0 \n\t" + "vmov.s8 q11, #0 \n\t" + "vmov.s8 q12, #0 \n\t" + "vmov.s8 q13, #0 \n\t" + "vmov.s8 q14, #0 \n\t" + "vmov.s8 q15, #0 \n\t" + "mov r0, #6 \n\t" + "subs %[kc1], %[kc1], #1 \n\t" + "blt 1f \n\t" + "0: \n\t" + "pld [%[a_ptr], #64] \n\t" + "pld [%[b_ptr], #64] \n\t" + "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 + "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 used + "vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B row1, q1 used + "vmov.s8 q2, #0 \n\t" // q2 used + "vdup.s8 d6, d0[0] \n\t" + "vdup.s8 d7, d1[0] \n\t" // q3 used + "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B row0 + "vmlal.s8 q2, d3, d7 \n\t" // A col10 * B row1, q3 free + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[1] \n\t" + "vdup.s8 d7, d1[1] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[2] \n\t" + "vdup.s8 d7, d1[2] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[3] \n\t" + "vdup.s8 d7, d1[3] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[4] \n\t" + "vdup.s8 d7, d1[4] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vmov.s8 q2, #0 \n\t" + "vdup.s8 d6, d0[5] \n\t" + "vdup.s8 d7, d1[5] \n\t" + "vmlal.s8 q2, d2, d6 \n\t" + "vmlal.s8 q2, d3, d7 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 5 + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 0b \n\t" + "1: \n\t" // odd, last row + "subs %[kc2], %[kc2], #1 \n\t" + "blt 2f \n\t" + "vld1.s8 {d0}, [%[a_ptr]] \n\t" + "vld1.s8 {d1}, [%[b_ptr]] \n\t" + "vdup.s8 d2, d0[0] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q4, q4, d4 \n\t" + "vaddw.s16 q5, q5, d5 \n\t" // res row 0 + "vdup.s8 d2, d0[1] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q6, q6, d4 \n\t" + "vaddw.s16 q7, q7, d5 \n\t" // res row 1 + "vdup.s8 d2, d0[2] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q8, q8, d4 \n\t" + "vaddw.s16 q9, q9, d5 \n\t" // res row 2 + "vdup.s8 d2, d0[3] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q10, q10, d4 \n\t" + "vaddw.s16 q11, q11, d5 \n\t" // res row 3 + "vdup.s8 d2, d0[4] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q12, q12, d4 \n\t" + "vaddw.s16 q13, q13, d5 \n\t" // res row 4 + "vdup.s8 d2, d0[5] \n\t" + "vmull.s8 q2, d1, d2 \n\t" + "vaddw.s16 q14, q14, d4 \n\t" + "vaddw.s16 q15, q15, d5 \n\t" // res row 4 + "2: \n\t" + "vst1.32 {q4, q5}, [%[c]], %[step] \n\t" + "vst1.32 {q6, q7}, [%[c]], %[step] \n\t" + "vst1.32 {q8, q9}, [%[c]], %[step] \n\t" + "vst1.32 {q10, q11}, [%[c]], %[step] \n\t" + "vst1.32 {q12, q13}, [%[c]], %[step] \n\t" + "vst1.32 {q14, q15}, [%[c]] \n\t" + + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [kc2] "r"(kc2), [step] "r"(step) + : "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif +} + +// 8 bit int inner product +void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, + const int8_t *a, const int8_t *b, int8_t beta, + int32_t *c, int32_t *C, int32_t ldc, bool relu, + int8_t *bias) { +#pragma omp parallel for + for (int32_t j = 0; j < nc; j += NR) { + for (int32_t i = 0; i < mc; i += MR) { + AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + } + } + if (alpha != 1) { + WriteWithAlphaBeta(mc, nc, c, C, ldc); + return; + } + if (beta == 0) { + WriteBasic(mc, nc, c, C, ldc); + return; + } + if (beta == 1 && !relu) { + if (bias == nullptr) { + WriteWithAdd(mc, nc, c, C, ldc); + } else { + WriteWithAddV1(mc, nc, c, C, ldc, bias); + } + return; + } + if (beta == 1 && relu) { + if (bias == nullptr) { + WriteWithAddRelu(mc, nc, c, C, ldc); + } else { + WriteWithAddReluV1(mc, nc, c, C, ldc, bias); + } + return; + } +} + +// 8 bit int PackMatrixA +void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer) { + const int32_t i_length = m - m_tail; + for (int32_t i = 0; i < i_length; i += MR) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + const int8_t *a4 = A + (i + 4) * lda; + const int8_t *a5 = A + (i + 5) * lda; + int8_t *local_buffer = buffer + i * k; + for (int32_t j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + const int8_t *a4 = a0 + 4 * lda; + const int8_t *a5 = a0 + 5 * lda; + int8_t *local_buffer = buffer + i_length * k; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + case 4: + a4 = zero_int8; + case 5: + a5 = zero_int8; + break; + default: + break; + } + for (int32_t j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } +} + +// 8 bit int PackMatrixB +void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + for (int32_t j = 0; j < j_length; j += NR) { + int8_t *local_buffer = buffer + j * k; + for (int32_t i = 0; i < k; ++i) { + const int8_t *b0 = &B(i, j); +#if __ARM_NEON + asm volatile( + // "pld [%[b0]] \n\t" + "vld1.s8 {d0}, [%[b0]] \n\t" + "vst1.s8 {d0}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "q0"); +#else + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; +#endif // __ARM_NEON + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * k; + for (int32_t i = 0; i < k; ++i) { + const int8_t *b0 = &B(i, j_length); + for (int32_t j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int32_t j = n; j < j_length + NR; ++j) { + *local_buffer++ = 0; + } + } + } +} + +// 8 bit int matrix product (m*k x k*n) +void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, + int32_t *C, int32_t ldc, bool relu, int8_t *bias) { + // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) + // L2 cache is 0.5~4 Mib (Contex-A72 cluster) + int32_t L1 = 32 * 1024; + int32_t L2 = 512 * 1024; + + KC = k; + MC = L1 / (KC * sizeof(int8_t)); + NC = L2 / (KC * sizeof(int8_t)); + + // make sure MC is multiple of MR, and NC is multiple of NR + if (MC == 0) { + MC = MR; + } else { + int32_t mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + } + // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; + if (NC == 0) { + NC = NR; + } else { + int32_t nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + } + // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); + packedC_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); + zero_int8 = + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); + + memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); + int32_t mc, nc; + for (int32_t j = 0; j < n; j += NC) { + nc = s_min(n - j, NC); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); + for (int32_t i = 0; i < m; i += MC) { + mc = s_min(m - i, MC); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8); + if (bias == nullptr) { + InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int8, &C(i, j), ldc, relu, nullptr); + } else { + InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int8, &C(i, j), ldc, relu, bias + i); + } + } + } + + paddle_mobile::memory::Free(packedA_int8); + paddle_mobile::memory::Free(packedB_int8); + paddle_mobile::memory::Free(packedC_int8); + paddle_mobile::memory::Free(zero_int8); +} + +// 8 bit int write back +// C = alpha * A * B + beta * C +void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) {} +// C = A * B, 8位 int32_t +void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) { + int32_t nc1 = nc >> 4; + int32_t _nc1 = nc & 15; + int32_t step = sizeof(int32_t) * ldc; + int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 4)); + int32_t volatile m = mc; + + int32_t *volatile c_ptr, *volatile C_ptr; + int32_t *C0, *c0; + c_ptr = c; + C_ptr = C; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" + + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vst1.32 {q0, q1}, [r6]! \n\t" + + "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" + "vst1.32 {q2, q3}, [r6]! \n\t" + + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" + + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + [step] "r"(step), [step1] "r"(step1) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); + } + + if (_nc1 != 0) { + for (int32_t i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 16 + i * ldc; + c0 = c_ptr + nc1 * 16 + i * NC; + for (int32_t j = 0; j < _nc1; j++) { + *C0++ = *c0++; + } + } + } +} + +// C = A * B + C +void Gemm::WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) {} + +// C = A * B + bias +void Gemm::WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias) {} +// C = A * B + C, relu(C) +void Gemm::WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc) {} + +// C = A * B + bias, relu(C) +void Gemm::WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, + int32_t ldc, int8_t *bias) {} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index de19e3df2ab69c8ac490b09af2852bf2fa806c64..b70dfb43ba11400e555365485f2a632c854279ac 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -25,7 +25,7 @@ template void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu = false, - float *bias = nullptr); + T *bias = nullptr); template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..70677223d12ded2da07ab53bc371f1e8da9fe293 --- /dev/null +++ b/src/operators/math/math_function_int8.cpp @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "operators/math/gemm.h" +#include "operators/math/math_function.h" + +namespace paddle_mobile { +namespace operators { +namespace math { +template <> +void matmul(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, + int8_t alpha, framework::Tensor *matrix_out, int8_t beta, + bool relu, int8_t *bias) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_MOBILE_ENFORCE( + dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + int32_t M = dim_out[0]; + int32_t N = dim_out[1]; + int32_t K = (!trans_a) ? dim_a[1] : dim_a[0]; + Gemm gemm; + + if (trans_a) { + int32_t numel = matrix_a.numel(); + int32_t m = matrix_a.dims()[0]; + int32_t n = matrix_a.dims()[1]; + int8_t *tmp = (int8_t *)(matrix_a.data()); // NOLINT + int8_t *a = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * numel)); + int32_t index = 0; + for (int32_t j = 0; j < n; j++) { + for (int32_t i = 0; i < m; i++) { + a[index++] = tmp[i * n + j]; + } + } + + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } else { + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), N, + relu, bias); + } +} +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/test/common/test_gemm_int8_accuracy.cpp b/test/common/test_gemm_int8_accuracy.cpp index fd01d9d545af9fdc71ce80e02366d903547f5e02..80ddd40e121c81032c903955bd7116cf52695569 100644 --- a/test/common/test_gemm_int8_accuracy.cpp +++ b/test/common/test_gemm_int8_accuracy.cpp @@ -57,14 +57,14 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) { int ldc = n; default_random_engine e; uniform_int_distribution pixel(-127, 127); - int8_t *a = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k)); - int8_t *b = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n)); - int32_t *c = - static_cast(paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n)); - int32_t *c1 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n)); + int8_t *a = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k)); + int8_t *b = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n)); + int32_t *c = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n)); + int32_t *c1 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * m * n)); for (int i = 0; i < m * k; ++i) { a[i] = pixel(e); @@ -84,8 +84,8 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) { } paddle_mobile::operators::math::Gemm gemm; - gemm.Sgemm(m, n, k, 1, a, lda, b, ldb, 0, c, ldc, relu, - nullptr); + gemm.Sgemm(m, n, k, static_cast(1), a, lda, b, ldb, + static_cast(0), c, ldc, relu, nullptr); int eq = 0; int neq = 0; for (int i = 0; i < m * n; ++i) { @@ -124,7 +124,8 @@ int main() { do_sgemm(512, 256, 384, false, 0); do_sgemm(1366, 768, 256, false, 0); do_sgemm(1255, 755, 333, false, 0); - do_sgemm(555, 777, 999, false, 0); + do_sgemm(555, 777, 999, false, 0); + do_sgemm(1024, 1024, 1024, false, 0); return 0; } diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 386c09d71a3d5709842991bffd2e8ea039edc940..89f0012ae8effaab383719c1b85748c24eb2bf73 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -28,13 +28,11 @@ limitations under the License. */ int main() { paddle_mobile::PaddleMobile paddle_mobile; - paddle_mobile.SetThreadNum(4); - Tensor aa, bb, cc, scale, bias; + paddle_mobile.SetThreadNum(1); + Tensor aa, bb, cc; auto aaptr = aa.mutable_data({m, k}); auto bbptr = bb.mutable_data({k, n}); auto ccptr = cc.mutable_data({m, n}); - auto scaleptr = scale.mutable_data({m}); - auto biasptr = bias.mutable_data({m}); for (int i = 0; i < m * k; ++i) { aaptr[i] = 2; @@ -45,23 +43,55 @@ int main() { for (int i = 0; i < m * n; ++i) { ccptr[i] = 2; } - for (int i = 0; i < m; ++i) { - scaleptr[i] = 1; - biasptr[i] = 0; + + Tensor aa_int8, bb_int8, cc_int8; + auto aaptr_int8 = aa_int8.mutable_data({m, k}); + auto bbptr_int8 = bb_int8.mutable_data({k, n}); + auto ccptr_int8 = cc_int8.mutable_data({m, n}); + + for (int i = 0; i < m * k; ++i) { + aaptr_int8[i] = static_cast(2); + } + for (int i = 0; i < k * n; ++i) { + bbptr_int8[i] = static_cast(2); + } + for (int i = 0; i < m * n; ++i) { + ccptr_int8[i] = static_cast(2); } - auto time1 = time(); + // float + // warm-up 10 times for (int j = 0; j < 10; ++j) { paddle_mobile::operators::math::matmul( aa, false, bb, false, static_cast(1), &cc, static_cast(0), - false, biasptr); + false, nullptr); + } - // paddle_mobile::operators::math::matmulWithBn( - // aa, false, bb, false, static_cast(1), &cc, - // static_cast(0), true, &scale, &bias, 0); + auto time1 = time(); + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa, false, bb, false, static_cast(1), &cc, static_cast(0), + false, nullptr); } auto time2 = time(); - std::cout << "gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; + std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; + + // int8_t + // warm-up 10 times + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + static_cast(0), false, nullptr); + } + + auto time3 = time(); + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + static_cast(0), false, nullptr); + } + auto time4 = time(); + std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; return 0; }