From 5b62acfa01153e4c525cdbd71c70335de86932df Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 27 Jan 2021 16:59:55 +0800 Subject: [PATCH] feat(dnn/armv7): add new matmul strategy k8x8x4 GitOrigin-RevId: 0c6b7fa1b2ad8724a5c68036d58b3c1e13c3bb42 --- dnn/src/armv7/matrix_mul/algos.cpp | 68 ++ dnn/src/armv7/matrix_mul/algos.h | 12 + .../armv7/matrix_mul/int8x8x16/kernel_8x8x4.h | 675 ++++++++++++++++++ .../armv7/matrix_mul/int8x8x16/strategy.cpp | 76 +- dnn/src/armv7/matrix_mul/int8x8x16/strategy.h | 3 + dnn/src/armv7/matrix_mul/opr_impl.cpp | 3 +- dnn/src/armv7/matrix_mul/opr_impl.h | 3 +- dnn/src/fallback/matrix_mul/opr_impl.h | 3 +- dnn/test/armv7/matrix_mul.cpp | 87 +++ dnn/test/common/convolution.cpp | 15 +- 10 files changed, 938 insertions(+), 7 deletions(-) create mode 100644 dnn/src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h diff --git a/dnn/src/armv7/matrix_mul/algos.cpp b/dnn/src/armv7/matrix_mul/algos.cpp index e88efa6ed..f0834a88a 100644 --- a/dnn/src/armv7/matrix_mul/algos.cpp +++ b/dnn/src/armv7/matrix_mul/algos.cpp @@ -541,6 +541,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8, armv7::matmul::gemm_s8x8x16_4x8, int8_t, int16_t, AlgoDataType::INT8X8X16, DEFAULT); +/* ===================== Int8x8x16 Kernel 8x8x4 algo ===================== */ + +namespace { +void kern_int8x8x16_k8x8x4(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("kern_int8x8x16_k8x8x4"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto trA = kern_param.trA, trB = kern_param.trB; + + armv7::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, kern_param.A_type, + kern_param.B_type, + kern_param.C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x16K8x8x4::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type == kern_size_param.B_type && + kern_size_param.A_type == dtype::Int8() && + kern_size_param.C_type == dtype::Int16() && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; +} + +size_t MatrixMulImpl::AlgoInt8x8x16K8x8x4::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); + return 0; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x4::get_kern( + const KernSizeParam&) const { + return kern_int8x8x16_k8x8x4; +} + +bool MatrixMulImpl::AlgoInt8x8x16K8x8x4::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K >= 8 && kern_size_param.K <= 128; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x4, + megdnn_armv7_matmul_kern, + "AlgoInt8x8x16K8x8x4"_hash, + armv7::matmul::gemm_s8x8x16_8x8, int8_t, + int16_t, AlgoDataType::INT8X8X16, DEFAULT); + + /* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/ namespace { diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index fc1ae72bb..d3b28e801 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -181,6 +181,18 @@ public: MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X8X8) }; +class MatrixMulImpl::AlgoInt8x8x16K8x8x4 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_INT8X8X16_K8X8X4"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K8X8X4) +}; + class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { public: bool is_reproducible() const override { return true; } diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h new file mode 100644 index 000000000..0550e978b --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h @@ -0,0 +1,675 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_8x8x4 { +/* +--------+---------------------------------+ + * | q4 | b00 b01 b02 b03 b04 b05 b06 b07 | + * +--------+---------------------------------+ + * | q5 | b10 b11 b12 b13 b14 b15 b16 b17 | + * +--------+---------------------------------+ + * | q6 | b20 b21 b22 b23 b24 b25 b26 b27 | + * +--------+---------------------------------+ + * | q7 | b30 b31 b32 b33 b34 b35 b36 b37 | + * +--------+---------------------------------+ + * +----+-----------------+ +--------+---------------------------------+ + * | d0 | a00 a01 a02 a03 | | q8 | c00 c01 c02 c03 c04 c05 c06 c07 | + * | d1 | a10 a11 a12 a13 | | q9 | c10 c11 c12 c13 c14 c15 c16 c17 | + * | d2 | a20 a21 a22 a23 | | q10 | c20 c21 c22 c23 c24 c25 c26 c27 | + * | d3 | a30 a31 a32 a33 | | q11 | c30 c31 c32 c33 c34 c35 c36 c37 | + * | d4 | a40 a41 a42 a43 | | q12 | c40 c41 c42 c43 c44 c45 c46 c47 | + * | d5 | a50 a51 a52 a53 | | q13 | c50 c51 c52 c53 c54 c55 c56 c57 | + * | d6 | a60 a61 a62 a63 | | q14 | c60 c61 c62 c63 c64 c65 c66 c67 | + * | d7 | a70 a71 a72 a73 | | q15 | c70 c71 c72 c73 c74 c75 c76 c77 | + * +----+-----------------+ +--------+---------------------------------+ + * + */ + +static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, + size_t n_remain) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + LDC = LDC * sizeof(int16_t); + size_t nr = n_remain; + + // clang-format off + +#define LOAD_C \ + "mov r1, r0\n" \ + "vld1.16 {d16, d17}, [r1], %[LDC]\n" \ + "vld1.16 {d18, d19}, [r1], %[LDC]\n" \ + "vld1.16 {d20, d21}, [r1], %[LDC]\n" \ + "vld1.16 {d22, d23}, [r1], %[LDC]\n" \ + "vld1.16 {d24, d25}, [r1], %[LDC]\n" \ + "vld1.16 {d26, d27}, [r1], %[LDC]\n" \ + "vld1.16 {d28, d29}, [r1], %[LDC]\n" \ + "vld1.16 {d30, d31}, [r1], %[LDC]\n" + + +#define STORE_LINE(id1, id2) \ + "mov r2, r1\n" \ + "cmp %[nr], #8\n" \ + "bne 100f\n" \ + "vst1.16 {d" id1 ", d" id2 "}, [r2]!\n" \ + "b 101f\n" \ + "100:\n" \ + "cmp %[nr], #0\n" \ + "beq 101f\n" \ + "vst1.16 {d" id1 "[0]}, [r2]!\n" \ + "cmp %[nr], #1\n" \ + "beq 101f\n" \ + "vst1.16 {d" id1 "[1]}, [r2]!\n" \ + "cmp %[nr], #2\n" \ + "beq 101f\n" \ + "vst1.16 {d" id1 "[2]}, [r2]!\n" \ + "cmp %[nr], #3\n" \ + "beq 101f\n" \ + "vst1.16 {d" id1 "[3]}, [r2]!\n" \ + "cmp %[nr], #4\n" \ + "beq 101f\n" \ + "vst1.16 {d" id2 "[0]}, [r2]!\n" \ + "cmp %[nr], #5\n" \ + "beq 101f\n" \ + "vst1.16 {d" id2 "[1]}, [r2]!\n" \ + "cmp %[nr], #6\n" \ + "beq 101f\n" \ + "vst1.16 {d" id2 "[2]}, [r2]!\n" \ + "101:\n" +#define STORE_C \ + "mov r1, r0\n" \ + STORE_LINE("16", "17") \ + "add r1, r1, %[LDC]\n" \ + STORE_LINE("18", "19") \ + "add r1, r1, %[LDC]\n" \ + STORE_LINE("20", "21") \ + "add r1, r1, %[LDC]\n" \ + STORE_LINE("22", "23") \ + "add r1, r1, %[LDC]\n" \ + STORE_LINE("24", "25") \ + "add r1, r1, %[LDC]\n" \ + STORE_LINE("26", "27") \ + "add r1, r1, %[LDC]\n" \ + STORE_LINE("28", "29") \ + "add r1, r1, %[LDC]\n" \ + STORE_LINE("30", "31") + // clang-format on + + register int16_t* outptr asm("r0") = output; + asm volatile( + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.s32 q8, q8, q8\n" + "veor.s32 q9, q9, q9\n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q11, q11, q11\n" + "veor.s32 q12, q12, q12\n" + "veor.s32 q13, q13, q13\n" + "veor.s32 q14, q14, q14\n" + "veor.s32 q15, q15, q15\n" + + "2:\n" + "vld1.8 {d0}, [%[a_ptr]]!\n" + "vld1.8 {d2}, [%[a_ptr]]!\n" + "vld1.8 {d4}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[a_ptr]]!\n" + "vmovl.s8 q0, d0\n" + "vmovl.s8 q1, d2\n" + "vmovl.s8 q2, d4\n" + "vmovl.s8 q3, d6\n" + + "vld1.8 {d8}, [%[b_ptr]]!\n" + "vld1.8 {d10}, [%[b_ptr]]!\n" + "vld1.8 {d12}, [%[b_ptr]]!\n" + "vld1.8 {d14}, [%[b_ptr]]!\n" + "vmovl.s8 q4, d8\n" + "vmovl.s8 q5, d10\n" + "vmovl.s8 q6, d12\n" + "vmovl.s8 q7, d14\n" + + "vmla.s16 q8, q4, d0[0]\n" + "vmla.s16 q9, q4, d1[0]\n" + "vmla.s16 q10, q4, d2[0]\n" + "vmla.s16 q11, q4, d3[0]\n" + "vmla.s16 q12, q4, d4[0]\n" + "vmla.s16 q13, q4, d5[0]\n" + "vmla.s16 q14, q4, d6[0]\n" + "vmla.s16 q15, q4, d7[0]\n" + + "vmla.s16 q8, q5, d0[1]\n" + "vmla.s16 q9, q5, d1[1]\n" + "vmla.s16 q10, q5, d2[1]\n" + "vmla.s16 q11, q5, d3[1]\n" + "vmla.s16 q12, q5, d4[1]\n" + "vmla.s16 q13, q5, d5[1]\n" + "vmla.s16 q14, q5, d6[1]\n" + "vmla.s16 q15, q5, d7[1]\n" + + "vmla.s16 q8, q6, d0[2]\n" + "vmla.s16 q9, q6, d1[2]\n" + "vmla.s16 q10, q6, d2[2]\n" + "vmla.s16 q11, q6, d3[2]\n" + "vmla.s16 q12, q6, d4[2]\n" + "vmla.s16 q13, q6, d5[2]\n" + "vmla.s16 q14, q6, d6[2]\n" + "vmla.s16 q15, q6, d7[2]\n" + + "vmla.s16 q8, q7, d0[3]\n" + "vmla.s16 q9, q7, d1[3]\n" + "vmla.s16 q10, q7, d2[3]\n" + "vmla.s16 q11, q7, d3[3]\n" + "vmla.s16 q12, q7, d4[3]\n" + "vmla.s16 q13, q7, d5[3]\n" + "vmla.s16 q14, q7, d6[3]\n" + "vmla.s16 q15, q7, d7[3]\n" + + "subs %[K], %[K], #1\n" + "bne 2b\n" + + "3:\n" STORE_C + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), + [ LDC ] "+r"(LDC), [ is_first_k ] "+r"(is_first_k), + [ outptr ] "+r"(outptr), [ nr ] "+r"(nr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "r1", "r2", "cc", "memory"); +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +/* +--------+---------------------------------+ + * | q2 | b00 b01 b02 b03 b04 b05 b06 b07 | + * +--------+---------------------------------+ + * | q3 | b10 b11 b12 b13 b14 b15 b16 b17 | + * +--------+---------------------------------+ + * | q4 | b20 b21 b22 b23 b24 b25 b26 b27 | + * +--------+---------------------------------+ + * | q5 | b30 b31 b32 b33 b34 b35 b36 b37 | + * +--------+---------------------------------+ + * +----+-----------------+ +--------+---------------------------------+ + * | d0 | a00 a01 a02 a03 | | q6 | c00 c01 c02 c03 c04 c05 c06 c07 | + * | d1 | a10 a11 a12 a13 | | q7 | c10 c11 c12 c13 c14 c15 c16 c17 | + * | d2 | a20 a21 a22 a23 | | q8 | c20 c21 c22 c23 c24 c25 c26 c27 | + * | d3 | a30 a31 a32 a33 | | q9 | c30 c31 c32 c33 c34 c35 c36 c37 | + * +----+-----------------+ +--------+---------------------------------+ + * + */ +static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, size_t m_remain, + size_t n_remain) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + LDC = LDC * sizeof(int16_t); + size_t mr = m_remain; + size_t nr = n_remain; + + // clang-format off + +#define LOAD_C \ + "cmp %[mr], #0\n" \ + "beq 100f\n" \ + "mov r1, r0\n" \ + "vld1.16 {d12, d13}, [r1], %[LDC]\n" \ + "cmp %[mr], #1\n" \ + "beq 100f\n" \ + "vld1.16 {d14, d15}, [r1], %[LDC]\n" \ + "cmp %[mr], #2\n" \ + "beq 100f\n" \ + "vld1.16 {d16, d17}, [r1], %[LDC]\n" \ + "cmp %[mr], #3\n" \ + "beq 100f\n" \ + "vld1.16 {d18, d19}, [r1], %[LDC]\n" \ + "100:\n" \ + +#define STORE_LINE(id1, id2) \ + "mov r2, r1\n" \ + "cmp %[nr], #0\n" \ + "beq 101f\n" \ + "vst1.16 {d" id1 "[0]}, [r2]!\n" \ + "cmp %[nr], #1\n" \ + "beq 101f\n" \ + "vst1.16 {d" id1 "[1]}, [r2]!\n" \ + "cmp %[nr], #2\n" \ + "beq 101f\n" \ + "vst1.16 {d" id1 "[2]}, [r2]!\n" \ + "cmp %[nr], #3\n" \ + "beq 101f\n" \ + "vst1.16 {d" id1 "[3]}, [r2]!\n" \ + "cmp %[nr], #4\n" \ + "beq 101f\n" \ + "vst1.16 {d" id2 "[0]}, [r2]!\n" \ + "cmp %[nr], #5\n" \ + "beq 101f\n" \ + "vst1.16 {d" id2 "[1]}, [r2]!\n" \ + "cmp %[nr], #6\n" \ + "beq 101f\n" \ + "vst1.16 {d" id2 "[2]}, [r2]!\n" \ + "cmp %[nr], #7\n" \ + "beq 101f\n" \ + "vst1.16 {d" id2 "[3]}, [r2]!\n" \ + "101:\n" +#define STORE_C \ + "cmp %[mr], #0\n" \ + "beq 102f\n" \ + "mov r1, r0\n" \ + STORE_LINE("12", "13") \ + "cmp %[mr], #1\n" \ + "beq 102f\n" \ + "add r1, r1, %[LDC]\n" \ + STORE_LINE("14", "15") \ + "cmp %[mr], #2\n" \ + "beq 102f\n" \ + "add r1, r1, %[LDC]\n" \ + STORE_LINE("16", "17") \ + "cmp %[mr], #3\n" \ + "beq 102f\n" \ + "add r1, r1, %[LDC]\n" \ + STORE_LINE("18", "19") \ + "102:\n" + + // clang-format on + + register int16_t* outptr asm("r0") = output; + asm volatile( + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + "veor.s32 q8, q8, q8\n" + "veor.s32 q9, q9, q9\n" + + "2:\n" + "vld1.8 {d0}, [%[a_ptr]]!\n" + "vld1.8 {d2}, [%[a_ptr]]!\n" + "vmovl.s8 q0, d0\n" + "vmovl.s8 q1, d2\n" + + "vld1.8 {d4}, [%[b_ptr]]!\n" + "vld1.8 {d6}, [%[b_ptr]]!\n" + "vld1.8 {d8}, [%[b_ptr]]!\n" + "vld1.8 {d10}, [%[b_ptr]]!\n" + "vmovl.s8 q2, d4\n" + "vmovl.s8 q3, d6\n" + "vmovl.s8 q4, d8\n" + "vmovl.s8 q5, d10\n" + + "vmla.s16 q6, q2, d0[0]\n" + "vmla.s16 q7, q2, d1[0]\n" + "vmla.s16 q8, q2, d2[0]\n" + "vmla.s16 q9, q2, d3[0]\n" + + "vmla.s16 q6, q3, d0[1]\n" + "vmla.s16 q7, q3, d1[1]\n" + "vmla.s16 q8, q3, d2[1]\n" + "vmla.s16 q9, q3, d3[1]\n" + + "vmla.s16 q6, q4, d0[2]\n" + "vmla.s16 q7, q4, d1[2]\n" + "vmla.s16 q8, q4, d2[2]\n" + "vmla.s16 q9, q4, d3[2]\n" + + "vmla.s16 q6, q5, d0[3]\n" + "vmla.s16 q7, q5, d1[3]\n" + "vmla.s16 q8, q5, d2[3]\n" + "vmla.s16 q9, q5, d3[3]\n" + + "subs %[K], %[K], #1\n" + "bne 2b\n" + + "3:\n" STORE_C + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), + [ LDC ] "+r"(LDC), [ is_first_k ] "+r"(is_first_k), + [ outptr ] "+r"(outptr), [ mr ] "+r"(mr), [ nr ] "+r"(nr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "r1", + "r2", "cc", "memory"); +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* out, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + int8_t* outptr = out; + int y = y0; + for (; y + 7 < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int K = kmax - k0; + for (; K > 15; K -= 16) { + interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + } + + if (K > 0) { + for (; K > 0; K -= 4) + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr, 4, std::min(K, 4)); + } + } + for (; y < ymax; y += 4) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = kmax - k0; + for (; K > 0; K -= 4) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, + std::min(K, 4)); + } + } +} + +static void gemm_s8x8x16_8x8_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + int8_t* outbase = out; + size_t K = round_up(kmax - k0, 4); + + int k = k0; + for (; k < kmax; k += 4) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = xmax - x0; + int8_t* outptr = outbase; + int8_t* out_tmp = outptr; + for (; x > 7; x -= 8) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + out_tmp = outptr; + transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, out_tmp); + outptr += (K - k) * 8 + (x > 15 ? 8 : 4) * k; + } + + if (x > 0) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + out_tmp = outptr; + if (x > 4) { + transpose_4(inptr0, inptr1, inptr2, inptr3, out_tmp, 4, 4); + x -= 4; + out_tmp = outptr + K * 4; + } + transpose_4(inptr0, inptr1, inptr2, inptr3, out_tmp, 4, x); + } + outbase += 4 * ((xmax - x0) > 7 ? 8 : 4); + } +} + +static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + int8_t* outbase = out; + int8_t* out_interleave = out; + const size_t K8 = round_up(kmax - k0, 4) * 8; + int k = k0; + for (; k < kmax; k += 4) { + const int8_t* inptr0 = in + k * ldin + x0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int x = xmax - x0; + + int8_t* outptr = outbase; + for (; x > 7; x -= 8) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + out_interleave = outptr; + asm volatile( + "vld1.32 {d0}, [%[inptr0]]!\n" + "vld1.32 {d1}, [%[inptr1]]!\n" + "vld1.32 {d2}, [%[inptr2]]!\n" + "vld1.32 {d3}, [%[inptr3]]!\n" + "vst1.32 {d0}, [%[out_interleave]]!\n" + "vst1.32 {d1}, [%[out_interleave]]!\n" + "vst1.32 {d2}, [%[out_interleave]]!\n" + "vst1.32 {d3}, [%[out_interleave]]!\n" + : [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1), + [ inptr2 ] "+r"(inptr2), [ inptr3 ] "+r"(inptr3), + [ out_interleave ] "+r"(out_interleave) + : + : "q0", "q1", "cc", "memory"); + outptr += K8; + } + + if (x > 0) { + if (k + 3 >= kmax) { + switch (k + 3 - kmax) { + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + out_interleave = outptr; + interleave_4(inptr0, inptr1, inptr2, inptr3, out_interleave, 8, x); + } + outbase += 4 * 8; + } +} + +static void gemm_s8x8x16_8x8_pack_B_t(dt_int8* out, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int8_t zerobuff[16]; + std::memset(zerobuff, 0, sizeof(int8_t) * 16); + int8_t* outptr = out; + + int y = y0; + for (; y < ymax; y += 8) { + const int8_t* inptr0 = inptr + y * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr1 + ldin; + const int8_t* inptr3 = inptr2 + ldin; + const int8_t* inptr4 = inptr3 + ldin; + const int8_t* inptr5 = inptr4 + ldin; + const int8_t* inptr6 = inptr5 + ldin; + const int8_t* inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int k = k0; + + for (; k + 3 < kmax; k += 4) { + if (y + 7 >= ymax) { + switch (y + 7 - ymax) { + case 6: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 5: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 4: + inptr3 = zerobuff; + MEGDNN_FALLTHRU + case 3: + inptr4 = zerobuff; + MEGDNN_FALLTHRU + case 2: + inptr5 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr6 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_4x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); + outptr += 4 * 8; + } + + if (k < kmax) { + if (y + 7 >= ymax) { + switch (y + 7 - ymax) { + case 6: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 5: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 4: + inptr3 = zerobuff; + MEGDNN_FALLTHRU + case 3: + inptr4 = zerobuff; + MEGDNN_FALLTHRU + case 2: + inptr5 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr6 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr7 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, kmax - k); + outptr += 4 * 8; + } + } +} +} // namespace matmul_8x8x4 +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp index 0f9fb4dbf..6dd199684 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp @@ -10,12 +10,13 @@ * implied. */ +#include "src/armv7/matrix_mul/int8x8x16/strategy.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/armv7/matrix_mul/asm/common.h" #include "src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h" #include "src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h" +#include "src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h" #include "src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h" -#include "src/armv7/matrix_mul/int8x8x16/strategy.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_common.h" @@ -181,6 +182,79 @@ void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB, } } +// ===========================gemm_s8x8x16_8x8================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); + +void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, + int ymax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_8x8x4::gemm_s8x8x16_8x8_pack_A_t(out, in, ldin, y0, ymax, k0, + kmax); + } else { + matmul_8x8x4::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); + } +} + +void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, + bool transpose) const { + if (transpose) { + matmul_8x8x4::gemm_s8x8x16_8x8_pack_B_t(out, in, ldin, x0, xmax, k0, + kmax); + } else { + matmul_8x8x4::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); + } +} + +void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int16* C, + size_t LDC, bool is_first_k, const dt_int16*, + dt_int16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 8; + //! K is packed to times of 4 + K = round_up(K, 4); + size_t m = 0; + for (; m + 7 < M; m += A_INTERLEAVE) { + int16_t* output = C + (m * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n < N; n += B_INTERLEAVE) { + matmul_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 8)); + output += B_INTERLEAVE; + cur_packB += K * 8; + } + packA += K * 8; + } + for (; m < M; m += 4) { + int16_t* output = C + (m * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n < N; n += B_INTERLEAVE) { + matmul_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), + std::min(N - n, 8)); + output += B_INTERLEAVE; + cur_packB += K * 8; + } + packA += K * 4; + } +} + // ===========================gemm_s8x8x16_mk4_8x8================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8); diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h index d395cbcd3..d17bd647e 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h +++ b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h @@ -22,6 +22,9 @@ MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 2, 16, false, true, MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 8, 8, false, true, gemm_s8x8x16_4x8); +MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 8, 8, 4, false, true, + gemm_s8x8x16_8x8); + MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(int8_t, int16_t, int16_t, int16_t, 8, 8, 4, false, false, gemm_s8x8x16_mk4_8x8); diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp index 6db2ee137..c0cc21e9a 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.cpp +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -39,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoQuint8K4x8x8 quint8_k4x8x8; AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; + AlgoInt8x8x16K8x8x4 int8x8x16_k8x8x4; AlgoInt8x8x16MK4_8x8x4 int8x8x16_mk4_8x8x4; AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; @@ -47,7 +48,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; public: - AlgoPack() { m_all_algos.emplace_back(&f32_gemv); m_all_algos.emplace_back(&f32); @@ -69,6 +69,7 @@ public: m_all_algos.emplace_back(&int8x8x16_mk4_8x8x4); m_all_algos.emplace_back(&int8x8x16_k4x2x16); m_all_algos.emplace_back(&int8x8x16_k4x8x8); + m_all_algos.emplace_back(&int8x8x16_k8x8x4); m_all_algos.emplace_back(&int16x16x32_k12x4x1); m_all_algos.emplace_back(&int16x16x32_mk8_4x8); diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index 20448525d..085bc8c25 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -41,7 +41,8 @@ private: class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 - class AlgoInt8x8x16MK4_8x8x4; // Armv7 Int8x8x16 Kernel 8x8x8 + class AlgoInt8x8x16K8x8x4; // Armv7 Int8x8x16 Kernel 8x8x4 + class AlgoInt8x8x16MK4_8x8x4; // Armv7 Int8x8x16 Kernel mk4_8x8x4 class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1 class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index 48d23b9fb..930e7a0dd 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -174,7 +174,8 @@ public: ARMV7_INT8X8X16_MK4_K8X8X4, ARMV7_INT16X16X32_K12X4X1, ARMV7_INT16X16X32_MK8_4X8, - ARMV7_INT8X8X32_MK4_4X2X16 + ARMV7_INT8X8X32_MK4_4X2X16, + ARMV7_INT8X8X16_K8X8X4 #endif #endif }; diff --git a/dnn/test/armv7/matrix_mul.cpp b/dnn/test/armv7/matrix_mul.cpp index 97a37fa88..7d014ee97 100644 --- a/dnn/test/armv7/matrix_mul.cpp +++ b/dnn/test/armv7/matrix_mul.cpp @@ -52,6 +52,12 @@ TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) { handle(), "ARMV7_INT8X8X16_K4X8X8"); } +TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K8x8x4) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle(), "ARMV7_INT8X8X16_K8X8X4"); +} + + TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_MK4_K8x8x4) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, handle(), "ARMV7_INT8X8X16_MK4_K8X8X4", @@ -183,6 +189,68 @@ void run_8x8x16_benchmark( } } } + +void run_8x8x16_contrast( + const char* algo0, const char* algo, Handle* handle, + MatrixMul::Param::Format format = MatrixMul::Param::Format::DEFAULT) { + constexpr size_t RUNS = 100; + param::MatrixMul param; + Benchmarker benchmarker_int(handle); + Benchmarker benchmarker_int_kern_4x2x16(handle); + benchmarker_int.set_before_exec_callback(AlgoChecker(algo0)); + + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + param::MatrixMul target_param; + target_param.format = format; + + benchmarker_int_kern_4x2x16.set_before_exec_callback( + AlgoChecker(algo)); + benchmarker_int_kern_4x2x16.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(target_param) + .set_display(false); + + auto run = [&](size_t M, size_t N, size_t K) { + auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS; + auto int_kern_used = 1e10; + double computation = 2.0f * M * N * K * 1e-6; + if (format == MatrixMul::Param::Format::MK4) { + int_kern_used = benchmarker_int_kern_4x2x16.exec( + {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / + RUNS; + } else { + int_kern_used = + benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) / + RUNS; + } + + printf(" %f(%f)\t %f(%f)\t %f\n", int_used, computation / int_used, + int_kern_used, computation / int_kern_used, + int_used / int_kern_used); + }; + printf("\nN\t K\t M\t %s ms(GFlops)\t %s ms(GFlops)\t SPEEDUP\n", algo0, + algo); + + for (size_t M : {8}) { + for (size_t K : {72}) { + for (size_t N : {8, 16, 32, 64, 72, 128, 256, 512, 1024, 4096, 8192, + 16384, 32768, 65536}) { + printf("%zu\t %zu\t %zu\t", N, K, M); + run(M, N, K); + } + } + } + printf("512\t 512\t 512\t"); + run(512, 512, 512); +} + void run_16x16x32_benchmark(const char* algo, Handle* handle) { constexpr size_t RUNS = 50; param::MatrixMul param; @@ -383,6 +451,10 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) { run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle()); } +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K8x8x4) { + run_8x8x16_benchmark("ARMV7_INT8X8X16_K8X8X4", handle()); +} + TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_MK4_K4x8x8) { run_8x8x16_benchmark("ARMV7_INT8X8X16_MK4_K8X8X4", handle(), MatrixMul::Param::Format::MK4); @@ -392,6 +464,21 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) { run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle()); } +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K8x8x4_CONTRAST) { + run_8x8x16_contrast("ARM_COMMON_INT8X8X16", "ARMV7_INT8X8X16_K8X8X4", + handle()); +} + +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8_CONTRAST) { + run_8x8x16_contrast("ARM_COMMON_INT8X8X16", "ARMV7_INT8X8X16_K4X8X8", + handle()); +} + +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8_K8x8x4_CONTRAST) { + run_8x8x16_contrast("ARMV7_INT8X8X16_K4X8X8", "ARMV7_INT8X8X16_K8X8X4", + handle()); +} + #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_FP16) { constexpr size_t RUNS = 50; diff --git a/dnn/test/common/convolution.cpp b/dnn/test/common/convolution.cpp index 1a8ae04b3..5caebcf86 100644 --- a/dnn/test/common/convolution.cpp +++ b/dnn/test/common/convolution.cpp @@ -517,9 +517,18 @@ void convolution::test_conv_config_combinations(int k_size, param.compute_mode = Param::ComputeMode::FLOAT32; } size_t IC = 6, OC = 9, G = 3, FH = ksize, FW = ksize; - TensorShape ishp = format ? - TensorShape{2, 18, 18, IC} : TensorShape{2, IC, 18, 18}, - fshp; + TensorShape ishp = TensorShape{2, 18, 18, IC}, fshp; + if (format) { + ishp.shape[0] = 2; + ishp.shape[1] = 18; + ishp.shape[2] = 18; + ishp.shape[3] = IC; + } else { + ishp.shape[0] = 2; + ishp.shape[1] = IC; + ishp.shape[2] = 18; + ishp.shape[3] = 18; + } if (padding) { param.pad_h = 2 + non_square; param.pad_w = 2 - non_square; -- GitLab