diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index b01a654c713f2328d62714f23af68d606380d203..ce111ed78f7b81affffc646b49a00e6d15cbb697 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -107,9 +107,15 @@ inline void GemmConv(const ConvParam ¶m) { Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, + if (param.Input()->type() == typeid(int8_t)) { + math::matmul_int8(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, static_cast(0)); + } else { + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(0)); + } } } } diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index 07e634e3be9648520357871d91d6677aec6b5c0e..62e8ae03d9119cafc3c5716042569a90f077325c 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -73,8 +73,8 @@ void MulCompute(const MulParam ¶m) { } if (param.InputX()->type() == typeid(int8_t)) { out->mutable_data(); - math::matmul(x_matrix, false, y_matrix, false, - static_cast(1), out, static_cast(0)); + math::matmul_int8(x_matrix, false, y_matrix, false, static_cast(1), + out, static_cast(0)); } else { out->mutable_data(); diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 8498992fcecbcb2c9a773fba874e108c013a04fc..e409fe07dc55bcf68748f0f25b3b63480d25cd56 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -23,10 +23,12 @@ limitations under the License. */ #if __aarch64__ #define MR_INT8 4 +#define NR_INT8 2 #define MR 6 #define NR 16 #else #define MR_INT8 4 +#define NR_INT8 2 #define MR 6 #define NR 8 #endif @@ -193,52 +195,58 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, // 8 bits int small block inner product void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); + void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc); void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); // 8 bits int inner product - void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, - const int8_t *a, const int8_t *b, int8_t beta, - int32_t *c, int32_t *C, int32_t ldc, bool relu, - int8_t *bias); + void InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, int32_t *C, + int32_t ldc, bool relu); + void InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, int8_t *C, + int32_t ldc, bool relu, int32_t *bias); // 8 bits int pack function void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); + void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); + void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer); void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer); void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer); + void PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail, + const int8_t *A, int32_t lda, int8_t *buffer); + void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer); // 8 bits int matrix product - void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C, - int32_t ldc, bool relu, int8_t *bias); - void Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, - int32_t *C, int32_t ldc, bool relu, int8_t *bias); + void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, int32_t *C, + int32_t ldc, bool relu, int32_t *bias); + void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, int8_t *C, + int32_t ldc, bool relu, int32_t *bias); + void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, + int32_t *C, int32_t ldc, bool relu, int32_t *bias); // 8 bits int write back - // C = alpha * A * B + beta * C - void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc); // C = A * B void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc); - // C = A * B + C - void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc); - // C = A * B + bias - void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias); - // C = A * B + C, relu(C) - void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc); - // C = A * B + bias, relu(C) - void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias); + // C = A * B + bias, scale * relu(C) + void WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale); + // C = A * B + bias, scale * C + void WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale); private: int MC = 0; @@ -254,7 +262,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, // 8 bits int int8_t *packedA_int8; int8_t *packedB_int8; - int32_t *packedC_int8; + int32_t *packedC_int32; int8_t *zero_int8; }; diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index b16db7fe6acf0c3c7fb2902c9fb3f6e3dc81a65f..555672720f2be51631ea10808ce6891b08df0721 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -18,6 +18,8 @@ limitations under the License. */ #include "operators/math/gemm.h" #if __ARM_NEON #include +#include + #endif #ifdef _OPENMP #include @@ -62,7 +64,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, "pld [%[b_ptr], #128] \n\t" "vld1.s8 {d0-d3}, [%[a_ptr]]! \n\t" // load A 8 cols "vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B first 4 rows - "vmovl.s8 q2, d0 \n\t" // process B first 4 + "vmovl.s8 q2, d0 \n\t" // process B first // rows "vmovl.s8 q3, d8 \n\t" "vmlal.s16 q8, d6, d4[0]\n\t" @@ -241,6 +243,132 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, #endif // __ARM_NEON } +void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc) { +#if __ARM_NEON +#if __aarch64__ +// TODO(wzzju) +#else +#define PADDLE_LABEL_LOOP "1" +#define PADDLE_LABEL_AFTER_LOOP "2" + asm volatile( + "lsl %[ldc], %[ldc], #2 \n\t" // sizeof(int32) == 4 + "vldr d0, [%[b], #0] \n\t" + "vmov.s32 q8, #0 \n\t" + "vldr d4, [%[a], #0] \n\t" + "vmov.s32 q9, q8 \n\t" + "vldr d2, [%[b], #16] \n\t" + "vmov.s32 q10, q8 \n\t" + "vldr d6, [%[a], #16] \n\t" + "vmov.s32 q11, q8 \n\t" + "vldr d1, [%[b], #8]\n\t" + "vmov.s32 q12, q8 \n\t" + "vldr d5, [%[a], #8]\n" + "vmov.s32 q13, q8 \n\t" + "vldr d3, [%[b], #24]\n\t" + "vmov.s32 q14, q8 \n\t" + "vldr d7, [%[a], #24]\n" + "vmov.s32 q15, q8 \n\t" + + PADDLE_LABEL_LOOP + ": \n\t" + "vmull.s8 q4, d0, d4 \n\t" // first half + "add %[b], %[b], #32 \n\t" + "vmull.s8 q5, d2, d4 \n\t" + "vldr d4, [%[a], #32] \n\t" + "vmull.s8 q6, d0, d6 \n\t" + "vmull.s8 q7, d2, d6 \n\t" + "vldr d6, [%[a], #48] \n\t" + + "vmlal.s8 q4, d1, d5 \n\t" // second half + "vmlal.s8 q5, d3, d5 \n\t" + "vldr d5, [%[a], #40] \n\t" + "vmlal.s8 q6, d1, d7 \n\t" + "vmlal.s8 q7, d3, d7 \n\t" + "vldr d7, [%[a], #56] \n\t" + + "vpadal.s16 q8, q4 \n\t" // pairwise-add + "add %[a], %[a], #64 \n\t" + "vpadal.s16 q9, q5 \n\t" + "subs %[k], %[k], #16 \n\t" + "vpadal.s16 q10, q6 \n\t" + "vpadal.s16 q11, q7 \n\t" + + "beq " PADDLE_LABEL_AFTER_LOOP + "f \n\t" + + "vmull.s8 q4, d0, d4 \n\t" // first half + "vmull.s8 q5, d2, d4 \n\t" + "vldr d4, [%[a], #0] \n\t" + "vmull.s8 q6, d0, d6 \n\t" + "vldr d0, [%[b], #0] \n\t" + "vmull.s8 q7, d2, d6 \n\t" + "vldr d2, [%[b], #16] \n\t" + + "vmlal.s8 q4, d1, d5 \n\t" // second half + "vldr d6, [%[a], #16] \n\t" + "vmlal.s8 q5, d3, d5 \n\t" + "vldr d5, [%[a], #8] \n\t" + "vmlal.s8 q6, d1, d7 \n\t" + "vldr d1, [%[b], #8] \n\t" + "vmlal.s8 q7, d3, d7 \n\t" + "vldr d3, [%[b], #24] \n\t" + + "vpadal.s16 q12, q4 \n\t" // pairwise-add + "vldr d7, [%[a], #24] \n\t" + "vpadal.s16 q13, q5 \n\t" + "vpadal.s16 q14, q6 \n\t" + "vpadal.s16 q15, q7 \n\t" + + "b " PADDLE_LABEL_LOOP "b \n\t" + + PADDLE_LABEL_AFTER_LOOP + ": \n\t" + "vmull.s8 q4, d0, d4 \n\t" // first half + "vmull.s8 q5, d2, d4 \n\t" + "vmull.s8 q6, d0, d6 \n\t" + "vmull.s8 q7, d2, d6 \n\t" + + "vmlal.s8 q4, d1, d5 \n\t" // second half + "vmlal.s8 q5, d3, d5 \n\t" + "vmlal.s8 q6, d1, d7 \n\t" + "vmlal.s8 q7, d3, d7 \n\t" + + "vpadal.s16 q12, q4 \n\t" // pairwise-add + "vpadal.s16 q13, q5 \n\t" + "vpadal.s16 q14, q6 \n\t" + "vpadal.s16 q15, q7 \n\t" + + "vpadd.s32 d0, d16, d17 \n\t" // reduce to int32 + "vpadd.s32 d1, d18, d19 \n\t" + "vpadd.s32 d2, d20, d21 \n\t" + "vpadd.s32 d3, d22, d23 \n\t" + "vpadd.s32 d4, d24, d25 \n\t" + "vpadd.s32 d5, d26, d27 \n\t" + "vpadd.s32 d6, d28, d29 \n\t" + "vpadd.s32 d7, d30, d31 \n\t" + + "vpadd.s32 d8, d0, d1 \n\t" // reduce to int32 again + "vpadd.s32 d9, d2, d3 \n\t" + "vpadd.s32 d10, d4, d5 \n\t" + "vpadd.s32 d11, d6, d7 \n\t" + + "vst1.32 {d8}, [%[c]], %[ldc] \n\t" + "vst1.32 {d9}, [%[c]], %[ldc] \n\t" + "vst1.32 {d10}, [%[c]], %[ldc] \n\t" + "vst1.32 {d11}, [%[c]] \n\t" + + : [k] "+r"(k), [a] "+r"(a), [b] "+r"(b), [c] "+r"(c) + : [ldc] "r"(ldc) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#undef PADDLE_LABEL_AFTER_LOOP +#undef PADDLE_LABEL_LOOP + +#endif // __aarch64__ +#endif // __ARM_NEON +} + // 8 bits int small block inner product void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { @@ -539,51 +667,213 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, } // 8 bits int inner product -void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, - const int8_t *a, const int8_t *b, int8_t beta, - int32_t *c, int32_t *C, int32_t ldc, bool relu, - int8_t *bias) { +void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, int32_t *C, + int32_t ldc, bool relu) { #pragma omp parallel for - for (int32_t j = 0; j < nc; j += NR) { + for (int32_t j = 0; j < nc; j += NR_INT8) { for (int32_t i = 0; i < mc; i += MR_INT8) { #if __aarch64__ // TODO(wzzju) #else // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); - AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); #endif // __aarch64__ } } - if (alpha != 1) { - WriteWithAlphaBeta(mc, nc, c, C, ldc); - return; - } - if (beta == 0) { + if (!relu) { WriteBasic(mc, nc, c, C, ldc); return; } - if (beta == 1 && !relu) { - if (bias == nullptr) { - WriteWithAdd(mc, nc, c, C, ldc); - } else { - WriteWithAddV1(mc, nc, c, C, ldc, bias); +} + +void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, + const int8_t *a, const int8_t *b, float beta, + int32_t *c, int8_t *C, int32_t ldc, bool relu, + int32_t *bias) { +#pragma omp parallel for + for (int32_t j = 0; j < nc; j += NR_INT8) { + for (int32_t i = 0; i < mc; i += MR_INT8) { +#if __aarch64__ + // TODO(wzzju) +#else + // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#endif // __aarch64__ } + } + if (relu) { + WriteWithAddReluScale(mc, nc, c, C, ldc, bias, alpha); return; + } else { + WriteWithAddScale(mc, nc, c, C, ldc, bias, alpha); } - if (beta == 1 && relu) { - if (bias == nullptr) { - WriteWithAddRelu(mc, nc, c, C, ldc); - } else { - WriteWithAddReluV1(mc, nc, c, C, ldc, bias); +} + +// 8 bits int PackMatrixA_4r +void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, + const int8_t *A, int32_t lda, int8_t *buffer) { + const int32_t i_length = m - m_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; + + for (int32_t i = 0; i < i_length; i += 4) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + int8_t *local_buffer = buffer + i * KC; + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO(wzzju) +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + int8_t *local_buffer = buffer + i_length * KC; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO(wzzju) +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } } - return; } } + // 8 bits int PackMatrixA_4r void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { const int8_t *a0, *a1, *a2, *a3; - for (int32_t i = 0; i < m - m_tail; i += MR_INT8) { + for (int32_t i = 0; i < m - m_tail; i += 4) { a0 = A + i * lda; a1 = A + (i + 1) * lda; a2 = A + (i + 2) * lda; @@ -625,7 +915,7 @@ void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { const int32_t i_length = m - m_tail; - for (int32_t i = 0; i < i_length; i += MR_INT8) { + for (int32_t i = 0; i < i_length; i += 6) { const int8_t *a0 = A + i * lda; const int8_t *a1 = A + (i + 1) * lda; const int8_t *a2 = A + (i + 2) * lda; @@ -676,11 +966,79 @@ void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, } } +// 8 bits int PackMatrixB +void Gemm::PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; + for (int32_t j = 0; j < j_length; j += 2) { + int8_t *local_buffer = buffer + j * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j); + const int8_t *b1 = &B((i << 4), j + 1); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b1; + b1 += ldb; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j); + const int8_t *b1 = &B((k_count << 4), j + 1); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b1; + b1 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j_length); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = 0; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j_length); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } +} + // 8 bits int PackMatrixB void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer) { const int32_t j_length = n - n_tail; - for (int32_t j = 0; j < j_length; j += NR) { + for (int32_t j = 0; j < j_length; j += 8) { int8_t *local_buffer = buffer + j * k; for (int32_t i = 0; i < k; ++i) { const int8_t *b0 = &B(i, j); @@ -715,7 +1073,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, for (int32_t j = j_length; j < n; ++j) { *local_buffer++ = *b0++; } - for (int32_t j = n; j < j_length + NR; ++j) { + for (int32_t j = n; j < j_length + 8; ++j) { *local_buffer++ = 0; } } @@ -723,19 +1081,20 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, } // 8 bits int matrix product (m*k x k*n) -void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, - int32_t *C, int32_t ldc, bool relu, int8_t *bias) { +void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, + int32_t *C, int32_t ldc, bool relu, int32_t *bias) { // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) // L2 cache is 0.5~4 Mib (Contex-A72 cluster) int32_t L1 = 32 * 1024; int32_t L2 = 512 * 1024; - KC = k; + const int32_t k_complete = (k + 15) - ((k + 15) & 15); + KC = k_complete; MC = L1 / (KC * sizeof(int8_t)); NC = L2 / (KC * sizeof(int8_t)); - // make sure MC is multiple of MR_INT8, and NC is multiple of NR + // make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8 if (MC == 0) { MC = MR_INT8; } else { @@ -745,52 +1104,106 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, } // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; if (NC == 0) { - NC = NR; + NC = NR_INT8; } else { int32_t nblock_num = (n + NC - 1) / NC; NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; + NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; } // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; packedA_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); packedB_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); - packedC_int8 = static_cast( + packedC_int32 = static_cast( paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); zero_int8 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); - memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); + memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); int32_t mc, nc; for (int32_t j = 0; j < n; j += NC) { nc = s_min(n - j, NC); - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); + PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8); for (int32_t i = 0; i < m; i += MC) { mc = s_min(m - i, MC); - // PackMatrixA_6r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); - PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8); if (bias == nullptr) { + InnerKernel(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int32, &C(i, j), ldc, relu); + } + } + } + + paddle_mobile::memory::Free(packedA_int8); + paddle_mobile::memory::Free(packedB_int8); + paddle_mobile::memory::Free(packedC_int32); + paddle_mobile::memory::Free(zero_int8); +} + +// 8 bits int matrix product (m*k x k*n) +void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, + int8_t *C, int32_t ldc, bool relu, int32_t *bias) { + // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) + // L2 cache is 0.5~4 Mib (Contex-A72 cluster) + int32_t L1 = 32 * 1024; + int32_t L2 = 512 * 1024; + + const int32_t k_complete = (k + 15) - ((k + 15) & 15); + KC = k_complete; + MC = L1 / (KC * sizeof(int8_t)); + NC = L2 / (KC * sizeof(int8_t)); + + // make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8 + if (MC == 0) { + MC = MR_INT8; + } else { + int32_t mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; + } + // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; + if (NC == 0) { + NC = NR_INT8; + } else { + int32_t nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; + } + // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); + packedC_int32 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); + zero_int8 = + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); + + memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); + int32_t mc, nc; + for (int32_t j = 0; j < n; j += NC) { + nc = s_min(n - j, NC); + PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8); + for (int32_t i = 0; i < m; i += MC) { + mc = s_min(m - i, MC); + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8); + if (bias != nullptr) { InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, - packedC_int8, &C(i, j), ldc, relu, nullptr); - } else { - InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, - packedC_int8, &C(i, j), ldc, relu, bias + i); + packedC_int32, &C(i, j), ldc, relu, bias + i); } } } paddle_mobile::memory::Free(packedA_int8); paddle_mobile::memory::Free(packedB_int8); - paddle_mobile::memory::Free(packedC_int8); + paddle_mobile::memory::Free(packedC_int32); paddle_mobile::memory::Free(zero_int8); } // 8 bits int write back -// C = alpha * A * B + beta * C -void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc) {} -// C = A * B, 8位 int32_t +// C = A * B void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc) { #if __ARM_NEON @@ -802,7 +1215,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t step = sizeof(int32_t) * ldc; int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 4)); int32_t volatile m = mc; - + int32_t volatile n = nc1; int32_t *volatile c_ptr, *volatile C_ptr; int32_t *C0, *c0; c_ptr = c; @@ -836,7 +1249,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, "end_mc_%=: \n\t" : - : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n), [step] "r"(step), [step1] "r"(step1) : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); } @@ -854,20 +1267,254 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, #endif // __ARM_NEON } -// C = A * B + C -void Gemm::WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc) {} +// C = A * B + bias, scale * C +void Gemm::WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale) { +#if __ARM_NEON +#if __aarch64__ +// TODO(wzzju) +#else + int32_t zero = 0; + int8_t narrow = -128; + int32_t nc1 = nc >> 3; + int32_t _nc1 = nc & 7; + int32_t step = sizeof(int8_t) * ldc; + int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 3)); + int32_t volatile m = mc; + int32_t volatile n = nc1; + int32_t *volatile c_ptr, *volatile bias_ptr; + int8_t *volatile C_ptr; + c_ptr = c; + C_ptr = C; + bias_ptr = bias; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "vdup.8 d24, %[narrow] \n\t" + "loop_mc_%=: \n\t" + "vld1.32 {d26[0]}, [%[bias_ptr]]!\n\t" + "vdup.32 q13, d26[0] \n\t" + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vqadd.s32 q1, q1, q13 \n\t" + "vcvt.f32.s32 q2, q0 \n\t" + "vcvt.f32.s32 q3, q1 \n\t" + "vmul.f32 q2, q2, q15 \n\t" + "vmul.f32 q3, q3, q15 \n\t" + "vcvt.s32.f32 q4, q2 \n\t" + "vcvt.s32.f32 q5, q3 \n\t" + "vqmovn.s32 d12, q4 \n\t" + "vqmovn.s32 d13, q5 \n\t" + "vqmovn.s16 d14, q6 \n\t" + "vceq.s8 d15, d14, d24 \n\t" + "vsub.s8 d14, d14, d15 \n\t" + "vst1.8 {d14}, [r6]! \n\t" + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" -// C = A * B + bias -void Gemm::WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias) {} -// C = A * B + C, relu(C) -void Gemm::WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc) {} + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n), + [step] "r"(step), [step1] "r"(step1), [bias_ptr] "r"(bias_ptr), + [scale] "r"(scale), [zero] "r"(zero), [narrow] "r"(narrow) + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q12", "q13", "q14", "q15"); + } -// C = A * B + bias, relu(C) -void Gemm::WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias) {} + int32_t nc_left; + int32_t *c0; + int8_t *C0; + int32_t bias_v; + if (_nc1 != 0) { + for (int32_t i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 8 + i * ldc; + c0 = c_ptr + nc1 * 8 + i * NC; + bias_v = *(bias_ptr + i); + nc_left = _nc1; + asm volatile( + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "vdup.8 d24, %[narrow] \n\t" + "vdup.32 q13, %[bias_v] \n\t" + "cmp %[_nc1], #4 \n\t" + "blt less_four_%= \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "vceq.s8 d9, d8, d24 \n\t" + "vsub.s8 d8, d8, d9 \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vst1.8 {d8[1]}, [%[C0]]! \n\t" + "vst1.8 {d8[2]}, [%[C0]]! \n\t" + "vst1.8 {d8[3]}, [%[C0]]! \n\t" + "subs %[_nc1], %[_nc1], #4 \n\t" + "beq process_over_%= \n\t" + "less_four_%=: \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "vceq.s8 d9, d8, d24 \n\t" + "vsub.s8 d8, d8, d9 \n\t" + "loop_save_%=: \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vext.8 d8, d8, d8, #1 \n\t" + "subs %[_nc1], %[_nc1], #1 \n\t" + "bgt loop_save_%= \n\t" + "process_over_%=: \n\t" + : + : [_nc1] "r"(nc_left), [C0] "r"(C0), [c0] "r"(c0), + [bias_v] "r"(bias_v), [scale] "r"(scale), [zero] "r"(zero), + [narrow] "r"(narrow) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q12", "q13", "q14", + "q15"); + } + } +#endif // __aarch64__ +#endif // __ARM_NEON +} + +// C = A * B + bias, scale * relu(C) +void Gemm::WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale) { +#if __ARM_NEON +#if __aarch64__ +// TODO(wzzju) +#else + int32_t zero = 0; + int32_t nc1 = nc >> 3; + int32_t _nc1 = nc & 7; + int32_t step = sizeof(int8_t) * ldc; + int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 3)); + int32_t volatile m = mc; + int32_t volatile n = nc1; + int32_t *volatile c_ptr, *volatile bias_ptr; + int8_t *volatile C_ptr; + c_ptr = c; + C_ptr = C; + bias_ptr = bias; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "loop_mc_%=: \n\t" + "vld1.32 {d26[0]}, [%[bias_ptr]]!\n\t" + "vdup.32 q13, d26[0] \n\t" + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vqadd.s32 q1, q1, q13 \n\t" + "vmax.s32 q0, q0, q14 \n\t" + "vmax.s32 q1, q1, q14 \n\t" + "vcvt.f32.s32 q2, q0 \n\t" + "vcvt.f32.s32 q3, q1 \n\t" + "vmul.f32 q2, q2, q15 \n\t" + "vmul.f32 q3, q3, q15 \n\t" + "vcvt.s32.f32 q4, q2 \n\t" + "vcvt.s32.f32 q5, q3 \n\t" + "vqmovn.s32 d12, q4 \n\t" + "vqmovn.s32 d13, q5 \n\t" + "vqmovn.s16 d14, q6 \n\t" + "vst1.8 {d14}, [r6]! \n\t" + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" + + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n), + [step] "r"(step), [step1] "r"(step1), [bias_ptr] "r"(bias_ptr), + [scale] "r"(scale), [zero] "r"(zero) + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q13", "q14", "q15"); + } + + int32_t nc_left; + int32_t *c0; + int8_t *C0; + int32_t bias_v; + if (_nc1 != 0) { + for (int32_t i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 8 + i * ldc; + c0 = c_ptr + nc1 * 8 + i * NC; + bias_v = *(bias_ptr + i); + nc_left = _nc1; + asm volatile( + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "vdup.32 q13, %[bias_v] \n\t" + "cmp %[_nc1], #4 \n\t" + "blt less_four_%= \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vmax.s32 q0, q0, q14 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vst1.8 {d8[1]}, [%[C0]]! \n\t" + "vst1.8 {d8[2]}, [%[C0]]! \n\t" + "vst1.8 {d8[3]}, [%[C0]]! \n\t" + "subs %[_nc1], %[_nc1], #4 \n\t" + "beq process_over_%= \n\t" + "less_four_%=: \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vmax.s32 q0, q0, q14 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "loop_save_%=: \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vext.8 d8, d8, d8, #1 \n\t" + "subs %[_nc1], %[_nc1], #1 \n\t" + "bgt loop_save_%= \n\t" + "process_over_%=: \n\t" + : + : [_nc1] "r"(nc_left), [C0] "r"(C0), [c0] "r"(c0), + [bias_v] "r"(bias_v), [scale] "r"(scale), [zero] "r"(zero) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q13", "q14", "q15"); + } + } +#endif // __aarch64__ +#endif // __ARM_NEON +} } // namespace math } // namespace operators diff --git a/src/operators/math/gemm_omp_int8.cpp b/src/operators/math/gemm_omp_int8.cpp index 21256cccfcc6dcc647f34a2129616b70804d398f..d4d4c294934191ba6717716486bf857477d73b55 100644 --- a/src/operators/math/gemm_omp_int8.cpp +++ b/src/operators/math/gemm_omp_int8.cpp @@ -28,10 +28,10 @@ namespace operators { namespace math { // 8 bits int matrix product (m*k x k*n) -void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, +void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, - int8_t beta, int32_t *C, int32_t ldc, bool relu, - int8_t *bias) { + float beta, int32_t *C, int32_t ldc, bool relu, + int32_t *bias) { #ifdef _OPENMP int32_t max_threads = omp_get_max_threads(); #else @@ -39,10 +39,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, #endif int32_t L1 = 64 / max_threads * 1024; - KC = k; + const int32_t k_complete = (k + 15) - ((k + 15) & 15); + KC = k_complete; zero_int8 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); - memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); + memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(int8_t)); @@ -54,14 +55,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; } // 补齐 B - NC = (n + NR - 1) / NR * NR; + NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8; packedB_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); #if __aarch64__ // TODO(wzzju) #else - PackMatrixB_omp_8c(KC, n, n % NR, B, ldb, packedB_int8); + PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); #endif packedA_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads)); @@ -69,11 +70,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, // 对 B 分块 NC = L1 / (KC * sizeof(int8_t)); if (NC == 0) { - NC = NR; + NC = NR_INT8; } else { int32_t nblock_num = (n + NC - 1) / NC; NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; + NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; } // 补齐 A MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8; @@ -83,12 +84,12 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, #if __aarch64__ // TODO(wzzju) #else - PackMatrixA_omp_4r(m, KC, m % MR_INT8, A, lda, packedA_int8); + PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); #endif packedB_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads)); } - packedC_int8 = static_cast( + packedC_int32 = static_cast( paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads)); if (m > n) { @@ -103,14 +104,19 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, int32_t mc; mc = s_min(m - i, MC); int8_t *local_A = packedA_int8 + MC * KC * local_threads; - int32_t *local_C = packedC_int8 + MC * NC * local_threads; + int32_t *local_C = packedC_int32 + MC * NC * local_threads; #if __aarch64__ // TODO(wzzju) #else - PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, local_A); + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); #endif - InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C, - &C(i, 0), ldc, relu, bias + i); + // InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, + // local_C, + // &C(i, 0), ldc, relu, bias + i); + if (bias == nullptr) { + InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C, + &C(i, 0), ldc, relu); + } } } else { #pragma omp parallel for @@ -123,20 +129,25 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, int32_t nc; nc = s_min(n - j, NC); int8_t *local_B = packedB_int8 + KC * NC * local_threads; - int32_t *local_C = packedC_int8 + MC * NC * local_threads; + int32_t *local_C = packedC_int32 + MC * NC * local_threads; #if __aarch64__ // TODO(wzzju) #else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, local_B); + PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); #endif - InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C, - &C(0, j), ldc, relu, bias); + // InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, + // local_C, + // &C(0, j), ldc, relu, bias); + if (bias == nullptr) { + InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C, + &C(0, j), ldc, relu); + } } } paddle_mobile::memory::Free(packedA_int8); paddle_mobile::memory::Free(packedB_int8); - paddle_mobile::memory::Free(packedC_int8); + paddle_mobile::memory::Free(packedC_int32); paddle_mobile::memory::Free(zero_int8); } @@ -144,7 +155,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer) { const int32_t j_length = n - n_tail; #pragma omp parallel for - for (int32_t j = 0; j < j_length; j += NR) { + for (int32_t j = 0; j < j_length; j += 8) { int8_t *local_buffer = buffer + j * k; for (int32_t i = 0; i < k; ++i) { const int8_t *b0 = &B(i, j); @@ -179,7 +190,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, for (int32_t j = j_length; j < n; ++j) { *local_buffer++ = *b0++; } - for (int32_t j = n; j < j_length + NR; ++j) { + for (int32_t j = n; j < j_length + 8; ++j) { *local_buffer++ = 0; } } @@ -188,9 +199,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { - const int i_length = m - m_tail; + const int32_t i_length = m - m_tail; #pragma omp parallel for - for (int32_t i = 0; i < i_length; i += MR_INT8) { + for (int32_t i = 0; i < i_length; i += 4) { const int8_t *a0 = A + i * lda; const int8_t *a1 = A + (i + 1) * lda; const int8_t *a2 = A + (i + 2) * lda; @@ -221,7 +232,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, default: break; } - for (int j = 0; j < k; ++j) { + for (int32_t j = 0; j < k; ++j) { *local_buffer++ = *a0++; *local_buffer++ = *a1++; *local_buffer++ = *a2++; @@ -230,6 +241,232 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, } } +// 8 bits int PackMatrixA_4r +void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail, + const int8_t *A, int32_t lda, int8_t *buffer) { + const int32_t i_length = m - m_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; +#pragma omp parallel for + for (int32_t i = 0; i < i_length; i += 4) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + int8_t *local_buffer = buffer + i * KC; + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO(wzzju) +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + int8_t *local_buffer = buffer + i_length * KC; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO(wzzju) +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } +} + +// 8 bits int PackMatrixB +void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; +#pragma omp parallel for + for (int32_t j = 0; j < j_length; j += 2) { + int8_t *local_buffer = buffer + j * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j); + const int8_t *b1 = &B((i << 4), j + 1); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b1; + b1 += ldb; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j); + const int8_t *b1 = &B((k_count << 4), j + 1); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b1; + b1 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j_length); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = 0; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j_length); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index b91242c1868398e4541c3727567a905e5b0c8714..9661b2d4c22ed49ef0c078fac0872c7643057430 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -28,7 +28,12 @@ template void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu = false, - T *bias = nullptr); + float *bias = nullptr); + +void matmul_int8(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, float alpha, + framework::Tensor *matrix_out, float beta, bool relu = false, + int32_t *bias = nullptr); template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp index e02824b290ebc0080613e2ae2365626d79576c9e..e1998e8e12062fe02fa9140b2f4a57bd8121724a 100644 --- a/src/operators/math/math_function_int8.cpp +++ b/src/operators/math/math_function_int8.cpp @@ -20,11 +20,10 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { -template <> -void matmul(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, - int8_t alpha, framework::Tensor *matrix_out, int8_t beta, - bool relu, int8_t *bias) { +void matmul_int8(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, float alpha, + framework::Tensor *matrix_out, float beta, bool relu, + int32_t *bias) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -52,21 +51,45 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, } #ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + if (bias != nullptr) { + // TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead. + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } else { + gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } #else - gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + if (bias != nullptr) { + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } else { + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } #endif } else { #ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + if (bias != nullptr) { + // TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead. + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), + N, relu, bias); + } else { + gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } #else - gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), N, - relu, bias); + if (bias != nullptr) { + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), + N, relu, bias); + } else { + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), + N, relu, bias); + } #endif } } diff --git a/test/common/test_gemm_int8_accuracy.cpp b/test/common/test_gemm_int8_accuracy.cpp index 87f8d945648577ef1414417b57f4013d288dc043..f276cad8e657f5dd0d126fba875a46a80dc66f78 100644 --- a/test/common/test_gemm_int8_accuracy.cpp +++ b/test/common/test_gemm_int8_accuracy.cpp @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include #include @@ -54,6 +55,30 @@ void print_matirx(int m, int n, int ldc, int8_t *c) { std::cout << std::endl; } +int32_t qadd_int32(int32_t l, int32_t r) { + int64_t res = static_cast(l) + static_cast(r); + if (res > INT_MAX) + return INT_MAX; + else if (res < INT_MIN) + return INT_MIN; + else + return static_cast(res); +} + +int8_t qscale_int32(int32_t v, float scale) { + float res = static_cast(v) * scale; + if (res > 0) + res = std::floor(res); + else if (res < 0) + res = std::ceil(res); // round to zero + if (res > 127) + return static_cast(127); + else if (res < -127) + return static_cast(-127); + else + return static_cast(res); +} + int do_sgemm(int m, int n, int k, bool relu, int pr) { int lda = k; int ldb = n; @@ -126,10 +151,98 @@ int do_sgemm(int m, int n, int k, bool relu, int pr) { return 0; } +int do_sgemm_with_bias(int m, int n, int k, bool relu, int pr) { + int lda = k; + int ldb = n; + int ldc = n; + float scale = 0.00628; + default_random_engine e; + uniform_int_distribution pixel(-127, 127); + int8_t *a = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * m * k)); + int8_t *b = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * k * n)); + int8_t *c = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * m * n)); + int8_t *c1 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * m * n)); + + int32_t *bias = + static_cast(paddle_mobile::memory::Alloc(sizeof(int32_t) * m)); + + for (int i = 0; i < m * k; ++i) { + a[i] = pixel(e); + } + for (int i = 0; i < k * n; ++i) { + b[i] = pixel(e); + } + for (int i = 0; i < m; ++i) { + bias[i] = static_cast(pixel(e)); + } + for (int i = 0; i < m; ++i) { + int32_t bias_v = bias[i]; + for (int j = 0; j < n; ++j) { + int32_t r = 0; + for (int p = 0; p < k; p++) { + r += static_cast(a(i, p)) * static_cast(b(p, j)); + } + r = qadd_int32(r, bias_v); + if (relu) r = std::max(0, r); + c1(i, j) = qscale_int32(r, scale); + } + } + + paddle_mobile::operators::math::Gemm gemm; +#ifdef _OPENMP + // TODO(wzzju):gemm.Sgemm_omp_with_bias, now use single thread instead. + gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast(0), c, ldc, + relu, bias); +#else + gemm.Sgemm(m, n, k, scale, a, lda, b, ldb, static_cast(0), c, ldc, + relu, bias); +#endif + int eq = 0; + int neq = 0; + for (int i = 0; i < m * n; ++i) { + if (c[i] == c1[i]) { + ++eq; + } else { + ++neq; + } + } + + if (pr > 0) { + std::cout << "A:" << std::endl; + print_matirx(m, k, lda, a); + std::cout << "B:" << std::endl; + print_matirx(k, n, ldb, b); + std::cout << "Bias:" << std::endl; + print_matirx(m, 1, 1, bias); + std::cout << "C:" << std::endl; + print_matirx(m, n, ldc, c); + std::cout << "C1:" << std::endl; + print_matirx(m, n, ldc, c1); + } + + std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu + << " eq=" << eq << " neq=" << neq << std::endl; + + paddle_mobile::memory::Free(a); + paddle_mobile::memory::Free(b); + paddle_mobile::memory::Free(c); + paddle_mobile::memory::Free(c1); + paddle_mobile::memory::Free(bias); + + return 0; +} + int main() { #ifdef _OPENMP - omp_set_num_threads(8); + omp_set_num_threads(4); #endif + std::cout << "\n\n******************************************************\n\n" + << std::endl; + std::cout << "Test gemm without bias:" << std::endl; do_sgemm(9, 9, 9, false, 1); do_sgemm(10, 6, 12, false, 0); do_sgemm(512, 256, 384, false, 0); @@ -140,5 +253,31 @@ int main() { do_sgemm(333, 797, 939, false, 0); do_sgemm(1024, 1024, 1024, false, 0); + std::cout << "\n\n******************************************************\n\n" + << std::endl; + std::cout << "Test gemm with bias:" << std::endl; + do_sgemm_with_bias(9, 9, 9, false, 1); + do_sgemm_with_bias(10, 6, 12, false, 0); + do_sgemm_with_bias(512, 256, 384, false, 0); + do_sgemm_with_bias(1366, 768, 256, false, 0); + do_sgemm_with_bias(1255, 755, 333, false, 0); + do_sgemm_with_bias(599, 1133, 393, false, 0); + do_sgemm_with_bias(777, 555, 999, false, 0); + do_sgemm_with_bias(333, 797, 939, false, 0); + do_sgemm_with_bias(1024, 1024, 1024, false, 0); + + std::cout << "\n\n******************************************************\n\n" + << std::endl; + std::cout << "Test gemm with relu and bias:" << std::endl; + do_sgemm_with_bias(9, 9, 9, true, 1); + do_sgemm_with_bias(10, 6, 12, true, 0); + do_sgemm_with_bias(512, 256, 384, true, 0); + do_sgemm_with_bias(1366, 768, 256, true, 0); + do_sgemm_with_bias(1255, 755, 333, true, 0); + do_sgemm_with_bias(599, 1133, 393, true, 0); + do_sgemm_with_bias(777, 555, 999, true, 0); + do_sgemm_with_bias(333, 797, 939, true, 0); + do_sgemm_with_bias(1024, 1024, 1024, true, 0); + return 0; } diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 14da4ba284b5ac7b0660bd15de871fdf5ed04cdd..5ca0b40cfcb20786ad69d1bbfbaca103b3e426e3 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -28,7 +28,7 @@ limitations under the License. */ int main() { paddle_mobile::PaddleMobile paddle_mobile; - paddle_mobile.SetThreadNum(8); + paddle_mobile.SetThreadNum(4); Tensor aa, bb, cc; auto aaptr = aa.mutable_data({m, k}); auto bbptr = bb.mutable_data({k, n}); @@ -44,10 +44,12 @@ int main() { ccptr[i] = 2; } - Tensor aa_int8, bb_int8, cc_int8; + Tensor aa_int8, bb_int8, cc_int32, cc_int8; auto aaptr_int8 = aa_int8.mutable_data({m, k}); auto bbptr_int8 = bb_int8.mutable_data({k, n}); - auto ccptr_int8 = cc_int8.mutable_data({m, n}); + auto ccptr_int32 = cc_int32.mutable_data({m, n}); + auto ccptr_int8 = cc_int8.mutable_data({m, n}); + int32_t* bias_data = new int32_t[m]; for (int i = 0; i < m * k; ++i) { aaptr_int8[i] = static_cast(2); @@ -56,7 +58,11 @@ int main() { bbptr_int8[i] = static_cast(2); } for (int i = 0; i < m * n; ++i) { - ccptr_int8[i] = static_cast(2); + ccptr_int32[i] = static_cast(2); + } + + for (int i = 0; i < m; ++i) { + bias_data[i] = 2; } // float @@ -76,22 +82,41 @@ int main() { auto time2 = time(); std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; - // int8_t + // int8_t without bias // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( - aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, - static_cast(0), false, nullptr); + paddle_mobile::operators::math::matmul_int8( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, + static_cast(0), false, nullptr); } auto time3 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( - aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, - static_cast(0), false, nullptr); + paddle_mobile::operators::math::matmul_int8( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, + static_cast(0), false, nullptr); } auto time4 = time(); std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; + // int8_t with bias&relu + // warm-up 10 times + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul_int8( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + static_cast(0), true, &bias_data[0]); + } + auto time5 = time(); + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul_int8( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + static_cast(0), true, &bias_data[0]); + } + auto time6 = time(); + std::cout << "int8_t gemm_with_bias_relu cost :" + << time_diff(time5, time6) / 10 << "ms\n"; + + delete[] bias_data; + return 0; }