diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index e8ad513db28d1fdc8d63878460a428f3f0d76d4a..6ae30e8215ac0c80b120e62be018d846585679f6 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -474,6 +474,76 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvDotProd::get_kern( MIDOUT_END(); return nullptr; } + +/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ +namespace { +void int8x8x32_mk4_8x12x4_dotprod_kern( + const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int8x8x32_mk4_8x12x4_dotprod_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && + (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32) && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK4_DOT && + !kern_size_param.trA && !kern_size_param.trB; +} + +size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + + aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type, + C_type); + return megdnn::matmul::GemmInterleaved< + aarch64::matmul::gemm_mk4_s8_8x12>(M, N, K, trA, trB, + strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern( + const KernSizeParam&) const { + return int8x8x32_mk4_8x12x4_dotprod_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd, + megdnn_aarch64_matmul_kern, + "AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash, + aarch64::matmul::gemm_mk4_s8_8x12, int8_t, + int32_t); #else /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index dc46b5e8eea8a7134744f5131a89d43f29641897..672187560a1aa6a6d8e3f68dec41221c886f5ea6 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -118,6 +118,19 @@ public: AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } }; + +class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { + return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; + } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; #else class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { diff --git a/dnn/src/aarch64/matrix_mul/asm/common.h b/dnn/src/aarch64/matrix_mul/asm/common.h index 334d3f9b2426d295afd1cbd51cec6d61d5af9e87..c24395cc822432cd9d4b11097de328f03102efbc 100644 --- a/dnn/src/aarch64/matrix_mul/asm/common.h +++ b/dnn/src/aarch64/matrix_mul/asm/common.h @@ -615,6 +615,20 @@ static inline void interleave_12x4_4_b(const T*& inptr0, const T*& inptr1, reinterpret_cast(outptr)); } +static inline void interleave_2x1_4_s(const int32_t*& inptr0, + const int32_t*& inptr1, + int32_t*& outptr) { + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3 + "ld1 {v1.4s}, [%[inptr1]], #16\n" // d1 = B0B1B2B3 + "st1 {v0.4s}, [%[outptr]], #16\n" + "st1 {v1.4s}, [%[outptr]], #16\n" + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) + : + : "v0", "v1", "cc", "memory"); +} + static inline void interleave_8x1_4_s( const int32_t*& inptr0, const int32_t*& inptr1, const int32_t*& inptr2, const int32_t*& inptr3, const int32_t*& inptr4, const int32_t*& inptr5, @@ -752,6 +766,17 @@ static inline void interleave_8x2_2_d( "v11", "v12", "v13", "v14", "v15", "cc", "memory"); } +template +static inline void interleave_2x4_4_b(const T*& inptr0, const T*& inptr1, + T*& outptr) { + static_assert( + std::is_same::value || std::is_same::value, + "interleave_2x4_4_b only support uint8_t and int8_t"); + interleave_2x1_4_s(reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(outptr)); +} + template static inline void interleave_8x4_4_b(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h new file mode 100644 index 0000000000000000000000000000000000000000..cef2569eb629f014984e79845d0cb26acfb10565 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h @@ -0,0 +1,933 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#if __ARM_FEATURE_DOTPROD + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_mk4_8x12x4 { + +// Overview of register layout: +// +// A 12x4 cell of Rhs is stored in 8bit in q2-q4. +// A 8x4x2 cell of Lhs is stored in 8bit in q0-q1,q5-q6 +// A 8x12 block of accumulators is stored in 8bit in q8--q31. +// +// +------------+------------+------------+ +// | v2[0-16]| v3[0-16]| v4[0-16]| +// Rhs +------------+------------+------------+ +// +// | | | | +// +// Lhs | | | | +// +// +--------+--------+ - - - - +------------+------------+------------+ +// |v0[0-16]|v5[0-16]| | v8 v9v10v11|v16v17v18v19|v24v25v26v27| +// |v1[0-16]|v6[0-16]| |v12v13v14v15|v20v21v22v23|v28v29v30v31| +// +--------+--------+ - - - - +------------+------------+------------+ +// +// Accumulator + +static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + int32x4_t a0; + int32x4_t a1; + int32x4_t b0; + int32x4_t b1; + int32x4_t b2; + int32x4_t a0a; + int32x4_t a1a; + LDC = LDC * sizeof(int32_t); + + int32_t* outptr0 = output; + int32_t* outptr1; + + asm volatile ( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 5f\n" + // we can not use ld1, as it can not encode {v8, v16, v24} + "ldp q8, q9, [%[outptr0]]\n" + "ldp q10, q11, [%[outptr0], #32]\n" + "ldp q16, q17, [%[outptr0], #64]\n" + "ldp q18, q19, [%[outptr0], #96]\n" + "ldp q24, q25, [%[outptr0], #128]\n" + "ldp q26, q27, [%[outptr0], #160]\n" + "ldp q12, q13, [%[outptr1]]\n" + "ldp q14, q15, [%[outptr1], #32]\n" + "ldp q20, q21, [%[outptr1], #64]\n" + "ldp q22, q23, [%[outptr1], #96]\n" + "ldp q28, q29, [%[outptr1], #128]\n" + "ldp q30, q31, [%[outptr1], #160]\n" + "b 6f\n" + + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + + "6: \n" + // Initialize result registers, load initial operands, prime prefetches. + "ldr %q[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[a1], [%[a_ptr], #16]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + ASM_PREFETCH("[%[b_ptr], #64]") + ASM_PREFETCH("[%[a_ptr], #64]") + ASM_PREFETCH("[%[b_ptr], #128]") + ASM_PREFETCH("[%[a_ptr], #128]") + ASM_PREFETCH("[%[b_ptr], #192]") + ASM_PREFETCH("[%[b_ptr], #256]") + ASM_PREFETCH("[%[a_ptr], #192]") + ASM_PREFETCH("[%[b_ptr], #320]") + ASM_PREFETCH("[%[a_ptr], #256]") + ASM_PREFETCH("[%[b_ptr], #384]") + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + // Loop proper + "1:\n" + "sdot v8.4s , %[a0].16b, %[b0].4b[0]\n" + "sdot v9.4s , %[a0].16b, %[b0].4b[1]\n" + + "ldr %q[b2], [%[b_ptr], #32]\n" + "sdot v10.4s, %[a0].16b, %[b0].4b[2]\n" + "sdot v11.4s, %[a0].16b, %[b0].4b[3]\n" + "ldr %q[a0a], [%[a_ptr], #32]\n" + "sdot v12.4s, %[a1].16b, %[b0].4b[0]\n" + "sdot v13.4s, %[a1].16b, %[b0].4b[1]\n" + "ldr %q[a1a], [%[a_ptr], #48]\n" + "sdot v14.4s, %[a1].16b, %[b0].4b[2]\n" + "sdot v15.4s, %[a1].16b, %[b0].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #48]\n" + + "sdot v16.4s, %[a0].16b, %[b1].4b[0]\n" + "sdot v17.4s, %[a0].16b, %[b1].4b[1]\n" + ASM_PREFETCH("[%[a_ptr], #320]") + "sdot v18.4s, %[a0].16b, %[b1].4b[2]\n" + "sdot v19.4s, %[a0].16b, %[b1].4b[3]\n" + "sdot v20.4s, %[a1].16b, %[b1].4b[0]\n" + "sdot v21.4s, %[a1].16b, %[b1].4b[1]\n" + "sdot v22.4s, %[a1].16b, %[b1].4b[2]\n" + "sdot v23.4s, %[a1].16b, %[b1].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #64]\n" + + "sdot v24.4s, %[a0].16b, %[b2].4b[0]\n" + "sdot v25.4s, %[a0].16b, %[b2].4b[1]\n" + ASM_PREFETCH("[%[b_ptr], #448]") + "sdot v26.4s, %[a0].16b, %[b2].4b[2]\n" + "sdot v27.4s, %[a0].16b, %[b2].4b[3]\n" + "sdot v28.4s, %[a1].16b, %[b2].4b[0]\n" + "sdot v29.4s, %[a1].16b, %[b2].4b[1]\n" + "sdot v30.4s, %[a1].16b, %[b2].4b[2]\n" + "sdot v31.4s, %[a1].16b, %[b2].4b[3]\n" + "ldr %q[b2], [%[b_ptr], #80]\n" + + "sdot v8.4s , %[a0a].16b, %[b0].4b[0]\n" + "sdot v9.4s , %[a0a].16b, %[b0].4b[1]\n" + "ldr %q[a0], [%[a_ptr], #64]\n" + "sdot v10.4s, %[a0a].16b, %[b0].4b[2]\n" + "sdot v11.4s, %[a0a].16b, %[b0].4b[3]\n" + "sdot v12.4s, %[a1a].16b, %[b0].4b[0]\n" + "ldr %q[a1], [%[a_ptr], #80]\n" + "sdot v13.4s, %[a1a].16b, %[b0].4b[1]\n" + "sdot v14.4s, %[a1a].16b, %[b0].4b[2]\n" + "sdot v15.4s, %[a1a].16b, %[b0].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #96]\n" + + "sdot v16.4s, %[a0a].16b, %[b1].4b[0]\n" + "sdot v17.4s, %[a0a].16b, %[b1].4b[1]\n" + ASM_PREFETCH("[%[b_ptr], #512]") + "sdot v18.4s, %[a0a].16b, %[b1].4b[2]\n" + "sdot v19.4s, %[a0a].16b, %[b1].4b[3]\n" + "sdot v20.4s, %[a1a].16b, %[b1].4b[0]\n" + "sdot v21.4s, %[a1a].16b, %[b1].4b[1]\n" + "sdot v22.4s, %[a1a].16b, %[b1].4b[2]\n" + "sdot v23.4s, %[a1a].16b, %[b1].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #112]\n" + + "sdot v24.4s, %[a0a].16b, %[b2].4b[0]\n" + "sdot v25.4s, %[a0a].16b, %[b2].4b[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "sdot v26.4s, %[a0a].16b, %[b2].4b[2]\n" + "sdot v27.4s, %[a0a].16b, %[b2].4b[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "sdot v28.4s, %[a1a].16b, %[b2].4b[0]\n" + "sdot v29.4s, %[a1a].16b, %[b2].4b[1]\n" + "subs %w[k], %w[k], #1\n" + "sdot v30.4s, %[a1a].16b, %[b2].4b[2]\n" + "sdot v31.4s, %[a1a].16b, %[b2].4b[3]\n" + "bne 1b\n" + + // Target to use when K is 1 or 2 (i.e. zero iterations of main loop) + "4:\n" + + // Branch to alternative tail for odd K + "cbnz %w[oddk], 2f\n" + + // Detached final iteration (even K) + "sdot v8.4s , %[a0].16b, %[b0].4b[0]\n" + "sdot v9.4s , %[a0].16b, %[b0].4b[1]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "sdot v10.4s, %[a0].16b, %[b0].4b[2]\n" + "sdot v11.4s, %[a0].16b, %[b0].4b[3]\n" + "ldr %q[a0a], [%[a_ptr], #32]\n" + "sdot v12.4s, %[a1].16b, %[b0].4b[0]\n" + "sdot v13.4s, %[a1].16b, %[b0].4b[1]\n" + "ldr %q[a1a], [%[a_ptr], #48]\n" + "sdot v14.4s, %[a1].16b, %[b0].4b[2]\n" + "sdot v15.4s, %[a1].16b, %[b0].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #48]\n" + + "sdot v16.4s, %[a0].16b, %[b1].4b[0]\n" + "sdot v17.4s, %[a0].16b, %[b1].4b[1]\n" + "sdot v18.4s, %[a0].16b, %[b1].4b[2]\n" + "sdot v19.4s, %[a0].16b, %[b1].4b[3]\n" + "sdot v20.4s, %[a1].16b, %[b1].4b[0]\n" + "sdot v21.4s, %[a1].16b, %[b1].4b[1]\n" + "sdot v22.4s, %[a1].16b, %[b1].4b[2]\n" + "sdot v23.4s, %[a1].16b, %[b1].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #64]\n" + + "sdot v24.4s, %[a0].16b, %[b2].4b[0]\n" + "sdot v25.4s, %[a0].16b, %[b2].4b[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "sdot v26.4s, %[a0].16b, %[b2].4b[2]\n" + "sdot v27.4s, %[a0].16b, %[b2].4b[3]\n" + "sdot v28.4s, %[a1].16b, %[b2].4b[0]\n" + "sdot v29.4s, %[a1].16b, %[b2].4b[1]\n" + "sdot v30.4s, %[a1].16b, %[b2].4b[2]\n" + "sdot v31.4s, %[a1].16b, %[b2].4b[3]\n" + "ldr %q[b2], [%[b_ptr], #80]\n" + + "sdot v8.4s , %[a0a].16b, %[b0].4b[0]\n" + + "sdot v16.4s, %[a0a].16b, %[b1].4b[0]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "sdot v9.4s , %[a0a].16b, %[b0].4b[1]\n" + "str q8, [%[outptr0], #0]\n" + "sdot v17.4s, %[a0a].16b, %[b1].4b[1]\n" + "str q16, [%[outptr0], #64]\n" + "sdot v24.4s, %[a0a].16b, %[b2].4b[0]\n" + "str q24, [%[outptr0], #128]\n" + + "sdot v25.4s, %[a0a].16b, %[b2].4b[1]\n" + "str q9, [%[outptr0], #16]\n" + "sdot v10.4s, %[a0a].16b, %[b0].4b[2]\n" + "str q17, [%[outptr0], #80]\n" + "sdot v18.4s, %[a0a].16b, %[b1].4b[2]\n" + "str q25, [%[outptr0], #144]\n" + "sdot v26.4s, %[a0a].16b, %[b2].4b[2]\n" + "str q10, [%[outptr0], #32]\n" + + "sdot v11.4s, %[a0a].16b, %[b0].4b[3]\n" + "str q18, [%[outptr0], #96]\n" + "sdot v19.4s, %[a0a].16b, %[b1].4b[3]\n" + "str q26, [%[outptr0], #160]\n" + "sdot v27.4s, %[a0a].16b, %[b2].4b[3]\n" + "str q11, [%[outptr0], #48]\n" + + "sdot v12.4s, %[a1a].16b, %[b0].4b[0]\n" + "str q19, [%[outptr0], #112]\n" + "sdot v20.4s, %[a1a].16b, %[b1].4b[0]\n" + "str q27, [%[outptr0], #176]\n" + "sdot v28.4s, %[a1a].16b, %[b2].4b[0]\n" + "str q12, [%[outptr1], #0]\n" + + "sdot v13.4s, %[a1a].16b, %[b0].4b[1]\n" + "str q20, [%[outptr1], #64]\n" + "sdot v21.4s, %[a1a].16b, %[b1].4b[1]\n" + "str q28, [%[outptr1], #128]\n" + "sdot v29.4s, %[a1a].16b, %[b2].4b[1]\n" + "str q13, [%[outptr1], #16]\n" + + "sdot v14.4s, %[a1a].16b, %[b0].4b[2]\n" + "str q21, [%[outptr1], #80]\n" + "sdot v22.4s, %[a1a].16b, %[b1].4b[2]\n" + "str q29, [%[outptr1], #144]\n" + "sdot v30.4s, %[a1a].16b, %[b2].4b[2]\n" + "str q14, [%[outptr1], #32]\n" + + "sdot v15.4s, %[a1a].16b, %[b0].4b[3]\n" + "str q22, [%[outptr1], #96]\n" + "sdot v23.4s, %[a1a].16b, %[b1].4b[3]\n" + "str q30, [%[outptr1], #160]\n" + "sdot v31.4s, %[a1a].16b, %[b2].4b[3]\n" + "str q15, [%[outptr1], #48]\n" + + "b 3f\n" + + // Detached final iteration (odd K) + "2:\n" + "sdot v8.4s , %[a0].16b, %[b0].4b[0]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "sdot v16.4s, %[a0].16b, %[b1].4b[0]\n" + "sdot v9.4s , %[a0].16b, %[b0].4b[1]\n" + "str q8, [%[outptr0], #0]\n" + "sdot v17.4s, %[a0].16b, %[b1].4b[1]\n" + "str q16, [%[outptr0], #64]\n" + "sdot v24.4s, %[a0].16b, %[b2].4b[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "add %[a_ptr], %[a_ptr], #32\n" + "str q24, [%[outptr0], #128]\n" + "sdot v25.4s, %[a0].16b, %[b2].4b[1]\n" + "str q9, [%[outptr0], #16]\n" + + "sdot v10.4s, %[a0].16b, %[b0].4b[2]\n" + "str q17, [%[outptr0], #80]\n" + "sdot v18.4s, %[a0].16b, %[b1].4b[2]\n" + "str q25, [%[outptr0], #144]\n" + "sdot v26.4s, %[a0].16b, %[b2].4b[2]\n" + "str q10, [%[outptr0], #32]\n" + + "sdot v11.4s, %[a0].16b, %[b0].4b[3]\n" + "str q18, [%[outptr0], #96]\n" + "sdot v19.4s, %[a0].16b, %[b1].4b[3]\n" + "str q26, [%[outptr0], #160]\n" + "sdot v27.4s, %[a0].16b, %[b2].4b[3]\n" + "str q11, [%[outptr0], #48]\n" + + "sdot v12.4s, %[a1].16b, %[b0].4b[0]\n" + "str q19, [%[outptr0], #112]\n" + "sdot v20.4s, %[a1].16b, %[b1].4b[0]\n" + "str q27, [%[outptr0], #176]\n" + "sdot v28.4s, %[a1].16b, %[b2].4b[0]\n" + "str q12, [%[outptr1], #0]\n" + + "sdot v13.4s, %[a1].16b, %[b0].4b[1]\n" + "str q20, [%[outptr1], #64]\n" + "sdot v21.4s, %[a1].16b, %[b1].4b[1]\n" + "str q28, [%[outptr1], #128]\n" + "sdot v29.4s, %[a1].16b, %[b2].4b[1]\n" + "str q13, [%[outptr1], #16]\n" + + "sdot v14.4s, %[a1].16b, %[b0].4b[2]\n" + "str q21, [%[outptr1], #80]\n" + "sdot v22.4s, %[a1].16b, %[b1].4b[2]\n" + "str q29, [%[outptr1], #144]\n" + "sdot v30.4s, %[a1].16b, %[b2].4b[2]\n" + "str q14, [%[outptr1], #32]\n" + + "sdot v15.4s, %[a1].16b, %[b0].4b[3]\n" + "str q22, [%[outptr1], #96]\n" + "sdot v23.4s, %[a1].16b, %[b1].4b[3]\n" + "str q30, [%[outptr1], #160]\n" + "sdot v31.4s, %[a1].16b, %[b2].4b[3]\n" + "str q15, [%[outptr1], #48]\n" + + + // Common tail + "3:\n" + "str q23, [%[outptr1], #112]\n" + "str q31, [%[outptr1], #176]\n" + : + [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr),[oddk] "+r" (oddk), + [is_first_k] "+r" (is_first_k), [k] "+r" (k), [LDC] "+r" (LDC), + [a0] "=w" (a0), [a1] "=w" (a1), [a0a] "=w" (a0a), [a1a] "=w" (a1a), + [b0] "=w" (b0), [b1] "=w" (b1), [b2] "=w" (b2), + [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1) + : + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", + "memory" + ); +} + +// Overview of register layout: +// +// A (12x4)x2 cell of Rhs is stored in 8bit in q2-q7. +// A (4x4)x2 cell of Lhs is stored in 8bit in q0-q1 +// A 4x12 block of accumulators is stored in 8bit in q8--q19. +// +// +------------+------------+------------+ +// | v2[0-16]| v3[0-16]| v4[0-16]| +// Rhs +------------+------------+------------+ +// | v5[0-16]| v6[0-16]| v7[0-16]| +// +------------+------------+------------+ +// Lhs | | | | +// +// +--------+--------+ - - - - +------------+------------+------------+ +// |v0[0-16]|v1[0-16]| | v8 v9v10v11|v12v13v14v15|v16v17v18v19| +// +--------+--------+ - - - - +------------+------------+------------+ +// +// Accumulator + +static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + int32x4_t a0; + int32x4_t b0; + int32x4_t b1; + int32x4_t b2; + int32x4_t a0a; + int32x4_t b0a; + int32x4_t b1a; + int32x4_t b2a; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "ldp q8, q9, [%[outptr0]]\n" + "ldp q10, q11, [%[outptr0], #32]\n" + "ldp q12, q13, [%[outptr0], #64]\n" + "ldp q14, q15, [%[outptr0], #96]\n" + "ldp q16, q17, [%[outptr0], #128]\n" + "ldp q18, q19, [%[outptr0], #160]\n" + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "ldr %q[b2], [%[b_ptr]], #16\n" + "sdot v8.4s, %[a0].16b, %[b0].4b[0]\n" + "sdot v9.4s, %[a0].16b, %[b0].4b[1]\n" + "sdot v10.4s, %[a0].16b, %[b0].4b[2]\n" + "sdot v11.4s, %[a0].16b, %[b0].4b[3]\n" + "sdot v12.4s, %[a0].16b, %[b1].4b[0]\n" + "sdot v13.4s, %[a0].16b, %[b1].4b[1]\n" + "sdot v14.4s, %[a0].16b, %[b1].4b[2]\n" + "sdot v15.4s, %[a0].16b, %[b1].4b[3]\n" + "sdot v16.4s, %[a0].16b, %[b2].4b[0]\n" + "sdot v17.4s, %[a0].16b, %[b2].4b[1]\n" + "sdot v18.4s, %[a0].16b, %[b2].4b[2]\n" + "sdot v19.4s, %[a0].16b, %[b2].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[b1], [%[b_ptr]], #16\n" + "ldr %q[b2], [%[b_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[b0a], [%[b_ptr]], #16\n" + "ldr %q[b1a], [%[b_ptr]], #16\n" + "ldr %q[b2a], [%[b_ptr]], #16\n" + + "sdot v8.4s, %[a0].16b, %[b0].4b[0]\n" + "sdot v9.4s, %[a0].16b, %[b0].4b[1]\n" + "sdot v10.4s, %[a0].16b, %[b0].4b[2]\n" + "sdot v11.4s, %[a0].16b, %[b0].4b[3]\n" + "sdot v12.4s, %[a0].16b, %[b1].4b[0]\n" + "sdot v13.4s, %[a0].16b, %[b1].4b[1]\n" + "sdot v14.4s, %[a0].16b, %[b1].4b[2]\n" + "sdot v15.4s, %[a0].16b, %[b1].4b[3]\n" + "sdot v16.4s, %[a0].16b, %[b2].4b[0]\n" + "sdot v17.4s, %[a0].16b, %[b2].4b[1]\n" + "sdot v18.4s, %[a0].16b, %[b2].4b[2]\n" + "sdot v19.4s, %[a0].16b, %[b2].4b[3]\n" + "sdot v8.4s , %[a0a].16b, %[b0a].4b[0]\n" + "sdot v9.4s , %[a0a].16b, %[b0a].4b[1]\n" + "sdot v10.4s, %[a0a].16b, %[b0a].4b[2]\n" + "sdot v11.4s, %[a0a].16b, %[b0a].4b[3]\n" + "sdot v12.4s, %[a0a].16b, %[b1a].4b[0]\n" + "sdot v13.4s, %[a0a].16b, %[b1a].4b[1]\n" + "sdot v14.4s, %[a0a].16b, %[b1a].4b[2]\n" + "sdot v15.4s, %[a0a].16b, %[b1a].4b[3]\n" + "sdot v16.4s, %[a0a].16b, %[b2a].4b[0]\n" + "sdot v17.4s, %[a0a].16b, %[b2a].4b[1]\n" + "sdot v18.4s, %[a0a].16b, %[b2a].4b[2]\n" + "sdot v19.4s, %[a0a].16b, %[b2a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" + "stp q8, q9, [%[outptr0]]\n" + "stp q10, q11, [%[outptr0], #32]\n" + "stp q12, q13, [%[outptr0], #64]\n" + "stp q14, q15, [%[outptr0], #96]\n" + "stp q16, q17, [%[outptr0], #128]\n" + "stp q18, q19, [%[outptr0], #160]\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), + [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [a0] "=w"(a0), [a0a] "=w"(a0a), + [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), + [b1a] "=w"(b1a), [b2a] "=w"(b2a) + : + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "memory", "cc"); +} + +// Overview of register layout: +// +// A (4x4)x2 cell of Rhs is stored in 8bit in q2-q7. +// A (8x4)x2 cell of Lhs is stored in 8bit in q0-q1, q4-q5 +// A 8x4 block of accumulators is stored in 8bit in q6-q13. +// +// +------------+ +// | v2[0-16]| +// Rhs +------------+ +// | v3[0-16]| +// +------------+ +// Lhs | | +// +// +--------+--------+ - - - - +------------+ +// |v0[0-16]|v4[0-16]| | v6 v7 v8 v9| +// +--------+--------+ - - - - +------------+ +// |v1[0-16]|v5[0-16]| |v10v11v12v13| +// +--------+--------+ - - - - +------------+ +// Accumulator + +static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int n_remain) { + K /= 4; + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + int32x4_t a0; + int32x4_t a1; + int32x4_t b0; + int32x4_t b0a; + int32x4_t a0a; + int32x4_t a1a; + + LDC = LDC * sizeof(int32_t); + int32_t* outptr0 = output; + int32_t* outptr1; + + size_t x0; + +// clang-format off +#define LOAD_LINE(v0, v1, v2, v3, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" v0 ", [%[x0]] \n" \ + "ldr q" v1 ", [%[x0], #16] \n" \ + "ldr q" v2 ", [%[x0], #32] \n" \ + "ldr q" v3 ", [%[x0], #48] \n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ldr q" v0 ", [%[x0]], #16\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ldr q" v1 ", [%[x0]], #16\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ldr q" v2 ", [%[x0]], #16\n" \ + "101" n ":\n" + + +#define LOAD_C \ + LOAD_LINE("6", "7", "8", "9", "0") \ + LOAD_LINE("10", "11", "12", "13", "1") \ + +#define STORE_LINE(v0, v1, v2, v3, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "str q" v0 ", [%[x0]] \n" \ + "str q" v1 ", [%[x0], #16] \n" \ + "str q" v2 ", [%[x0], #32] \n" \ + "str q" v3 ", [%[x0], #48] \n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "str q" v0 ", [%[x0]], #16\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "str q" v1 ", [%[x0]], #16\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "str q" v2 ", [%[x0]], #16\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("6", "7", "8", "9", "0") \ + STORE_LINE("10", "11", "12", "13", "1") + + // clang-format on + + asm volatile( + // load accumulator C + "add %[outptr1], %[outptr0], %x[LDC]\n" + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[a1], [%[a_ptr]], #16\n" + "sdot v6.4s , %[a0].16b, %[b0].4b[0]\n" + "sdot v7.4s , %[a0].16b, %[b0].4b[1]\n" + "sdot v8.4s, %[a0].16b, %[b0].4b[2]\n" + "sdot v9.4s, %[a0].16b, %[b0].4b[3]\n" + "sdot v10.4s, %[a1].16b, %[b0].4b[0]\n" + "sdot v11.4s, %[a1].16b, %[b0].4b[1]\n" + "sdot v12.4s, %[a1].16b, %[b0].4b[2]\n" + "sdot v13.4s, %[a1].16b, %[b0].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[a1], [%[a_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[a1a], [%[a_ptr]], #16\n" + "ldr %q[b0a], [%[b_ptr]], #16\n" + "sdot v6.4s , %[a0].16b, %[b0].4b[0]\n" + "sdot v7.4s , %[a0].16b, %[b0].4b[1]\n" + "sdot v8.4s, %[a0].16b, %[b0].4b[2]\n" + "sdot v9.4s, %[a0].16b, %[b0].4b[3]\n" + "sdot v10.4s, %[a1].16b, %[b0].4b[0]\n" + "sdot v11.4s, %[a1].16b, %[b0].4b[1]\n" + "sdot v12.4s, %[a1].16b, %[b0].4b[2]\n" + "sdot v13.4s, %[a1].16b, %[b0].4b[3]\n" + + "sdot v6.4s , %[a0a].16b, %[b0a].4b[0]\n" + "sdot v7.4s , %[a0a].16b, %[b0a].4b[1]\n" + "sdot v8.4s, %[a0a].16b, %[b0a].4b[2]\n" + "sdot v9.4s, %[a0a].16b, %[b0a].4b[3]\n" + "sdot v10.4s, %[a1a].16b, %[b0a].4b[0]\n" + "sdot v11.4s, %[a1a].16b, %[b0a].4b[1]\n" + "sdot v12.4s, %[a1a].16b, %[b0a].4b[2]\n" + "sdot v13.4s, %[a1a].16b, %[b0a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), + [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), + [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), + [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), + [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), + [x0] "=r"(x0) + : + : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", + "cc"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +// Overview of register layout: +// +// A (4x4)x2 cell of Rhs is stored in 8bit in q2-q3. +// A (4x4)x2 cell of Lhs is stored in 8bit in q0-q1 +// A 4x4 block of accumulators is stored in 8bit in q4-q7. +// +// +------------+ +// | v2[0-16]| +// Rhs +------------+ +// | v3[0-16]| +// +------------+ +// Lhs | | +// +// +--------+--------+ - - - - +------------+ +// |v0[0-16]|v4[0-16]| | v4 v5 v6 v7| +// +--------+--------+ - - - - +------------+ +// Accumulator + +static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int32_t* output, int LDC, bool is_first_k, int n_remain) { + K /= 4; + const int32_t* a_ptr = reinterpret_cast(packA); + const int32_t* b_ptr = reinterpret_cast(packB); + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = K / 2; + int32x4_t a0; + int32x4_t a0a; + int32x4_t b0; + int32x4_t b0a; + LDC = LDC * sizeof(int32_t); + + int32_t* outptr0 = output; + size_t x0; + +// clang-format off +#define LOAD_LINE(v0, v1, v2, v3, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ldr q" v0 ", [%[x0]] \n" \ + "ldr q" v1 ", [%[x0], #16] \n" \ + "ldr q" v2 ", [%[x0], #32] \n" \ + "ldr q" v3 ", [%[x0], #48] \n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ldr q" v0 ", [%[x0]], #16\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ldr q" v1 ", [%[x0]], #16\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ldr q" v2 ", [%[x0]], #16\n" \ + "101" n ":\n" + +#define LOAD_C \ + LOAD_LINE("4", "5", "6", "7", "0") + +#define STORE_LINE(v0, v1, v2, v3, n) \ + "mov %[x0], %[outptr" n "]\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 102" n "f\n" \ + "str q" v0 ", [%[x0]] \n" \ + "str q" v1 ", [%[x0], #16] \n" \ + "str q" v2 ", [%[x0], #32] \n" \ + "str q" v3 ", [%[x0], #48] \n" \ + "b 103" n "f\n" \ + "102" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 103" n "f\n" \ + "str q" v0 ", [%[x0]], #16\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 103" n "f\n" \ + "str q" v1 ", [%[x0]], #16\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 103" n "f\n" \ + "str q" v2 ", [%[x0]], #16\n" \ + "103" n ":\n" + +#define STORE_C \ + STORE_LINE("4", "5", "6", "7", "0") + // clang-format on + + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" // + LOAD_C // + + "b 2f\n" + + "1:\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + + "2: \n" + "cbz %w[oddk], 3f\n" + + // parse the oddk + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "sdot v4.4s , %[a0].16b, %[b0].4b[0]\n" + "sdot v5.4s , %[a0].16b, %[b0].4b[1]\n" + "sdot v6.4s, %[a0].16b, %[b0].4b[2]\n" + "sdot v7.4s, %[a0].16b, %[b0].4b[3]\n" + + "cbz %w[k], 4f\n" + // Loop proper + "3:\n" + "ldr %q[a0], [%[a_ptr]], #16\n" + "ldr %q[b0], [%[b_ptr]], #16\n" + "ldr %q[a0a], [%[a_ptr]], #16\n" + "ldr %q[b0a], [%[b_ptr]], #16\n" + "sdot v4.4s , %[a0].16b, %[b0].4b[0]\n" + "sdot v5.4s , %[a0].16b, %[b0].4b[1]\n" + "sdot v6.4s, %[a0].16b, %[b0].4b[2]\n" + "sdot v7.4s, %[a0].16b, %[b0].4b[3]\n" + "sdot v4.4s , %[a0a].16b, %[b0a].4b[0]\n" + "sdot v5.4s , %[a0a].16b, %[b0a].4b[1]\n" + "sdot v6.4s, %[a0a].16b, %[b0a].4b[2]\n" + "sdot v7.4s, %[a0a].16b, %[b0a].4b[3]\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), + [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k), + [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), + [x0] "=r"(x0) + : + : "v4", "v5", "v6", "v7", "memory", "cc"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C +} + +static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, + int ldin, int y0, int ymax, int k0, + int kmax) { + megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, + "mk4 matmul with m is not times of 4"); + megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, + "mk4 matmul with k is not times of 4"); + int y = y0; + int start_y = y0 / 4; + for (; y + 7 < ymax; y += 8, start_y += 2) { + const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2); + const int8_t* inptr1 = inptr0 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + + int K = kmax - k0; + //! read 2 * 4 in each row + for (; K > 3; K -= 4) { + interleave_2x4_4_b(inptr0, inptr1, outptr); + } + } + for (; y + 3 < ymax; y += 4, start_y ++) { + int K = kmax - k0; + const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2); + std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4); + } +} + +static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax) { + const int ksize = kmax - k0; + const int ksize12 = ksize * 12; + const int ksize4 = ksize * 4; + int8_t* outptr = out; + int8_t* outptr_base = out; + //! 4x4 block output start pos + int8_t* outptr_base4 = out + ((xmax - x0) / 12) * ksize12; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const int8_t* inptr = in + (k >> 2) * ldin + (x0 << 2); + prefetch_2x(inptr); + + int x = x0; + outptr = outptr_base; + for (; x + 11 < xmax; x += 12) { + std::memcpy(outptr, inptr, 48); + outptr += ksize12; + inptr += 48; + } + + outptr = outptr_base4; + for (; x + 3 < xmax; x += 4) { + std::memcpy(outptr, inptr, 16); + outptr += ksize4; + inptr += 16; + } + + if (x < xmax) { + int i = 0; + for (; i < xmax - x; i++) { + *outptr++ = *inptr++; + *outptr++ = *inptr++; + *outptr++ = *inptr++; + *outptr++ = *inptr++; + } + for (; i < 4; i++) { + *outptr++ = *inptr++; + *outptr++ = *inptr++; + *outptr++ = *inptr++; + *outptr++ = *inptr++; + } + } + + outptr_base += 48; + outptr_base4 += 16; + } +} + +} // namespace matmul_mk4_8x12x4 +} // namespace aarch64 +} // namespace megdnn + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp index ca3d006169faa80ccbc55534af5bc9c8b5ee988e..bf4a2f3aa26de77b2f7ea96846088b04c4daa4d9 100644 --- a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp @@ -14,12 +14,14 @@ #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" +#include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" #if __ARM_FEATURE_DOTPROD using namespace megdnn; using namespace aarch64; using namespace aarch64::matmul; +/* ====================== gemm_s8_8x12 ===========================*/ MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, @@ -109,5 +111,91 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, packA += K4; } } + +/* ====================== gemm_mk4_s8_8x12 ===========================*/ +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12); + +void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported"); + matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, + kmax); +} + +void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, + bool transpose) const { + megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported"); + matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); +} + +void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int32* C, + size_t LDC, bool is_first_k, const dt_int32*, + dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t B_INTERLEAVE = 12; + //! K is packed to times of 4 + K = round_up(K, 4); + const int K8 = (K << 3); + const int K12 = K * 12; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { + int32_t* output = C + ((m >> 2) * LDC); + + size_t n = 0; + const dt_int8* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, + is_first_k); + output += (B_INTERLEAVE << 2); + cur_packB += K12; + } + + for (; n < N; n += 4) { + matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, + is_first_k, std::min(N - n, 4)); + output += 16; + cur_packB += K4; + } + packA += K8; + } + + for (; m < M; m += 4) { + int32_t* output = C + ((m >> 2) * LDC); + const dt_int8* cur_packB = packB; + size_t n = 0; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, + is_first_k); + output += (B_INTERLEAVE << 2); + cur_packB += K12; + } + + for (; n < N; n += 4) { + matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, std::min(N - n, 4)); + output += 16; + cur_packB += K4; + } + packA += K4; + } +} + #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h index c633d9dc14f370f5d473ecc07ef671e98501ca18..f91a37e9cd7cfbc59a20c303c843b84f80900562 100644 --- a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h +++ b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h @@ -19,6 +19,9 @@ namespace matmul { MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_s8_8x12); +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, + gemm_mk4_s8_8x12); + } // namespace aarch64 } // namespace matmul } // namespace megdnn diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp index 58ec3ef6b3c162d5be887156bea9a357db1f438a..d616e665ad5fcdf667dfacd918c0fe9ad50e0f5f 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.cpp +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -29,6 +29,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #if __ARM_FEATURE_DOTPROD AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod; + AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; #else AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; @@ -64,6 +65,7 @@ public: #if __ARM_FEATURE_DOTPROD all_algos.emplace_back(&int8x8x32_gemv_dotprod); all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); + all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); #else all_algos.emplace_back(&int8x8x32_gemv); all_algos.emplace_back(&int8x8x32_k4x4x16); diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h index 1d4e0ab36f3e27186f76a2ead6906f54c9dc9118..dd8e01f9a4dd4b31c6f6c725e0b92a9522bc89cb 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.h +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -35,6 +35,8 @@ private: class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel // 8x12x4 DotProduct class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct + class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel + // 8x12x4 DotProduct #else class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 diff --git a/dnn/test/aarch64/matrix_mul.cpp b/dnn/test/aarch64/matrix_mul.cpp index 71d338f1636a4e2a89e439858cc341af9dccf119..d4dfae3717e7f2048d6b198525df61dd5c7eacbf 100644 --- a/dnn/test/aarch64/matrix_mul.cpp +++ b/dnn/test/aarch64/matrix_mul.cpp @@ -64,6 +64,18 @@ TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "AARCH64_INT8X8X32_K8X12X4_DOTPROD"); } + +TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_MK4_8X12X4_DOTPROD) { + std::vector args; + for (size_t m : {1, 2, 3, 4, 5, 6, 7, 10, 11}) + for (size_t n : {2, 3, 4, 5, 8, 12, 13, 14, 15, 16, 31}) + for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34}) + args.emplace_back(m, n, k, 0); + matrix_mul::check_matrix_mul( + dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), + "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD", + param::MatrixMul::Format::MK4_DOT, 1, 1e-3, std::move(args)); +} #else TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K4X4X16) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, @@ -460,6 +472,54 @@ TEST_F(AARCH64, BENCHMARK_GEMV_INT_8X8X32) { run(M, N, K); } +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT8X8X32_MK4_8X12X4) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_mk4(handle()); + benchmarker.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_display(false); + benchmarker.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X32_K8X12X4")); + + param.format = MatrixMul::Param::Format::MK4_DOT; + benchmarker_mk4.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD")); + benchmarker_mk4.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_display(false); + + auto run = [&](size_t M, size_t N, size_t K) { + auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; + auto mk_used = benchmarker_mk4.exec( + {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / + RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms " + "%f Gflops speedup_vs_normal: %f\n", + M, K, N, default_used, computations / default_used, mk_used, + computations / mk_used, default_used / mk_used); + }; + + run(256, 256, 128); + for (size_t k = 4; k <= 512; k *= 2) { + for (size_t m = 4; m <= 512; m *= 2) { + for (size_t n = 4; n <= 512; n *= 2) { + run(m, n, k); + } + } + std::cout << std::endl; + } +} #endif // __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC