From 36e3bb6ea7c41ec4a77671e2b8b98bf58c2eecc6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 24 May 2020 19:32:12 +0800 Subject: [PATCH] feat(mgb/dnn): add armv7 mk4_dot matmul GitOrigin-RevId: d4206f8e21d1f58a7e07e1345d2738dc76e7bfbd --- dnn/src/armv7/matrix_mul/algos.cpp | 67 ++ dnn/src/armv7/matrix_mul/algos.h | 12 + dnn/src/armv7/matrix_mul/asm/common.h | 25 + .../matrix_mul/int8/kernel_mk4_dot_8x6x4.h | 747 ++++++++++++++++++ dnn/src/armv7/matrix_mul/int8/strategy.cpp | 84 ++ dnn/src/armv7/matrix_mul/int8/strategy.h | 3 + dnn/src/armv7/matrix_mul/opr_impl.cpp | 2 + dnn/src/armv7/matrix_mul/opr_impl.h | 2 + dnn/test/armv7/matrix_mul.cpp | 59 ++ 9 files changed, 1001 insertions(+) create mode 100644 dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h diff --git a/dnn/src/armv7/matrix_mul/algos.cpp b/dnn/src/armv7/matrix_mul/algos.cpp index c9a669460..3a4a12fd0 100644 --- a/dnn/src/armv7/matrix_mul/algos.cpp +++ b/dnn/src/armv7/matrix_mul/algos.cpp @@ -706,6 +706,73 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4, "AlgoQuint8DotK4x8x4"_hash, armv7::matmul::gemm_dot_quint8_4x8, uint8_t, int32_t); + +/* ======================== Int8 MK4 8x6x4 dot algo ======================== */ +namespace { +void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("int8_mk4_8x6x4_dotprod_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // namespace + +bool MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && + (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32) && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK4_DOT && + !kern_size_param.trA && !kern_size_param.trB; +} + +size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type, + C_type); + return megdnn::matmul::GemmInterleaved< + armv7::matmul::gemm_mk4_dots8_8x6>(M, N, K, trA, trB, + strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_kern( + const KernSizeParam&) const { + return int8_mk4_8x6x4_dotprod_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x6x4DotProd, + megdnn_armv7_matmul_kern, + "AlgoInt8x8x32MK4_8x6x4DotProd"_hash, + armv7::matmul::gemm_mk4_dots8_8x6, int8_t, + int32_t); #endif /* ===================== F32 algo K4x8 ===================== */ diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index 388074c26..923be14a1 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -93,6 +93,18 @@ public: kern_t get_kern(const KernSizeParam&) const override; MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; + +class MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { + return "AARCH32_INT8_MK4_8X6X4_DOTPROD"; + } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; #endif class MatrixMulImpl::AlgoF32Gemv final diff --git a/dnn/src/armv7/matrix_mul/asm/common.h b/dnn/src/armv7/matrix_mul/asm/common.h index a7aad47e9..820322e52 100644 --- a/dnn/src/armv7/matrix_mul/asm/common.h +++ b/dnn/src/armv7/matrix_mul/asm/common.h @@ -125,6 +125,20 @@ static inline void interleave_4x1_2_d(const int64_t*& inptr0, : "q0", "q1", "q2", "q3", "cc", "memory"); } +static inline void interleave_2x1_4_s(const int32_t*& inptr0, + const int32_t*& inptr1, + int32_t*& outptr) { + asm volatile( + "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.32 {d2, d3}, [%[inptr1]]!\n" // A0A1A2A3 + "vst1.32 {d0, d1}, [%[outptr]]!\n" + "vst1.32 {d2, d3}, [%[outptr]]!\n" + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "cc", "memory"); +} + template static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, @@ -188,6 +202,17 @@ static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1, : "q0", "q1", "q2", "q3", "memory"); } +template +static inline void interleave_2x4_4_b(const T*& inptr0, const T*& inptr1, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_2x4_4_b only support uint8_t and int8_t"); + interleave_2x1_4_s(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(outptr)); +} + template static inline void interleave_6x4_4_b(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h new file mode 100644 index 000000000..214c94c5f --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h @@ -0,0 +1,747 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#if __ARM_FEATURE_DOTPROD + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_mk4_dot_8x6x4 { + +// Overview of register layout: +// +// A 1x6x4 cell of Rhs is stored in 8bit in q0, q1. +// A 2x1x4x4 cell of Lhs is stored in 8bit in q2, q3 +// A 2x6x4 block of accumulators is stored in 8bit in q4-q15 +// +// +--------+ +// Rhs |q0[0-16]| +// |q1[0-16]| +// +--------+ +// Lhs | | +// +-------+-------+ - - - - +--------+ +// | q2[0-16]| | q4[0-4]| +// | q3[0-16]| | q5[0-4]| +// +---------+ | q6[0-4]| +// | q7[0-4]| +// | q8[0-4]| +// | q9[0-4]| +// |q10[0-4]| +// |q11[0-4]| +// |q12[0-4]| +// |q13[0-4]| +// |q14[0-4]| +// |q15[0-4]| +// +--------+ +// Accumulator + +static void kern_8x6(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = (K + 1) / 2 - 1; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %[LDC]\n" + "cmp %[is_first_k], #1\n" + "beq 1f\n" + "vld1.32 {d8, d9}, [%[outptr0]]!\n" + "vld1.32 {d10, d11}, [%[outptr0]]!\n" + "vld1.32 {d12, d13}, [%[outptr0]]!\n" + "vld1.32 {d14, d15}, [%[outptr0]]!\n" + "vld1.32 {d16, d17}, [%[outptr0]]!\n" + "vld1.32 {d18, d19}, [%[outptr0]]!\n" + "vld1.32 {d20, d21}, [%[outptr1]]!\n" + "vld1.32 {d22, d23}, [%[outptr1]]!\n" + "vld1.32 {d24, d25}, [%[outptr1]]!\n" + "vld1.32 {d26, d27}, [%[outptr1]]!\n" + "vld1.32 {d28, d29}, [%[outptr1]]!\n" + "vld1.32 {d30, d31}, [%[outptr1]]!\n" + + "b 2f\n" + + "1:\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q5, q5, q5\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q7, q7, q7\n" + "veor.s32 q8, q8, q8\n" + "veor.s32 q9, q9, q9\n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q11, q11, q11\n" + "veor.s32 q12, q12, q12\n" + "veor.s32 q13, q13, q13\n" + "veor.s32 q14, q14, q14\n" + "veor.s32 q15, q15, q15\n" + + "2: \n" + "vld1.s8 {q0}, [%[b_ptr]]!\n" + "vld1.s8 {d2}, [%[b_ptr]]!\n" + "vld1.s8 {q2}, [%[a_ptr]]!\n" + "vld1.s8 {q3}, [%[a_ptr]]!\n" + + "cmp %[k], #0 \n" + "beq 4f \n" + + "3:\n" + "vsdot.s8 q4 , q2, d0[0]\n" + "vsdot.s8 q5 , q2, d0[1]\n" + "vsdot.s8 q6 , q2, d1[0]\n" + "vsdot.s8 q7 , q2, d1[1]\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vsdot.s8 q9 , q2, d2[1]\n" + "vsdot.s8 q10 , q3, d0[0]\n" + "vsdot.s8 q11 , q3, d0[1]\n" + "vsdot.s8 q12 , q3, d1[0]\n" + "vsdot.s8 q13 , q3, d1[1]\n" + "vsdot.s8 q14 , q3, d2[0]\n" + "vsdot.s8 q15 , q3, d2[1]\n" + + "vld1.s8 {q0}, [%[b_ptr]]!\n" + "vld1.s8 {d2}, [%[b_ptr]]!\n" + "vld1.s8 {q2}, [%[a_ptr]]!\n" + "vld1.s8 {q3}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q2, d0[0]\n" + "vsdot.s8 q5 , q2, d0[1]\n" + "vsdot.s8 q6 , q2, d1[0]\n" + "vsdot.s8 q7 , q2, d1[1]\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vsdot.s8 q9 , q2, d2[1]\n" + "vsdot.s8 q10 , q3, d0[0]\n" + "vsdot.s8 q11 , q3, d0[1]\n" + "vsdot.s8 q12 , q3, d1[0]\n" + "vsdot.s8 q13 , q3, d1[1]\n" + "vsdot.s8 q14 , q3, d2[0]\n" + "vsdot.s8 q15 , q3, d2[1]\n" + + "vld1.s8 {q0}, [%[b_ptr]]!\n" + "vld1.s8 {d2}, [%[b_ptr]]!\n" + "vld1.s8 {q2}, [%[a_ptr]]!\n" + "vld1.s8 {q3}, [%[a_ptr]]!\n" + + "subs %[k], %[k], #1\n" + "bne 3b\n" + + // Target to use when K is 1 or 2 (i.e. zero iterations of main + // loop) + "4:\n" + + "cmp %[oddk], #0 \n" + "bne 5f \n" + "vsdot.s8 q4 , q2, d0[0]\n" + "vsdot.s8 q5 , q2, d0[1]\n" + "vsdot.s8 q6 , q2, d1[0]\n" + "vsdot.s8 q7 , q2, d1[1]\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vsdot.s8 q9 , q2, d2[1]\n" + "vsdot.s8 q10 , q3, d0[0]\n" + "vsdot.s8 q11 , q3, d0[1]\n" + "vsdot.s8 q12 , q3, d1[0]\n" + "vsdot.s8 q13 , q3, d1[1]\n" + "vsdot.s8 q14 , q3, d2[0]\n" + "vsdot.s8 q15 , q3, d2[1]\n" + + "vld1.s8 {q0}, [%[b_ptr]]!\n" + "vld1.s8 {d2}, [%[b_ptr]]!\n" + "vld1.s8 {q2}, [%[a_ptr]]!\n" + "vld1.s8 {q3}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q2, d0[0]\n" + "vsdot.s8 q5 , q2, d0[1]\n" + "vsdot.s8 q6 , q2, d1[0]\n" + "vst1.32 {d8, d9}, [%[outptr0]]!\n" + "vsdot.s8 q7 , q2, d1[1]\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vsdot.s8 q9 , q2, d2[1]\n" + "vst1.32 {d10, d11}, [%[outptr0]]!\n" + "vsdot.s8 q10 , q3, d0[0]\n" + "vsdot.s8 q11 , q3, d0[1]\n" + "vsdot.s8 q12 , q3, d1[0]\n" + "vst1.32 {d12, d13}, [%[outptr0]]!\n" + "vsdot.s8 q13 , q3, d1[1]\n" + "vsdot.s8 q14 , q3, d2[0]\n" + "vsdot.s8 q15 , q3, d2[1]\n" + + "b 6f\n" + "5: \n" + "vsdot.s8 q4 , q2, d0[0]\n" + "vsdot.s8 q5 , q2, d0[1]\n" + "vsdot.s8 q6 , q2, d1[0]\n" + "vst1.32 {d8, d9}, [%[outptr0]]!\n" + "vsdot.s8 q7 , q2, d1[1]\n" + "vsdot.s8 q8 , q2, d2[0]\n" + "vsdot.s8 q9 , q2, d2[1]\n" + "vst1.32 {d10, d11}, [%[outptr0]]!\n" + "vsdot.s8 q10 , q3, d0[0]\n" + "vsdot.s8 q11 , q3, d0[1]\n" + "vsdot.s8 q12 , q3, d1[0]\n" + "vst1.32 {d12, d13}, [%[outptr0]]!\n" + "vsdot.s8 q13 , q3, d1[1]\n" + "vsdot.s8 q14 , q3, d2[0]\n" + "vsdot.s8 q15 , q3, d2[1]\n" + + "6: \n" + "vst1.32 {d14, d15}, [%[outptr0]]!\n" + "vst1.32 {d16, d17}, [%[outptr0]]!\n" + "vst1.32 {d18, d19}, [%[outptr0]]!\n" + "vst1.32 {d20, d21}, [%[outptr1]]!\n" + "vst1.32 {d22, d23}, [%[outptr1]]!\n" + "vst1.32 {d24, d25}, [%[outptr1]]!\n" + "vst1.32 {d26, d27}, [%[outptr1]]!\n" + "vst1.32 {d28, d29}, [%[outptr1]]!\n" + "vst1.32 {d30, d31}, [%[outptr1]]!\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), + [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [k] "+r"(k), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q14", "q15", "cc", "memory"); +} + +// Overview of register layout: +// +// A 2x4x4 cell of Rhs is stored in 8bit in q1, q3. +// A 2x2x4x4 ping-pong cell of Lhs is stored in 8bit in q5, q7, q9, q11 +// A 2x4x4 block of accumulators is stored in 8bit in q0, q2, q4, q6, q8, q10, +// q12, q14 +// +// +--------+ +// Rhs |q1[0-16]| +// |q3[0-16]| +// +--------+ +// Lhs | | +// +-------+-------+ - - - - +--------+ +// | q5[0-16]| | q0[0-4]| +// | q7[0-16]| | q2[0-4]| +// | q9[0-16]| | q4[0-4]| +// |q11[0-16]| | q6[0-4]| +// +---------+ | q8[0-4]| +// |q10[0-4]| +// |q12[0-4]| +// |q14[0-4]| +// +--------+ +// Accumulator + +static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int n_remain) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + + size_t x0; + +// clang-format off +#define LOAD_LINE(dr0, dr1, dr2, dr3, dr4, dr5, dr6, dr7, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "vld1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \ + "vld1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \ + "vld1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \ + "vld1.32 {d" dr6 ", d" dr7 "}, [%[x0]]!\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \ + "101" n ":\n" + +#define LOAD_C \ + LOAD_LINE("0", "1", "4", "5", "8", "9", "12", "13", "0") \ + LOAD_LINE("16", "17", "20", "21", "24", "25", "28", "29", "1") \ + +#define STORE_LINE(dr0, dr1, dr2, dr3, dr4, dr5, dr6, dr7, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "vst1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \ + "vst1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \ + "vst1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \ + "vst1.32 {d" dr6 ", d" dr7 "}, [%[x0]]!\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "vst1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "vst1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "vst1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("0", "1", "4", "5", "8", "9", "12", "13", "0") \ + STORE_LINE("16", "17", "20", "21", "24", "25", "28", "29", "1") + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %[LDC]\n" + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.s32 q0, q0, q0\n" + "veor.s32 q2, q2, q2\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q6, q6, q6\n" + "veor.s32 q8, q8, q8\n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q12, q12, q12\n" + "veor.s32 q14, q14, q14\n" + + "2: \n" + "cmp %[oddk], #0 \n" + "beq 3f \n" + + // parse the oddk + "vld1.s8 {q1}, [%[b_ptr]]!\n" + "vld1.s8 {q3}, [%[a_ptr]]!\n" + "vld1.s8 {q5}, [%[a_ptr]]!\n" + "vsdot.s8 q0 , q3, d2[0]\n" + "vsdot.s8 q2 , q3, d2[1]\n" + "vsdot.s8 q4 , q3, d3[0]\n" + "vsdot.s8 q6 , q3, d3[1]\n" + "vsdot.s8 q8 , q5, d2[0]\n" + "vsdot.s8 q10 , q5, d2[1]\n" + "vsdot.s8 q12 , q5, d3[0]\n" + "vsdot.s8 q14 , q5, d3[1]\n" + + "cmp %[k], #0 \n" + "beq 4f \n" + // Loop proper + "3:\n" + "vld1.s8 {q1}, [%[b_ptr]]!\n" + "vld1.s8 {q3}, [%[b_ptr]]!\n" + "vld1.s8 {q5}, [%[a_ptr]]!\n" + "vld1.s8 {q7}, [%[a_ptr]]!\n" + "vsdot.s8 q0 , q5, d2[0]\n" + "vsdot.s8 q2 , q5, d2[1]\n" + "vsdot.s8 q4 , q5, d3[0]\n" + "vsdot.s8 q6 , q5, d3[1]\n" + "vld1.s8 {q9}, [%[a_ptr]]!\n" + "vld1.s8 {q11}, [%[a_ptr]]!\n" + "vsdot.s8 q8 , q7, d2[0]\n" + "vsdot.s8 q10 , q7, d2[1]\n" + "vsdot.s8 q12 , q7, d3[0]\n" + "vsdot.s8 q14 , q7, d3[1]\n" + + "vsdot.s8 q0 , q9, d6[0]\n" + "vsdot.s8 q2 , q9, d6[1]\n" + "vsdot.s8 q4 , q9, d7[0]\n" + "vsdot.s8 q6 , q9, d7[1]\n" + "vsdot.s8 q8 , q11, d6[0]\n" + "vsdot.s8 q10 , q11, d6[1]\n" + "vsdot.s8 q12 , q11, d7[0]\n" + "vsdot.s8 q14 , q11, d7[1]\n" + + "subs %[k], %[k], #1\n" + "bne 3b\n" + + "4:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), + [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), + [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [x0] "+r"(x0) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q14", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A 1x6x4 pingpong cell of Rhs is stored in 8bit in q0-q3. +// A 1x1x4x4 pingpong cell of Lhs is stored in 8bit in q4-q5 +// A 2x6x4 block of accumulators is stored in 8bit in q10-q15 +// +// +--------+ +// Rhs |q0[0-16]| +// |q1[0-16]| +// +--------+ +// Lhs | | +// +-------+-------+ - - - - +--------+ +// | q4[0-16]| |q10[0-4]| +// | q5[0-16]| |q11[0-4]| +// +---------+ |q12[0-4]| +// |q13[0-4]| +// |q14[0-4]| +// |q15[0-4]| +// +--------+ +// Accumulator + +static void kern_4x6(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = (K + 1) / 2 - 1; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + + asm volatile( + // load accumulator C + "cmp %[is_first_k], #1\n" + "beq 1f\n" + "vld1.32 {d20, d21}, [%[outptr0]]!\n" + "vld1.32 {d22, d23}, [%[outptr0]]!\n" + "vld1.32 {d24, d25}, [%[outptr0]]!\n" + "vld1.32 {d26, d27}, [%[outptr0]]!\n" + "vld1.32 {d28, d29}, [%[outptr0]]!\n" + "vld1.32 {d30, d31}, [%[outptr0]]!\n" + + "b 2f\n" + + "1:\n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q11, q11, q11\n" + "veor.s32 q12, q12, q12\n" + "veor.s32 q13, q13, q13\n" + "veor.s32 q14, q14, q14\n" + "veor.s32 q15, q15, q15\n" + + "2: \n" + "vld1.s8 {q0}, [%[b_ptr]]!\n" + "vld1.s8 {d2}, [%[b_ptr]]!\n" + "vld1.s8 {q4}, [%[a_ptr]]!\n" + + "cmp %[k], #0 \n" + "beq 4f \n" + + "3:\n" + "vsdot.s8 q10 , q4, d0[0]\n" + "vsdot.s8 q11 , q4, d0[1]\n" + "vsdot.s8 q12 , q4, d1[0]\n" + "vld1.s8 {q2}, [%[b_ptr]]!\n" + "vld1.s8 {d6}, [%[b_ptr]]!\n" + "vld1.s8 {q5}, [%[a_ptr]]!\n" + "vsdot.s8 q13 , q4, d1[1]\n" + "vsdot.s8 q14 , q4, d2[0]\n" + "vsdot.s8 q15 , q4, d2[1]\n" + + "vld1.s8 {q0}, [%[b_ptr]]!\n" + "vsdot.s8 q10 , q5, d4[0]\n" + "vsdot.s8 q11 , q5, d4[1]\n" + "vsdot.s8 q12 , q5, d5[0]\n" + "vld1.s8 {d2}, [%[b_ptr]]!\n" + "vsdot.s8 q13 , q5, d5[1]\n" + "vsdot.s8 q14 , q5, d6[0]\n" + "vsdot.s8 q15 , q5, d6[1]\n" + "vld1.s8 {q4}, [%[a_ptr]]!\n" + + "subs %[k], %[k], #1\n" + "bne 3b\n" + + // Target to use when K is 1 or 2 (i.e. zero iterations of main + // loop) + "4:\n" + + "cmp %[oddk], #0 \n" + "bne 5f \n" + + "vsdot.s8 q10 , q4, d0[0]\n" + "vsdot.s8 q11 , q4, d0[1]\n" + "vsdot.s8 q12 , q4, d1[0]\n" + "vld1.s8 {q2}, [%[b_ptr]]!\n" + "vld1.s8 {d6}, [%[b_ptr]]!\n" + "vld1.s8 {q5}, [%[a_ptr]]!\n" + "vsdot.s8 q13 , q4, d1[1]\n" + "vsdot.s8 q14 , q4, d2[0]\n" + "vsdot.s8 q15 , q4, d2[1]\n" + + "vsdot.s8 q10 , q5, d4[0]\n" + "vsdot.s8 q11 , q5, d4[1]\n" + "vsdot.s8 q12 , q5, d5[0]\n" + "vst1.32 {d20, d21}, [%[outptr0]]!\n" + "vsdot.s8 q13 , q5, d5[1]\n" + "vsdot.s8 q14 , q5, d6[0]\n" + "vsdot.s8 q15 , q5, d6[1]\n" + "vst1.32 {d22, d23}, [%[outptr0]]!\n" + + "b 6f\n" + "5: \n" + + "vsdot.s8 q10 , q4, d0[0]\n" + "vsdot.s8 q11 , q4, d0[1]\n" + "vsdot.s8 q12 , q4, d1[0]\n" + "vst1.32 {d20, d21}, [%[outptr0]]!\n" + "vsdot.s8 q13 , q4, d1[1]\n" + "vsdot.s8 q14 , q4, d2[0]\n" + "vsdot.s8 q15 , q4, d2[1]\n" + "vst1.32 {d22, d23}, [%[outptr0]]!\n" + + "6: \n" + "vst1.32 {d24, d25}, [%[outptr0]]!\n" + "vst1.32 {d26, d27}, [%[outptr0]]!\n" + "vst1.32 {d28, d29}, [%[outptr0]]!\n" + "vst1.32 {d30, d31}, [%[outptr0]]!\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), + [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [k] "+r"(k), + [outptr0] "+r"(outptr0) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q14", "q15", "cc", "memory"); +} + +// Overview of register layout: +// +// A 2x4x4 cell of Rhs is stored in 8bit in q1, q3. +// A 1x2x4x4 ping-pong cell of Lhs is stored in 8bit in q5, q7 +// A 1x4x4 block of accumulators is stored in 8bit in q0, q2, q4, q6 +// +// +--------+ +// Rhs |q1[0-16]| +// |q3[0-16]| +// +--------+ +// Lhs | | +// +-------+-------+ - - - - +--------+ +// | q5[0-16]| | q0[0-4]| +// | q7[0-16]| | q2[0-4]| +// +---------+ | q4[0-4]| +// | q6[0-4]| +// +--------+ +// Accumulator + +static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int n_remain) { + K /= 4; + const int32_t* a_ptr = reinterpret_cast(packA); + const int32_t* b_ptr = reinterpret_cast(packB); + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + LDC = LDC * sizeof(int32_t); + + int32_t* outptr0 = output; + size_t x0; + +// clang-format off +#define LOAD_LINE(dr0, dr1, dr2, dr3, dr4, dr5, dr6, dr7, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "vld1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \ + "vld1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \ + "vld1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \ + "vld1.32 {d" dr6 ", d" dr7 "}, [%[x0]]!\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "vld1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \ + "101" n ":\n" + +#define LOAD_C \ + LOAD_LINE("0", "1", "4", "5", "8", "9", "12", "13", "0") \ + +#define STORE_LINE(dr0, dr1, dr2, dr3, dr4, dr5, dr6, dr7, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "vst1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \ + "vst1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \ + "vst1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \ + "vst1.32 {d" dr6 ", d" dr7 "}, [%[x0]]!\n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "vst1.32 {d" dr0 ", d" dr1 "}, [%[x0]]!\n" \ + "cmp %[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "vst1.32 {d" dr2 ", d" dr3 "}, [%[x0]]!\n" \ + "cmp %[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "vst1.32 {d" dr4 ", d" dr5 "}, [%[x0]]!\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("0", "1", "4", "5", "8", "9", "12", "13", "0") \ + // clang-format on + + asm volatile( + // load accumulator C + "cmp %[is_first_k], #1\n" + "beq 1f\n" // + LOAD_C // + + "b 2f\n" + + "1:\n" + "veor.s32 q0, q0, q0\n" + "veor.s32 q2, q2, q2\n" + "veor.s32 q4, q4, q4\n" + "veor.s32 q6, q6, q6\n" + + "2: \n" + "cmp %[oddk], #0 \n" + "beq 3f \n" + + // parse the oddk + "vld1.s8 {q1}, [%[a_ptr]]!\n" + "vld1.s8 {q3}, [%[b_ptr]]!\n" + "vsdot.s8 q0 , q1, d6[0]\n" + "vsdot.s8 q2 , q1, d6[1]\n" + "vsdot.s8 q4 , q1, d7[0]\n" + "vsdot.s8 q6 , q1, d7[1]\n" + + "cmp %[k], #0 \n" + "beq 4f \n" + // Loop proper + "3:\n" + "vld1.s8 {q1}, [%[b_ptr]]!\n" + "vld1.s8 {q5}, [%[a_ptr]]!\n" + "vsdot.s8 q0 , q5, d2[0]\n" + "vsdot.s8 q2 , q5, d2[1]\n" + "vld1.s8 {q3}, [%[b_ptr]]!\n" + "vld1.s8 {q7}, [%[a_ptr]]!\n" + "vsdot.s8 q4 , q5, d3[0]\n" + "vsdot.s8 q6 , q5, d3[1]\n" + "vsdot.s8 q0 , q7, d6[0]\n" + "vsdot.s8 q2 , q7, d6[1]\n" + "vsdot.s8 q4 , q7, d7[0]\n" + "vsdot.s8 q6 , q7, d7[1]\n" + + "subs %[k], %[k], #1\n" + "bne 3b\n" + + "4:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), + [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k), + [x0] "+r"(x0) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_dots8_8x6_pack_A(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + int y = y0, y_start = y0 / 4; + for (; y + 7 < ymax; y += 8, y_start += 2) { + const int8_t* inptr0 = inptr + y_start * ldin + k0 * 4; + const int8_t* inptr1 = inptr0 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + int K = kmax - k0; + for (; K > 3; K -= 4) { + interleave_2x4_4_b(inptr0, inptr1, outptr); + } + } + for (; y + 3 < ymax; y += 4, ++y_start) { + int K = kmax - k0; + const int8_t* inptr0 = inptr + y_start * ldin + k0 * 4; + std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4); + } +} + +static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + const int ksize = kmax - k0; + const int ksize4 = ksize * 4; + const int ksize6 = ksize * 6; + int8_t* outptr = out; + int8_t* outptr_base = out; + int8_t* outptr_base4 = out + ((xmax - x0) / 6) * ksize6; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const int8_t* inptr = in + k / 4 * ldin + x0 * 4; + prefetch_2x(inptr); + + outptr = outptr_base; + int x = x0; + for (; x + 5 < xmax; x += 6) { + memcpy(outptr, inptr, sizeof(int8_t) * 24); + outptr += ksize6; + inptr += 24; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + memcpy(outptr, inptr, sizeof(int8_t) * 16); + outptr += ksize4; + inptr += 16; + } + if (x < xmax) { + int i = 0; + for (; i < xmax - x; i++) { + *outptr++ = *inptr++; + *outptr++ = *inptr++; + *outptr++ = *inptr++; + *outptr++ = *inptr++; + } + for (; i < 4; i++) { + *outptr++ = *inptr++; + *outptr++ = *inptr++; + *outptr++ = *inptr++; + *outptr++ = *inptr++; + } + } + outptr_base += 24; + outptr_base4 += 16; + } +} + +} // namespace matmul_mk4_dot_8x6x4 +} // namespace armv7 +} // namespace megdnn +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8/strategy.cpp b/dnn/src/armv7/matrix_mul/int8/strategy.cpp index 275002a02..bdca391b9 100644 --- a/dnn/src/armv7/matrix_mul/int8/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/int8/strategy.cpp @@ -16,6 +16,7 @@ #include "src/armv7/matrix_mul/int8/kernel_4x8x8.h" #include "src/armv7/matrix_mul/int8/kernel_6x8x4.h" #include "src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h" +#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_common.h" @@ -252,6 +253,89 @@ void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, } } } + +// ===========================gemm_mk4_dots8_8x6====================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_dots8_8x6); + +void gemm_mk4_dots8_8x6::pack_A(dt_int8* out, const dt_int8* in, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + megdnn_assert(!transpose, + "matrix mul mk4 with transposed matrix A is not supported."); + megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, + "mk4 format matmul with m is not times of 4."); + megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, + "mk4 format matmul with k is not times of 4."); + matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_A(out, in, ldin, y0, ymax, k0, + kmax); +} + +void gemm_mk4_dots8_8x6::pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax, + bool transpose) const { + megdnn_assert(!transpose, + "matrix mul mk4 with transposed matrix B is not supported"); + megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, + "mk4 format matmul with k is not times of 4."); + matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_B(out, in, ldin, x0, xmax, k0, + kmax); +} + +void gemm_mk4_dots8_8x6::kern(const dt_int8* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int32* C, + size_t LDC, bool is_first_k, const dt_int32* bias, + dt_int32* workspace) const { + MEGDNN_MARK_USED_VAR(bias); + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 6; + //! K is packed to times of 4 + K = round_up(K, 4); + const int K4 = K * 4; + const int K6 = K * 6; + const int K8 = K * 8; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + ((m >> 2) * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_mk4_dot_8x6x4::kern_8x6(packA, cur_packB, K, output, LDC, + is_first_k); + output += 24; + cur_packB += K6; + } + + for (; n < N; n += 4) { + size_t n_remain = std::min(N - n, 4); + matmul_mk4_dot_8x6x4::kern_8x4(packA, cur_packB, K, output, LDC, + is_first_k, n_remain); + output += 16; + cur_packB += K4; + } + packA += K8; + } + for (; m < M; m += 4) { + int32_t* output = C + ((m >> 2) * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_mk4_dot_8x6x4::kern_4x6(packA, cur_packB, K, output, LDC, + is_first_k); + output += 24; + cur_packB += K6; + } + for (; n < N; n += 4) { + size_t n_remain = std::min(N - n, 4); + matmul_mk4_dot_8x6x4::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, n_remain); + output += 16; + cur_packB += K4; + } + packA += K4; + } +} + #endif // ===========================gemm_mk4_s8_4x2====================================== diff --git a/dnn/src/armv7/matrix_mul/int8/strategy.h b/dnn/src/armv7/matrix_mul/int8/strategy.h index 516a99318..bcd9588f3 100644 --- a/dnn/src/armv7/matrix_mul/int8/strategy.h +++ b/dnn/src/armv7/matrix_mul/int8/strategy.h @@ -26,6 +26,9 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false, #if __ARM_FEATURE_DOTPROD MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, gemm_dots8_6x8); + +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 6, 4, false, false, + gemm_mk4_dots8_8x6); #endif } // namespace matmul } // namespace armv7 diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp index 736e2f7c7..ca157fc56 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.cpp +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -29,6 +29,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #if __ARM_FEATURE_DOTPROD AlgoInt8x8x32K6x8x4 int8_k6x8x4; AlgoQuint8DotK4x8x4 quint8_k4x8x4; + AlgoInt8x8x32MK4_8x6x4DotProd int8x8x32_mk4_8x6x4_dotprod; #endif AlgoF32Gemv f32_gemv; AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; @@ -56,6 +57,7 @@ public: all_algos.emplace_back(&f16_mk8_4x8); #endif #if __ARM_FEATURE_DOTPROD + all_algos.emplace_back(&int8x8x32_mk4_8x6x4_dotprod); all_algos.emplace_back(&int8_k6x8x4); all_algos.emplace_back(&quint8_k4x8x4); #endif diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index fe3d22a3b..4b11b5ba8 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -42,6 +42,8 @@ private: #if __ARM_FEATURE_DOTPROD class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 + class AlgoInt8x8x32MK4_8x6x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x6x4 + // DotProduct #endif class AlgoPack; }; diff --git a/dnn/test/armv7/matrix_mul.cpp b/dnn/test/armv7/matrix_mul.cpp index 78d82bc0b..6ef2d23e2 100644 --- a/dnn/test/armv7/matrix_mul.cpp +++ b/dnn/test/armv7/matrix_mul.cpp @@ -86,6 +86,18 @@ TEST_F(ARMV7, MATRIX_MUL_UDOT) { dtype::Quantized8Asymm(4.0f, static_cast(10)), dtype::Quantized8Asymm(3.0f, static_cast(54)), dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4"); } + +TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) { + std::vector args; + for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11}) + for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24, 25, 32}) + for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34}) + args.emplace_back(m, n, k, 0); + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, + handle(), "AARCH32_INT8_MK4_8X6X4_DOTPROD", + param::MatrixMul::Format::MK4_DOT, 1, 1e-3, + std::move(args)); +} #endif #if MEGDNN_WITH_BENCHMARK @@ -286,6 +298,53 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_K6x8x4) { TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_QUINT8x8x32_K4x8x4) { run_8x8x32_quint_benchmark(handle()); } + +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_MK4_DOT) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + Benchmarker benchmarker_default(handle()); + benchmarker_default.set_times(RUNS) + .set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()) + .set_param(param) + .set_display(false); + benchmarker_default.set_before_exec_callback( + AlgoChecker("AARCH32_INT8_K6X8X4")); + + param.format = MatrixMul::Param::Format::MK4_DOT; + Benchmarker benchmarker_mk4_dot(handle()); + benchmarker_mk4_dot.set_before_exec_callback( + AlgoChecker("AARCH32_INT8_MK4_8X6X4_DOTPROD")); + benchmarker_mk4_dot.set_param(param) + .set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()) + .set_display(false) + .set_times(RUNS); + + auto run = [&](size_t M, size_t N, size_t K) { + auto default_used = + benchmarker_default.exec({{M, K}, {K, N}, {}}) / RUNS; + auto mk4_dot_used = benchmarker_mk4_dot.exec( + {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / + RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} default: %f ms %f Gflops mk4_dot: " + "%f ms " + "%f Gflops speedup: %f\n", + M, K, N, default_used, computations / default_used, mk4_dot_used, + computations / mk4_dot_used, default_used / mk4_dot_used); + }; + + for (size_t M = 4; M < 512; M *= 2) { + for (size_t K = 4; K < 512; K *= 2) { + for (size_t N : {4, 8, 33, 113, 128}) { + run(M, N, K); + } + } + } +} #endif TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) { -- GitLab