From f72a124b80da6ed6d87327a5647087e72f1e080b Mon Sep 17 00:00:00 2001 From: Jiaying Zhao Date: Wed, 26 Jun 2019 19:19:56 +0800 Subject: [PATCH] add gemm_int8 arm64 version without openmp (#1708) * add gemm_int8 arm64 version without openmp * add gemm_int8 arm64 version with openmp --- src/operators/math/gemm.h | 33 +- src/operators/math/gemm_int8.cpp | 826 +++++++++++++++++++----- src/operators/math/gemm_omp_int8.cpp | 111 +++- test/common/test_gemm_int8_accuracy.cpp | 2 +- 4 files changed, 808 insertions(+), 164 deletions(-) diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 113e04fe3c..fdbae47112 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -28,7 +28,7 @@ limitations under the License. */ #if __aarch64__ #define MR_INT8 4 -#define NR_INT8 2 +#define NR_INT8 4 #define MR 6 #define NR 16 #else @@ -181,12 +181,15 @@ class Gemm { std::string mode, float *bias, float *bias1); // 8 bits function cluster begins - // 8 bits int small block inner product + // 8 bits int small block inner product, data packed k = 1 void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); + void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc); + // 8 bits int small block inner product, data packed k = 16 void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); - void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + void AddDot4x4(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); // 8 bits int inner product @@ -203,14 +206,16 @@ class Gemm { // 8 bits int pack function void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); - void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, - int32_t lda, int8_t *buffer); void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); - void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, - int32_t ldb, int8_t *buffer); void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer); + void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); + void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer); + void PackMatrixB_4c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer); void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, @@ -219,6 +224,8 @@ class Gemm { const int8_t *A, int32_t lda, int8_t *buffer); void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer); + void PackMatrixB_omp_4c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer); // 8 bits int matrix product template @@ -314,7 +321,11 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, int32_t mc, nc; for (int32_t j = 0; j < n; j += NC) { nc = s_min(n - j, NC); +#if __aarch64__ + PackMatrixB_4c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8); +#else PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8); +#endif for (int32_t i = 0; i < m; i += MC) { mc = s_min(m - i, MC); PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8); @@ -375,7 +386,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, packedB_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); #if __aarch64__ - // TODO(paddle mobile) + PackMatrixB_omp_4c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); #else PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); #endif @@ -397,7 +408,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, packedA_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); #if __aarch64__ - // TODO(paddle mobile) + PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); #else PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); #endif @@ -421,7 +432,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, int8_t *local_A = packedA_int8 + MC * KC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads; #if __aarch64__ - // TODO(paddle mobile) + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); #else PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); #endif @@ -451,7 +462,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, int8_t *local_B = packedB_int8 + KC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads; #if __aarch64__ - // TODO(paddle mobile) + PackMatrixB_4c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); #else PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); #endif diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index 16537adfec..ba7d076915 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -31,7 +31,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON #if __aarch64__ -// TODO() +// AddDot4x8 used only for aarch32 #else const int8_t *a_ptr, *b_ptr; a_ptr = a; @@ -249,7 +249,7 @@ void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON #if __aarch64__ -// TODO +// AddDot4x2 used only for aarch32 #else #define PADDLE_LABEL_LOOP "1" #define PADDLE_LABEL_AFTER_LOOP "2" @@ -371,12 +371,226 @@ void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, #endif // __ARM_NEON } +void Gemm::AddDot4x4(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc) { +#if __ARM_NEON +#if __aarch64__ +#define PADDLE_LABEL_LOOP "1" +#define PADDLE_LABEL_AFTER_LOOP "2" + asm volatile( + // load data from matrix a and b,and set zero to result register + "ld1 {v0.16b}, [%[b]], #16\n" + "dup v16.4s, wzr\n" + "ld1 {v4.16b}, [%[a]], #16\n" + "dup v17.4s, wzr\n" + "ld1 {v1.16b}, [%[b]], #16\n" + "dup v18.4s, wzr\n" + "ld1 {v5.16b}, [%[a]], #16\n" + "dup v19.4s, wzr\n" + "ld1 {v2.16b}, [%[b]], #16\n" + "dup v20.4s, wzr\n" + "ld1 {v3.16b}, [%[b]], #16\n" + "dup v21.4s, wzr\n" + "ld1 {v6.16b}, [%[a]], #16\n" + "dup v22.4s, wzr\n" + "ld1 {v7.16b}, [%[a]], #16\n" + "dup v23.4s, wzr\n" + "dup v24.4s, wzr\n" + "dup v25.4s, wzr\n" + "dup v26.4s, wzr\n" + "dup v27.4s, wzr\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + // Multiply ldc by 4 == sizeof(int32) + "lsl %[ldc], %[ldc], #2\n" + + // first half + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + "subs %[k], %[k], #16\n" + + // skip the loop + "beq " PADDLE_LABEL_AFTER_LOOP "f\n" + + // loop + PADDLE_LABEL_LOOP + ":\n" + + // first half + "sadalp v16.4s, v8.8h\n" + "ld1 {v4.16b}, [%[a]], #16\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "ld1 {v5.16b}, [%[a]], #16\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "ld1 {v6.16b}, [%[a]], #16\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "ld1 {v0.16b}, [%[b]], #16\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "ld1 {v1.16b}, [%[b]], #16\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "ld1 {v2.16b}, [%[b]], #16\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "ld1 {v3.16b}, [%[b]], #16\n" + + // first half + "sadalp v24.4s, v8.8h\n" + "smull v8.8h, v0.8b, v4.8b\n" + "sadalp v25.4s, v9.8h\n" + "ld1 {v7.16b}, [%[a]], #16\n" + "smull v9.8h, v1.8b, v4.8b\n" + "sadalp v26.4s, v10.8h\n" + "smull v10.8h, v2.8b, v4.8b\n" + "sadalp v27.4s, v11.8h\n" + "smull v11.8h, v3.8b, v4.8b\n" + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v5.8b\n" + "sadalp v29.4s, v13.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v30.4s, v14.8h\n" + "smull v14.8h, v2.8b, v5.8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Loop + "subs %[k], %[k], #16\n" + + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + "bne " PADDLE_LABEL_LOOP "b\n" + + // Final + PADDLE_LABEL_AFTER_LOOP + ":\n" + + // first half + "sadalp v16.4s, v8.8h\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + "smlal2 v12.8h, v0.16b, v7.16b\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "sadalp v24.4s, v8.8h\n" + "sadalp v25.4s, v9.8h\n" + "sadalp v26.4s, v10.8h\n" + "sadalp v27.4s, v11.8h\n" + "sadalp v28.4s, v12.8h\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + // Reduce 32bit accumulators horizontally. + "addp v0.4s, v16.4s, v17.4s\n" + "addp v1.4s, v18.4s, v19.4s\n" + "addp v2.4s, v20.4s, v21.4s\n" + "addp v3.4s, v22.4s, v23.4s\n" + "addp v4.4s, v24.4s, v25.4s\n" + "addp v5.4s, v26.4s, v27.4s\n" + "addp v6.4s, v28.4s, v29.4s\n" + "addp v7.4s, v30.4s, v31.4s\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v12.4s, v0.4s, v1.4s\n" + "addp v13.4s, v2.4s, v3.4s\n" + "addp v14.4s, v4.4s, v5.4s\n" + "addp v15.4s, v6.4s, v7.4s\n" + + "st1 {v12.4s}, [%[c]], %[ldc] \n\t" + "st1 {v13.4s}, [%[c]], %[ldc] \n\t" + "st1 {v14.4s}, [%[c]], %[ldc] \n\t" + "st1 {v15.4s}, [%[c]] \n\t" + + : [k] "+r"(k), [a] "+r"(a), [b] "+r"(b), [c] "+r"(c) // outputs + : [ldc] "r"(ldc) // inputs + : "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31"); // clobbers +#undef PADDLE_LABEL_AFTER_LOOP +#undef PADDLE_LABEL_LOOP +#else +// AddDot4x2 used only for aarch64 +#endif // __aarch64__ +#endif // __ARM_NEON +} + // 8 bits int small block inner product void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { #if __ARM_NEON #if __aarch64__ -// TODO +// AddDot6x8 used only for aarch32 #else const int8_t *a_ptr, *b_ptr; a_ptr = a; @@ -681,10 +895,8 @@ void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, for (int32_t j = 0; j < nc; j += NR_INT8) { for (int32_t i = 0; i < mc; i += MR_INT8) { #if __aarch64__ - // TODO + AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); #else - // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); - // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); #endif // __aarch64__ } @@ -704,10 +916,8 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, for (int32_t j = 0; j < nc; j += NR_INT8) { for (int32_t i = 0; i < mc; i += MR_INT8) { #if __aarch64__ - // TODO + AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); #else - // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); - // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); #endif // __aarch64__ } @@ -730,6 +940,149 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, int32_t *c, int32_t *C, int32_t ldc, bool relu, int32_t *bias, bool addOnRow) {} +// 8 bits int PackMatrixA_4r +void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer) { + const int8_t *a0, *a1, *a2, *a3; + for (int32_t i = 0; i < m - m_tail; i += 4) { + a0 = A + i * lda; + a1 = A + (i + 1) * lda; + a2 = A + (i + 2) * lda; + a3 = A + (i + 3) * lda; + for (int32_t j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + } + } + + if (m_tail != 0) { + a0 = &A(m - m_tail, 0); + a1 = a0 + lda; + a2 = a0 + 2 * lda; + a3 = a0 + 3 * lda; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int j = 0; j < k; ++j) { + *buffer++ = *a0++; + *buffer++ = *a1++; + *buffer++ = *a2++; + *buffer++ = *a3++; + } + } +} + +// 8 bits int PackMatrixA_6r +void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer) { + const int32_t i_length = m - m_tail; + for (int32_t i = 0; i < i_length; i += 6) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + const int8_t *a4 = A + (i + 4) * lda; + const int8_t *a5 = A + (i + 5) * lda; + int8_t *local_buffer = buffer + i * k; + for (int32_t j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + const int8_t *a4 = a0 + 4 * lda; + const int8_t *a5 = a0 + 5 * lda; + int8_t *local_buffer = buffer + i_length * k; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + case 4: + a4 = zero_int8; + case 5: + a5 = zero_int8; + break; + default: + break; + } + for (int32_t j = 0; j < k; ++j) { + *local_buffer++ = *a0++; + *local_buffer++ = *a1++; + *local_buffer++ = *a2++; + *local_buffer++ = *a3++; + *local_buffer++ = *a4++; + *local_buffer++ = *a5++; + } + } +} + +// 8 bits int PackMatrixB +void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + for (int32_t j = 0; j < j_length; j += 8) { + int8_t *local_buffer = buffer + j * k; + for (int32_t i = 0; i < k; ++i) { + const int8_t *b0 = &B(i, j); +#if __ARM_NEON +#if __aarch64__ +// PackMatrixB_8c used only for aarch32 +#else + asm volatile( + // "pld [%[b0]] \n\t" + "vld1.s8 {d0}, [%[b0]] \n\t" + "vst1.s8 {d0}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer) + : [b0] "r"(b0) + : "memory", "q0"); +#endif // __aarch64__ +#else + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; + *local_buffer++ = *b0++; +#endif // __ARM_NEON + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * k; + for (int32_t i = 0; i < k; ++i) { + const int8_t *b0 = &B(i, j_length); + for (int32_t j = j_length; j < n; ++j) { + *local_buffer++ = *b0++; + } + for (int32_t j = n; j < j_length + 8; ++j) { + *local_buffer++ = 0; + } + } + } +} + // 8 bits int PackMatrixA_4r void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { @@ -746,7 +1099,19 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, for (int32_t j = 0; j < k_count; ++j) { #if __ARM_NEON #if __aarch64__ - // TODO + asm volatile( + "ld1 {v0.16b}, [%[a0]], #16 \n\t" + "ld1 {v1.16b}, [%[a1]], #16 \n\t" + "ld1 {v2.16b}, [%[a2]], #16 \n\t" + "ld1 {v3.16b}, [%[a3]], #16 \n\t" + "st1 {v0.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v1.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v2.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v3.16b}, [%[local_buffer]], #16 \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "v0", "v1", "v2", "v3"); #else asm volatile( "vld1.s8 {d0, d1}, [%[a0]]! \n\t" @@ -826,7 +1191,19 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, for (int32_t j = 0; j < k_count; ++j) { #if __ARM_NEON #if __aarch64__ - // TODO + asm volatile( + "ld1 {v0.16b}, [%[a0]], #16 \n\t" + "ld1 {v1.16b}, [%[a1]], #16 \n\t" + "ld1 {v2.16b}, [%[a2]], #16 \n\t" + "ld1 {v3.16b}, [%[a3]], #16 \n\t" + "st1 {v0.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v1.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v2.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v3.16b}, [%[local_buffer]], #16 \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "v0", "v1", "v2", "v3"); #else asm volatile( "vld1.s8 {d0, d1}, [%[a0]]! \n\t" @@ -887,103 +1264,6 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, } } -// 8 bits int PackMatrixA_4r -void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, - int32_t lda, int8_t *buffer) { - const int8_t *a0, *a1, *a2, *a3; - for (int32_t i = 0; i < m - m_tail; i += 4) { - a0 = A + i * lda; - a1 = A + (i + 1) * lda; - a2 = A + (i + 2) * lda; - a3 = A + (i + 3) * lda; - for (int32_t j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - } - } - - if (m_tail != 0) { - a0 = &A(m - m_tail, 0); - a1 = a0 + lda; - a2 = a0 + 2 * lda; - a3 = a0 + 3 * lda; - switch (m_tail) { - case 1: - a1 = zero_int8; - case 2: - a2 = zero_int8; - case 3: - a3 = zero_int8; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - } - } -} - -// 8 bits int PackMatrixA_6r -void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, - int32_t lda, int8_t *buffer) { - const int32_t i_length = m - m_tail; - for (int32_t i = 0; i < i_length; i += 6) { - const int8_t *a0 = A + i * lda; - const int8_t *a1 = A + (i + 1) * lda; - const int8_t *a2 = A + (i + 2) * lda; - const int8_t *a3 = A + (i + 3) * lda; - const int8_t *a4 = A + (i + 4) * lda; - const int8_t *a5 = A + (i + 5) * lda; - int8_t *local_buffer = buffer + i * k; - for (int32_t j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - } - } - if (m_tail != 0) { - const int8_t *a0 = &A(i_length, 0); - const int8_t *a1 = a0 + lda; - const int8_t *a2 = a0 + 2 * lda; - const int8_t *a3 = a0 + 3 * lda; - const int8_t *a4 = a0 + 4 * lda; - const int8_t *a5 = a0 + 5 * lda; - int8_t *local_buffer = buffer + i_length * k; - switch (m_tail) { - case 1: - a1 = zero_int8; - case 2: - a2 = zero_int8; - case 3: - a3 = zero_int8; - case 4: - a4 = zero_int8; - case 5: - a5 = zero_int8; - break; - default: - break; - } - for (int32_t j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - } - } -} - // 8 bits int PackMatrixB void Gemm::PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer) { @@ -1052,46 +1332,79 @@ void Gemm::PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, } } -// 8 bits int PackMatrixB -void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, - int32_t ldb, int8_t *buffer) { +void Gemm::PackMatrixB_4c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer) { const int32_t j_length = n - n_tail; - for (int32_t j = 0; j < j_length; j += 8) { - int8_t *local_buffer = buffer + j * k; - for (int32_t i = 0; i < k; ++i) { - const int8_t *b0 = &B(i, j); -#if __ARM_NEON -#if __aarch64__ - // TODO -#else - asm volatile( - // "pld [%[b0]] \n\t" - "vld1.s8 {d0}, [%[b0]] \n\t" - "vst1.s8 {d0}, [%[local_buffer]]! \n\t" - : [local_buffer] "+r"(local_buffer) - : [b0] "r"(b0) - : "memory", "q0"); -#endif // __aarch64__ -#else - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; -#endif // __ARM_NEON + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; + for (int32_t j = 0; j < n; j += 4) { + int8_t *local_buffer = buffer + j * KC; + const int8_t *b0 = &B(0, j); + const int8_t *b1 = b0 + 1; + const int8_t *b2 = b0 + 2; + const int8_t *b3 = b0 + 3; + if (j > j_length) { + switch (n_tail) { + case 1: + b1 = zero_int8; + case 2: + b2 = zero_int8; + case 3: + b3 = zero_int8; + break; + default: + break; + } } - } - if (n_tail != 0) { - int8_t *local_buffer = buffer + j_length * k; - for (int32_t i = 0; i < k; ++i) { - const int8_t *b0 = &B(i, j_length); - for (int32_t j = j_length; j < n; ++j) { - *local_buffer++ = *b0++; + + for (int32_t i = 0; i < k_count; ++i) { + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; } - for (int32_t j = n; j < j_length + 8; ++j) { + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b1; + b1 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b2; + b2 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b3; + b3 += ldb; + } + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b1; + b1 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b2; + b2 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b3; + b3 += ldb; + } + for (int32_t j = k; j < KC; ++j) { *local_buffer++ = 0; } } @@ -1104,7 +1417,35 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc) { #if __ARM_NEON #if __aarch64__ -// TODO + int32_t nc1 = nc / 4; + int32_t _nc1 = nc % 4; + + int32_t *c_ptr, *C_ptr; + int32x4_t cv; + for (int32_t i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + for (int32_t j = 0; j < nc1; ++j) { + cv = vld1q_s32(c_ptr); + vst1q_s32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_s32(c_ptr); + if (_nc1 >= 1) { + vst1q_lane_s32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_s32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_s32(C_ptr, cv, 2); + } + } + } #else int32_t nc1 = nc >> 4; int32_t _nc1 = nc & 15; @@ -1168,7 +1509,67 @@ void Gemm::WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, int32_t ldc, int32_t *bias, float scale) { #if __ARM_NEON #if __aarch64__ -// TODO + int32_t nc1 = nc >> 3; + int32_t _nc1 = nc & 7; + + int32_t *c_ptr; + int8_t *C_ptr; + int32x4_t cv0; + int32x4_t cv1; + int16x8_t cv_h; + int8x8_t cv_b; + int32x4_t biasv; + int8_t min = -127; + int8x8_t minv = vdup_n_s8(min); + for (int32_t i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + biasv = vld1q_dup_s32(bias + i); + for (int32_t j = 0; j < nc1; ++j) { + cv0 = vld1q_s32(c_ptr); + cv1 = vld1q_s32(c_ptr + 4); + cv0 = vqaddq_s32(cv0, biasv); + cv1 = vqaddq_s32(cv1, biasv); + + cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1)); + cv_b = vqmovn_s16(cv_h); + + cv_b = vmax_s8(cv_b, minv); + vst1_s8(C_ptr, cv_b); + c_ptr += 8; + C_ptr += 8; + } + if (_nc1 != 0) { + cv0 = vld1q_s32(c_ptr); + cv1 = vld1q_s32(c_ptr + 4); + cv0 = vqaddq_s32(cv0, biasv); + cv1 = vqaddq_s32(cv1, biasv); + + cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1)); + cv_b = vqmovn_s16(cv_h); + + cv_b = vmax_s8(cv_b, minv); + + switch (_nc1) { + case 7: + vst1_lane_s8(C_ptr + 6, cv_b, 6); + case 6: + vst1_lane_s8(C_ptr + 5, cv_b, 5); + case 5: + vst1_lane_s8(C_ptr + 4, cv_b, 4); + case 4: + vst1_lane_s8(C_ptr + 3, cv_b, 3); + case 3: + vst1_lane_s8(C_ptr + 2, cv_b, 2); + case 2: + vst1_lane_s8(C_ptr + 1, cv_b, 1); + case 1: + vst1_lane_s8(C_ptr, cv_b, 0); + default: + break; + } + } + } #else int8_t narrow = -128; int32_t nc1 = nc >> 3; @@ -1291,7 +1692,74 @@ void Gemm::WriteWithAddScaleT(int32_t mc, int32_t nc, int32_t *c, int8_t *C, int32_t ldc, int32_t *bias, float scale) { #if __ARM_NEON #if __aarch64__ -// TODO + int32_t nc1 = nc >> 3; + int32_t _nc1 = nc & 7; + + int32_t *c_ptr; + int8_t *C_ptr; + int32x4_t cv0; + int32x4_t cv1; + int16x8_t cv_h; + int8x8_t cv_b; + int32_t *bias_ptr; + int32x4_t biasv0; + int32x4_t biasv1; + int8_t min = -127; + int8x8_t minv = vdup_n_s8(min); + for (int32_t i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + bias_ptr = bias; + for (int32_t j = 0; j < nc1; ++j) { + cv0 = vld1q_s32(c_ptr); + cv1 = vld1q_s32(c_ptr + 4); + biasv0 = vld1q_s32(bias_ptr); + biasv1 = vld1q_s32(bias_ptr + 4); + cv0 = vqaddq_s32(cv0, biasv0); + cv1 = vqaddq_s32(cv1, biasv1); + + cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1)); + cv_b = vqmovn_s16(cv_h); + + cv_b = vmax_s8(cv_b, minv); + vst1_s8(C_ptr, cv_b); + c_ptr += 8; + C_ptr += 8; + bias_ptr += 8; + } + if (_nc1 != 0) { + cv0 = vld1q_s32(c_ptr); + cv1 = vld1q_s32(c_ptr + 4); + biasv0 = vld1q_s32(bias_ptr); + biasv1 = vld1q_s32(bias_ptr + 4); + cv0 = vqaddq_s32(cv0, biasv0); + cv1 = vqaddq_s32(cv1, biasv1); + + cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1)); + cv_b = vqmovn_s16(cv_h); + + cv_b = vmax_s8(cv_b, minv); + + switch (_nc1) { + case 7: + vst1_lane_s8(C_ptr + 6, cv_b, 6); + case 6: + vst1_lane_s8(C_ptr + 5, cv_b, 5); + case 5: + vst1_lane_s8(C_ptr + 4, cv_b, 4); + case 4: + vst1_lane_s8(C_ptr + 3, cv_b, 3); + case 3: + vst1_lane_s8(C_ptr + 2, cv_b, 2); + case 2: + vst1_lane_s8(C_ptr + 1, cv_b, 1); + case 1: + vst1_lane_s8(C_ptr, cv_b, 0); + default: + break; + } + } + } #else int8_t narrow = -128; int32_t nc1 = nc >> 3; @@ -1414,7 +1882,67 @@ void Gemm::WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, int32_t ldc, int32_t *bias, float scale) { #if __ARM_NEON #if __aarch64__ -// TODO + int32_t nc1 = nc >> 3; + int32_t _nc1 = nc & 7; + + int32_t *c_ptr; + int8_t *C_ptr; + int32x4_t cv0; + int32x4_t cv1; + int16x8_t cv_h; + int8x8_t cv_b; + int32x4_t biasv; + int32x4_t zero = vdupq_n_s32(0); + for (int32_t i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + biasv = vld1q_dup_s32(bias + i); + for (int32_t j = 0; j < nc1; ++j) { + cv0 = vld1q_s32(c_ptr); + cv1 = vld1q_s32(c_ptr + 4); + cv0 = vqaddq_s32(cv0, biasv); + cv1 = vqaddq_s32(cv1, biasv); + cv0 = vmaxq_s32(cv0, zero); + cv1 = vmaxq_s32(cv1, zero); + + cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1)); + cv_b = vqmovn_s16(cv_h); + + vst1_s8(C_ptr, cv_b); + c_ptr += 8; + C_ptr += 8; + } + if (_nc1 != 0) { + cv0 = vld1q_s32(c_ptr); + cv1 = vld1q_s32(c_ptr + 4); + cv0 = vqaddq_s32(cv0, biasv); + cv1 = vqaddq_s32(cv1, biasv); + cv0 = vmaxq_s32(cv0, zero); + cv1 = vmaxq_s32(cv1, zero); + + cv_h = vcombine_s16(vqmovn_s32(cv0), vqmovn_s32(cv1)); + cv_b = vqmovn_s16(cv_h); + + switch (_nc1) { + case 7: + vst1_lane_s8(C_ptr + 6, cv_b, 6); + case 6: + vst1_lane_s8(C_ptr + 5, cv_b, 5); + case 5: + vst1_lane_s8(C_ptr + 4, cv_b, 4); + case 4: + vst1_lane_s8(C_ptr + 3, cv_b, 3); + case 3: + vst1_lane_s8(C_ptr + 2, cv_b, 2); + case 2: + vst1_lane_s8(C_ptr + 1, cv_b, 1); + case 1: + vst1_lane_s8(C_ptr, cv_b, 0); + default: + break; + } + } + } #else int32_t zero = 0; int32_t nc1 = nc >> 3; diff --git a/src/operators/math/gemm_omp_int8.cpp b/src/operators/math/gemm_omp_int8.cpp index 61f0be418f..2ea4520181 100644 --- a/src/operators/math/gemm_omp_int8.cpp +++ b/src/operators/math/gemm_omp_int8.cpp @@ -37,7 +37,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *b0 = &B(i, j); #if __ARM_NEON #if __aarch64__ - // TODO +// PackMatrixB_omp_8c used only for aarch32 #else asm volatile( // "pld [%[b0]] \n\t" @@ -133,7 +133,19 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail, for (int32_t j = 0; j < k_count; ++j) { #if __ARM_NEON #if __aarch64__ - // TODO + asm volatile( + "ld1 {v0.16b}, [%[a0]], #16 \n\t" + "ld1 {v1.16b}, [%[a1]], #16 \n\t" + "ld1 {v2.16b}, [%[a2]], #16 \n\t" + "ld1 {v3.16b}, [%[a3]], #16 \n\t" + "st1 {v0.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v1.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v2.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v3.16b}, [%[local_buffer]], #16 \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "v0", "v1", "v2", "v3"); #else asm volatile( "vld1.s8 {d0, d1}, [%[a0]]! \n\t" @@ -213,7 +225,19 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail, for (int32_t j = 0; j < k_count; ++j) { #if __ARM_NEON #if __aarch64__ - // TODO + asm volatile( + "ld1 {v0.16b}, [%[a0]], #16 \n\t" + "ld1 {v1.16b}, [%[a1]], #16 \n\t" + "ld1 {v2.16b}, [%[a2]], #16 \n\t" + "ld1 {v3.16b}, [%[a3]], #16 \n\t" + "st1 {v0.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v1.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v2.16b}, [%[local_buffer]], #16 \n\t" + "st1 {v3.16b}, [%[local_buffer]], #16 \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "v0", "v1", "v2", "v3"); #else asm volatile( "vld1.s8 {d0, d1}, [%[a0]]! \n\t" @@ -343,6 +367,87 @@ void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail, } } +// 8 bits int PackMatrixB +void Gemm::PackMatrixB_omp_4c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; +#pragma omp parallel for + for (int32_t j = 0; j < n; j += 4) { + int8_t *local_buffer = buffer + j * KC; + const int8_t *b0 = &B(0, j); + const int8_t *b1 = b0 + 1; + const int8_t *b2 = b0 + 2; + const int8_t *b3 = b0 + 3; + if (j > j_length) { + switch (n_tail) { + case 1: + b1 = zero_int8; + case 2: + b2 = zero_int8; + case 3: + b3 = zero_int8; + break; + default: + break; + } + } + + for (int32_t i = 0; i < k_count; ++i) { + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b1; + b1 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b2; + b2 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b3; + b3 += ldb; + } + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b1; + b1 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b2; + b2 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b3; + b3 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/test/common/test_gemm_int8_accuracy.cpp b/test/common/test_gemm_int8_accuracy.cpp index 493a33af95..7d20a178c1 100644 --- a/test/common/test_gemm_int8_accuracy.cpp +++ b/test/common/test_gemm_int8_accuracy.cpp @@ -174,7 +174,7 @@ int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr, int lda = k; int ldb = n; int ldc = n; - float scale = 0.00628f; + float scale = 1; default_random_engine e; uniform_int_distribution pixel(-127, 127); int8_t *a = static_cast( -- GitLab