diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index adc6924d8ad273012a9b44677f8ad1a29bc37787..3b00ae0a3c314ebd8e18490504fca4e657965a1a 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include "common/log.h" @@ -25,6 +26,7 @@ limitations under the License. */ #define MR 6 #define NR 16 #else +#define MR_INT8 4 #define MR 6 #define NR 8 #endif @@ -189,6 +191,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, // 8 bits function cluster begins // 8 bits int small block inner product + void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc); void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); @@ -199,6 +203,8 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, int8_t *bias); // 8 bits int pack function + void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index 1eac419fee921d7fd8c57437dbf66dd418cd1c92..9b722410beb92f5c507a34fd02422e3eb9ba7486 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -26,11 +26,228 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { +void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc) { +#if __ARM_NEON +#if __aarch64__ +// TODO +#else + const int8_t *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int32_t kc1 = k >> 3; + int32_t kc2 = k & 7; + int32_t kc3 = kc2 >> 2; + int32_t kc4 = kc2 & 3; + int32_t kc5 = kc4 >> 1; + int32_t kc6 = kc4 & 1; + int32_t step = sizeof(int32_t) * ldc; + asm volatile( + // q8-q15: save 32 results + "pld [%[a_ptr]] \n\t" + "pld [%[b_ptr]] \n\t" + "pld [%[b_ptr], #64] \n\t" + "vmov.s32 q8, #0 \n\t" + "vmov.s32 q9, q8 \n\t" + "vmov.s32 q10, q8 \n\t" + "vmov.s32 q11, q8 \n\t" + "vmov.s32 q12, q8 \n\t" + "vmov.s32 q13, q8 \n\t" + "vmov.s32 q14, q8 \n\t" + "vmov.s32 q15, q8 \n\t" + "subs %[kc1], %[kc1], #1 \n\t" + "blt 1f \n\t" + "0: \n\t" + "pld [%[a_ptr], #64] \n\t" + "pld [%[b_ptr], #128] \n\t" + "vld1.s8 {d0-d3}, [%[a_ptr]]! \n\t" // load A 8 cols + "vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B first 4 rows + "vmovl.s8 q2, d0 \n\t" // process B first 4 + // rows + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d9 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vld1.s8 {d12-d15}, [%[b_ptr]]! \n\t" // load B second 4 + // rows + "vmovl.s8 q2, d1 \n\t" + "vmovl.s8 q3, d10 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d11 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vmovl.s8 q2, d2 \n\t" // process B second 4 + // rows + "vmovl.s8 q3, d12 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d13 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vmovl.s8 q2, d3 \n\t" + "vmovl.s8 q3, d14 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d15 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 0b \n\t" + "1: \n\t" // last 4 rows + "subs %[kc3], %[kc3], #1 \n\t" + "blt 2f \n\t" + "vld1.s8 {d0-d1}, [%[a_ptr]]! \n\t" // load A 4 cols + "vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B 4 rows + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d9 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "vmovl.s8 q2, d1 \n\t" + "vmovl.s8 q3, d10 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d11 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "2: \n\t" // last 2 rows + "subs %[kc5], %[kc5], #1 \n\t" + "blt 3f \n\t" + "vld1.s8 {d0}, [%[a_ptr]]! \n\t" // load A 2 cols + "vld1.s8 {d8-d9}, [%[b_ptr]]! \n\t" // load B 2 rows + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "vmovl.s8 q3, d9 \n\t" + "vmlal.s16 q8, d6, d5[0]\n\t" + "vmlal.s16 q9, d7, d5[0]\n\t" + "vmlal.s16 q10, d6, d5[1]\n\t" + "vmlal.s16 q11, d7, d5[1]\n\t" + "vmlal.s16 q12, d6, d5[2]\n\t" + "vmlal.s16 q13, d7, d5[2]\n\t" + "vmlal.s16 q14, d6, d5[3]\n\t" + "vmlal.s16 q15, d7, d5[3]\n\t" + "3: \n\t" // last 1 row + "subs %[kc6], %[kc6], #1 \n\t" + "blt 4f \n\t" + "vld1.s8 {d0}, [%[a_ptr]] \n\t" // load A 1 col + "vld1.s8 {d8}, [%[b_ptr]] \n\t" // load B 1 row + "vmovl.s8 q2, d0 \n\t" + "vmovl.s8 q3, d8 \n\t" + "vmlal.s16 q8, d6, d4[0]\n\t" + "vmlal.s16 q9, d7, d4[0]\n\t" + "vmlal.s16 q10, d6, d4[1]\n\t" + "vmlal.s16 q11, d7, d4[1]\n\t" + "vmlal.s16 q12, d6, d4[2]\n\t" + "vmlal.s16 q13, d7, d4[2]\n\t" + "vmlal.s16 q14, d6, d4[3]\n\t" + "vmlal.s16 q15, d7, d4[3]\n\t" + "4: \n\t" + "vst1.32 {q8, q9}, [%[c]], %[step] \n\t" + "vst1.32 {q10, q11}, [%[c]], %[step] \n\t" + "vst1.32 {q12, q13}, [%[c]], %[step] \n\t" + "vst1.32 {q14, q15}, [%[c]] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [kc3] "r"(kc3), [kc5] "r"(kc5), [kc6] "r"(kc6), [step] "r"(step) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#endif // __ARM_NEON +} // 8 bits int small block inner product void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON +#if __aarch64__ +// TODO +#else const int8_t *a_ptr, *b_ptr; a_ptr = a; b_ptr = b; @@ -317,6 +534,7 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, [kc3] "r"(kc3), [kc5] "r"(kc5), [kc6] "r"(kc6), [step] "r"(step) : "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ #endif // __ARM_NEON } @@ -327,8 +545,9 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, int8_t *bias) { #pragma omp parallel for for (int32_t j = 0; j < nc; j += NR) { - for (int32_t i = 0; i < mc; i += MR) { - AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + for (int32_t i = 0; i < mc; i += MR_INT8) { +// AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); } } if (alpha != 1) { @@ -356,12 +575,53 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, return; } } +// 8 bits int PackMatrixA_4r +void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer) { + const int8_t *a0, *a1, *a2, *a3; + for (int32_t i = 0; i < m - m_tail; i += MR_INT8) { + a0 = A + i * lda; + a1 = A + (i + 1) * lda; + a2 = A + (i + 2) * lda; + a3 = A + (i + 3) * lda; + for (int32_t j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + } + } -// 8 bits int PackMatrixA + if (m_tail != 0) { + a0 = &A(m - m_tail, 0); + a1 = a0 + lda; + a2 = a0 + 2 * lda; + a3 = a0 + 3 * lda; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + } + } +} + +// 8 bits int PackMatrixA_6r void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { const int32_t i_length = m - m_tail; - for (int32_t i = 0; i < i_length; i += MR) { + for (int32_t i = 0; i < i_length; i += MR_INT8) { const int8_t *a0 = A + i * lda; const int8_t *a1 = A + (i + 1) * lda; const int8_t *a2 = A + (i + 2) * lda; @@ -421,6 +681,9 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, for (int32_t i = 0; i < k; ++i) { const int8_t *b0 = &B(i, j); #if __ARM_NEON +#if __aarch64__ + // TODO +#else asm volatile( // "pld [%[b0]] \n\t" "vld1.s8 {d0}, [%[b0]] \n\t" @@ -428,6 +691,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "q0"); +#endif // __aarch64__ #else *local_buffer++ = *b0++; *local_buffer++ = *b0++; @@ -467,13 +731,13 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, MC = L1 / (KC * sizeof(int8_t)); NC = L2 / (KC * sizeof(int8_t)); - // make sure MC is multiple of MR, and NC is multiple of NR + // make sure MC is multiple of MR_INT8, and NC is multiple of NR if (MC == 0) { - MC = MR; + MC = MR_INT8; } else { int32_t mblock_num = (m + MC - 1) / MC; MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR - 1) / MR * MR; + MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; } // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; if (NC == 0) { @@ -500,7 +764,8 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); for (int32_t i = 0; i < m; i += MC) { mc = s_min(m - i, MC); - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA_int8); +// PackMatrixA_6r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); + PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); if (bias == nullptr) { InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, packedC_int8, &C(i, j), ldc, relu, nullptr); @@ -525,6 +790,9 @@ void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc) { #if __ARM_NEON +#if __aarch64__ +// TODO +#else int32_t nc1 = nc >> 4; int32_t _nc1 = nc & 15; int32_t step = sizeof(int32_t) * ldc; @@ -578,6 +846,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, } } } +#endif // __aarch64__ #endif // __ARM_NEON }