From 6e70fa7a11fb35b29692757e3f3b73d22d1226a5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 19 Aug 2020 00:22:30 +0800 Subject: [PATCH] feat(dnn/arm): add fp32 asm gemm for a53 a55 and i8i8i16 gemm for a72 a53 GitOrigin-RevId: a049c33f2bf1e737630de263161d7e32be2ba645 --- dnn/src/aarch64/matrix_mul/algos.cpp | 157 ++ dnn/src/aarch64/matrix_mul/algos.h | 38 +- dnn/src/aarch64/matrix_mul/asm/common.h | 159 +- .../matrix_mul/fp32/kernel_general_8x12.h | 2031 +++++++++-------- .../matrix_mul/fp32/kernel_general_8x12_a53.h | 1331 +++++++++++ .../matrix_mul/fp32/kernel_general_8x12_a55.h | 1170 ++++++++++ .../aarch64/matrix_mul/fp32/kernel_mk4_8x12.h | 1623 ++++++------- .../matrix_mul/fp32/kernel_mk4_8x12_a53.h | 1260 ++++++++++ .../matrix_mul/fp32/kernel_mk4_8x12_a55.h | 1160 ++++++++++ dnn/src/aarch64/matrix_mul/fp32/strategy.cpp | 174 +- dnn/src/aarch64/matrix_mul/fp32/strategy.h | 3 +- .../int8x8x16/kernel_mk4_16x12x4_a53.h | 1265 ++++++++++ .../int8x8x16/kernel_mk4_4x4x8_a72.h | 387 ++++ .../aarch64/matrix_mul/int8x8x16/strategy.cpp | 162 +- .../aarch64/matrix_mul/int8x8x16/strategy.h | 8 +- dnn/src/aarch64/matrix_mul/opr_impl.cpp | 9 +- dnn/src/aarch64/matrix_mul/opr_impl.h | 33 +- dnn/src/arm_common/conv_bias/opr_impl.cpp | 1 - dnn/src/common/cpuinfo_arch_vendor.cpp | 4 +- dnn/src/common/cpuinfo_arch_vendor.h | 4 +- dnn/test/aarch64/matrix_mul.cpp | 111 +- dnn/test/arm_common/conv_bias.cpp | 10 +- .../arm_common/conv_bias_multi_thread.cpp | 167 +- dnn/test/arm_common/cpuinfo.cpp | 7 +- dnn/test/arm_common/cpuinfo_help.cpp | 17 + dnn/test/arm_common/cpuinfo_help.h | 47 + dnn/test/x86/cpuinfo.cpp | 6 +- 27 files changed, 9380 insertions(+), 1964 deletions(-) create mode 100644 dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h create mode 100644 dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h create mode 100644 dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h create mode 100644 dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h create mode 100644 dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h create mode 100644 dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h create mode 100644 dnn/test/arm_common/cpuinfo_help.cpp create mode 100644 dnn/test/arm_common/cpuinfo_help.h diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index 7d88e3e93..090cadce5 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -23,6 +23,9 @@ #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_impl.h" +#if MGB_ENABLE_CPUINFO +#include "cpuinfo.h" +#endif #include "midout.h" MIDOUT_DECL(megdnn_aarch64_matmul_kern) @@ -80,6 +83,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( } MIDOUT_END(); }; + return f32_kern_8x12; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, @@ -837,6 +841,159 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16, aarch64::matmul::gemm_s8x8x16_4x4, int8_t, int16_t); +/* ===================== Int8x8x16 K16x12x4 algo ===================== */ +namespace { +void int8x8x16_mk4_16x12x4_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int8x8x16_mk4_16x12x4_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type, + B_type, C_type); + megdnn::matmul::GemmInterleaved< + aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB, + strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x16(kern_size_param) && + kern_size_param.format == param::MatrixMul::Format::MK4 && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; +} + +bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::preferred( + const KernSizeParam&) const { +#if !MGB_ENABLE_CPUINFO + return false; +#else + auto arch = cpuinfo_get_current_core()->uarch; + bool little_core = arch == cpuinfo_uarch_cortex_a53 || + arch == cpuinfo_uarch_cortex_a55; + return little_core; +#endif +} + +size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x16MK4_16x12x4::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type, + B_type, C_type); + return megdnn::matmul::GemmInterleaved< + matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB, + strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern( + const KernSizeParam&) const { + return int8x8x16_mk4_16x12x4_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( + AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern, + "AlgoInt8x8x16MK4_16x12x4Impl"_hash, + aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t); + +/* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */ +namespace { +void int8x8x16_mk4_4x4x8_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int8x8x16_mk4_4x4x8_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type, + B_type, C_type); + megdnn::matmul::GemmInterleaved< + aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB, + strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x16(kern_size_param) && + kern_size_param.format == param::MatrixMul::Format::MK4 && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; +} + +bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::preferred( + const KernSizeParam&) const { +#if !MGB_ENABLE_CPUINFO + return false; +#else + auto arch = cpuinfo_get_current_core()->uarch; + bool little_core = arch == cpuinfo_uarch_cortex_a53 || + arch == cpuinfo_uarch_cortex_a55; + return !little_core; +#endif +} + +size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x16MK4_4x4x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type, + B_type, C_type); + return megdnn::matmul::GemmInterleaved< + matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB, + strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern( + const KernSizeParam&) const { + return int8x8x16_mk4_4x4x8_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8, + megdnn_aarch64_matmul_kern, + "AlgoInt8x8x16MK4_4x4x8_Impl"_hash, + aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72, + int8_t, int16_t); + /* ===================== Int16x16x32 K12x8x1 algo ===================== */ namespace { void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) { diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index d4b6ff6bb..54b6734d8 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -121,12 +122,9 @@ public: #else class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { - public: bool is_reproducible() const override { return true; } - const char* name() const override { - return "AARCH64_INT8X8X32_MK4_4X4X16"; - } + const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -188,6 +186,36 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; +class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { + return "AARCH64_INT8X8X16_MK4_16X12X4"; + } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::DEFAULT; } + + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + +class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::DEFAULT; } + + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { public: bool is_reproducible() const override { return true; } diff --git a/dnn/src/aarch64/matrix_mul/asm/common.h b/dnn/src/aarch64/matrix_mul/asm/common.h index c24395cc8..771129a9b 100644 --- a/dnn/src/aarch64/matrix_mul/asm/common.h +++ b/dnn/src/aarch64/matrix_mul/asm/common.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include @@ -993,8 +994,8 @@ static inline void interleave_4x1_4_s(const int32_t*& inptr0, template static inline void interleave_4x8_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { + const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert(sizeof(T) == 4, "only support size == 4"); asm volatile( "ld1 {v0.4s, v1.4s}, [%[inptr0]], #32\n" @@ -1140,8 +1141,8 @@ static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1, "stp q2, q6, [%[outptr], #64]\n" "stp q3, q7, [%[outptr], #96]\n" - : [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1), - [ outptr ] "+r"(outptr) + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); } @@ -1153,7 +1154,7 @@ static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) { "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n" - : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "memory"); } @@ -1550,7 +1551,7 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { "stp q2, q6, [%[outptr], #96] \n" "stp q10, q3, [%[outptr], #128] \n" "stp q7, q11, [%[outptr], #160] \n" - : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory"); @@ -1564,7 +1565,7 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { asm volatile( "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n" - : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "memory"); } @@ -1681,13 +1682,12 @@ static inline void transpose_12x4_1_s(const T*& inptr0, const T*& inptr1, "st1 {v3.4s,v4.4s,v5.4s}, [%[outptr]], #48\n" "st1 {v6.4s,v7.4s,v8.4s}, [%[outptr]], #48\n" "st1 {v24.4s,v25.4s,v26.4s}, [%[outptr]], #48\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), - [inptr8] "+r"(inptr8), [inptr9] "+r"(inptr9), - [inptr10] "+r"(inptr10), [inptr11] "+r"(inptr11), - [outptr] "+r"(outptr) + : + [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), + [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", @@ -1972,6 +1972,135 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, : "v0", "v1", "v2", "v3", "v4", "memory"); } +static inline void interleave_4x4_16x4_s8_s16(const int8_t* inptr0, + const int8_t* inptr1, + const int8_t* inptr2, + const int8_t* inptr3, + int16_t* outptr) { + int8x16_t row0 = vld1q_s8(inptr0); + int16x8_t row0_01 = vmovl_low_s8(row0); + int16x8_t row0_23 = vmovl_high_s8(row0); + int16x4_t row0_0 = vget_low_s16(row0_01); + int16x4_t row0_1 = vget_high_s16(row0_01); + int16x4_t row0_2 = vget_low_s16(row0_23); + int16x4_t row0_3 = vget_high_s16(row0_23); + + int8x16_t row1 = vld1q_s8(inptr1); + int16x8_t row1_01 = vmovl_low_s8(row1); + int16x8_t row1_23 = vmovl_high_s8(row1); + int16x4_t row1_0 = vget_low_s16(row1_01); + int16x4_t row1_1 = vget_high_s16(row1_01); + int16x4_t row1_2 = vget_low_s16(row1_23); + int16x4_t row1_3 = vget_high_s16(row1_23); + + int8x16_t row2 = vld1q_s8(inptr2); + int16x8_t row2_01 = vmovl_low_s8(row2); + int16x8_t row2_23 = vmovl_high_s8(row2); + int16x4_t row2_0 = vget_low_s16(row2_01); + int16x4_t row2_1 = vget_high_s16(row2_01); + int16x4_t row2_2 = vget_low_s16(row2_23); + int16x4_t row2_3 = vget_high_s16(row2_23); + + int8x16_t row3 = vld1q_s8(inptr3); + int16x8_t row3_01 = vmovl_low_s8(row3); + int16x8_t row3_23 = vmovl_high_s8(row3); + int16x4_t row3_0 = vget_low_s16(row3_01); + int16x4_t row3_1 = vget_high_s16(row3_01); + int16x4_t row3_2 = vget_low_s16(row3_23); + int16x4_t row3_3 = vget_high_s16(row3_23); + + vst1_s16(outptr, row0_0); + vst1_s16(outptr + 1 * 4, row1_0); + vst1_s16(outptr + 2 * 4, row2_0); + vst1_s16(outptr + 3 * 4, row3_0); + vst1_s16(outptr + 4 * 4, row0_1); + vst1_s16(outptr + 5 * 4, row1_1); + vst1_s16(outptr + 6 * 4, row2_1); + vst1_s16(outptr + 7 * 4, row3_1); + vst1_s16(outptr + 8 * 4, row0_2); + vst1_s16(outptr + 9 * 4, row1_2); + vst1_s16(outptr + 10 * 4, row2_2); + vst1_s16(outptr + 11 * 4, row3_2); + vst1_s16(outptr + 12 * 4, row0_3); + vst1_s16(outptr + 13 * 4, row1_3); + vst1_s16(outptr + 14 * 4, row2_3); + vst1_s16(outptr + 15 * 4, row3_3); +}; +static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0, + const int8_t* inptr1, + int16_t* outptr) { + int8x16_t row0 = vld1q_s8(inptr0); + int16x8_t row0_01 = vmovl_low_s8(row0); + int16x8_t row0_23 = vmovl_high_s8(row0); + int16x4_t row0_0 = vget_low_s16(row0_01); + int16x4_t row0_1 = vget_high_s16(row0_01); + int16x4_t row0_2 = vget_low_s16(row0_23); + int16x4_t row0_3 = vget_high_s16(row0_23); + + int8x16_t row1 = vld1q_s8(inptr1); + int16x8_t row1_01 = vmovl_low_s8(row1); + int16x8_t row1_23 = vmovl_high_s8(row1); + int16x4_t row1_0 = vget_low_s16(row1_01); + int16x4_t row1_1 = vget_high_s16(row1_01); + int16x4_t row1_2 = vget_low_s16(row1_23); + int16x4_t row1_3 = vget_high_s16(row1_23); + + vst1_s16(outptr, row0_0); + vst1_s16(outptr + 1 * 4, row1_0); + vst1_s16(outptr + 2 * 4, row0_1); + vst1_s16(outptr + 3 * 4, row1_1); + vst1_s16(outptr + 4 * 4, row0_2); + vst1_s16(outptr + 5 * 4, row1_2); + vst1_s16(outptr + 6 * 4, row0_3); + vst1_s16(outptr + 7 * 4, row1_3); +}; + +static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr, + int count) { + for (; count >= 32; count -= 32) { + int8x8_t in0 = vld1_s8(inptr); + int8x8_t in1 = vld1_s8(inptr + 1 * 8); + int8x8_t in2 = vld1_s8(inptr + 2 * 8); + int8x8_t in3 = vld1_s8(inptr + 3 * 8); + vst1q_s16(outptr, vmovl_s8(in0)); + vst1q_s16(outptr + 1 * 8, vmovl_s8(in1)); + vst1q_s16(outptr + 2 * 8, vmovl_s8(in2)); + vst1q_s16(outptr + 3 * 8, vmovl_s8(in3)); + inptr += 32; + outptr += 32; + } + for (; count >= 8; count -= 8) { + int8x8_t in0 = vld1_s8(inptr); + vst1q_s16(outptr, vmovl_s8(in0)); + inptr += 8; + outptr += 8; + } + for (; count > 0; --count) { + *outptr++ = (int16_t)(*inptr++); + } +} + +static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) { + static const uint8_t src_idx_buffer[16] = {0, 4, 8, 12, 1, 5, 9, 13, + 2, 6, 10, 14, 3, 7, 11, 15}; + static const uint8x16_t vtbl = vld1q_u8(&src_idx_buffer[0]); + int8x8x4_t input = vld4_s8(inptr0); + int8x16_t input2 = vqtbl1q_s8(vld1q_s8(inptr0 + 4 * 8), vtbl); + + vst1_s8(outptr, input.val[0]); + vst1q_lane_s32(reinterpret_cast(outptr + 8), + vreinterpretq_s32_s8(input2), 0); + vst1_s8(outptr + 1 * 12, input.val[1]); + vst1q_lane_s32(reinterpret_cast(outptr + 1 * 12 + 8), + vreinterpretq_s32_s8(input2), 1); + vst1_s8(outptr + 2 * 12, input.val[2]); + vst1q_lane_s32(reinterpret_cast(outptr + 2 * 12 + 8), + vreinterpretq_s32_s8(input2), 2); + vst1_s8(outptr + 3 * 12, input.val[3]); + vst1q_lane_s32(reinterpret_cast(outptr + 3 * 12 + 8), + vreinterpretq_s32_s8(input2), 3); +} + } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h index 985cbff4b..d317e88fb 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h @@ -6,53 +6,52 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" - namespace megdnn { namespace aarch64 { -namespace matmul_general_8x12 { - -// Overview of register layout: -// -// A 1x12 cell of Rhs is stored in 32bit in v2-v7 -// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) -// A 8x12 block of accumulators is stored in 32bit in v8-v31. -// -// +--------+--------+--------+ -// | v2[0-3]| v3[0-3]| v4[0-3]| -// | v5[0-3]| v6[0-3]| v7[0-3]| -// Rhs +--------+--------+--------+ -// -// | | | | -// -// Lhs | | | | -// -// +--+ --- - +--------+--------+--------+ -// |v0| | v8[0-3]| v9[0-3]|v10[0-3]| -// |v0| |v11[0-3]|v12[0-3]|v13[0-3]| -// |v0| |v14[0-3]|v15[0-3]|v16[0-3]| -// |v0| |v17[0-3]|v18[0-3]|v19[0-3]| -// |v1| |v20[0-3]|v21[0-3]|v22[0-3]| -// |v1| |v23[0-3]|v24[0-3]|v25[0-3]| -// |v1| |v26[0-3]|v27[0-3]|v28[0-3]| -// |v1| |v29[0-3]|v30[0-3]|v31[0-3]| -// +--+ --- - +--------+--------+--------+ -// -// Accumulator -void kern_8x12(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k) { - const float* a_ptr = packA; - const float* b_ptr = packB; - int oddk = (K & 1); - K = ((K + 1) / 2) - 1; - - LDC = LDC * sizeof(float); - register float* outptr asm("x0") = reinterpret_cast(output); +struct matmul_general_8x12 { + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // |v1| |v20[0-3]|v21[0-3]|v22[0-3]| + // |v1| |v23[0-3]|v24[0-3]|v25[0-3]| + // |v1| |v26[0-3]|v27[0-3]|v28[0-3]| + // |v1| |v29[0-3]|v30[0-3]|v31[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + static void kern_8x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = reinterpret_cast(output); // clang-format off #define LOAD_LINE(v0, v1, v2, n) \ @@ -79,285 +78,286 @@ void kern_8x12(const float* packA, const float* packB, int K, float* output, STORE_LINE("20", "21", "22", "4") \ STORE_LINE("23", "24", "25", "5") \ STORE_LINE("26", "27", "28", "6") \ - STORE_LINE("29", "30", "31", "7") \ + STORE_LINE("29", "30", "31", "7") \ // clang-format on - asm volatile( - // load accumulator C - "add x1, x0, %x[LDC]\n" - "add x2, x1, %x[LDC]\n" - "add x3, x2, %x[LDC]\n" - "add x4, x3, %x[LDC]\n" - "add x5, x4, %x[LDC]\n" - "add x6, x5, %x[LDC]\n" - "add x7, x6, %x[LDC]\n" - - "cmp %w[is_first_k], #1\n" - "beq 1f\n" LOAD_C - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "b 2f\n" - - "1:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "eor v9.16b, v9.16b, v9.16b\n" - "eor v10.16b, v10.16b, v10.16b\n" - "prfm pstl1keep, [x0]\n" - "eor v11.16b, v11.16b, v11.16b\n" - "eor v12.16b, v12.16b, v12.16b\n" - "eor v13.16b, v13.16b, v13.16b\n" - "prfm pstl1keep, [x1]\n" - "eor v14.16b, v14.16b, v14.16b\n" - "eor v15.16b, v15.16b, v15.16b\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" - "eor v16.16b, v16.16b, v16.16b\n" - "eor v17.16b, v17.16b, v17.16b\n" - "prfm pstl1keep, [x2]\n" - "eor v18.16b, v18.16b, v18.16b\n" - "eor v19.16b, v19.16b, v19.16b\n" - "eor v20.16b, v20.16b, v20.16b\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "eor v21.16b, v21.16b, v21.16b\n" - "eor v22.16b, v22.16b, v22.16b\n" - "prfm pstl1keep, [x3]\n" - "eor v23.16b, v23.16b, v23.16b\n" - "eor v24.16b, v24.16b, v24.16b\n" - "eor v25.16b, v25.16b, v25.16b\n" - "prfm pstl1keep, [x4]\n" - "eor v26.16b, v26.16b, v26.16b\n" - "eor v27.16b, v27.16b, v27.16b\n" - "eor v28.16b, v28.16b, v28.16b\n" - "prfm pstl1keep, [x5]\n" - "eor v29.16b, v29.16b, v29.16b\n" - "eor v30.16b, v30.16b, v30.16b\n" - "eor v31.16b, v31.16b, v31.16b\n" - "prfm pstl1keep, [x6]\n" - - "2: \n" - "cmp %w[K], #0\n" - "beq 4f\n" - - "3:\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "fmla v12.4s, v3.4s, v0.s[1]\n" - "fmla v13.4s, v4.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v15.4s, v3.4s, v0.s[2]\n" - "fmla v16.4s, v4.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "prfm pldl1keep, [%[a_ptr], #64]\n" - "fmla v18.4s, v3.4s, v0.s[3]\n" - "fmla v19.4s, v4.4s, v0.s[3]\n" - "fmla v20.4s, v2.4s, v1.s[0]\n" - "fmla v21.4s, v3.4s, v1.s[0]\n" - "fmla v22.4s, v4.4s, v1.s[0]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v23.4s, v2.4s, v1.s[1]\n" - "fmla v24.4s, v3.4s, v1.s[1]\n" - "fmla v25.4s, v4.4s, v1.s[1]\n" - "fmla v26.4s, v2.4s, v1.s[2]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" - "fmla v27.4s, v3.4s, v1.s[2]\n" - "fmla v28.4s, v4.4s, v1.s[2]\n" - "fmla v29.4s, v2.4s, v1.s[3]\n" - "prfm pldl1keep, [%[b_ptr], #64]\n" - "fmla v30.4s, v3.4s, v1.s[3]\n" - "fmla v31.4s, v4.4s, v1.s[3]\n" - - "fmla v8.4s, v5.4s, v0.s[0]\n" - "fmla v9.4s, v6.4s, v0.s[0]\n" - "fmla v10.4s, v7.4s, v0.s[0]\n" - "fmla v11.4s, v5.4s, v0.s[1]\n" - "fmla v12.4s, v6.4s, v0.s[1]\n" - "fmla v13.4s, v7.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v5.4s, v0.s[2]\n" - "fmla v15.4s, v6.4s, v0.s[2]\n" - "fmla v16.4s, v7.4s, v0.s[2]\n" - "fmla v17.4s, v5.4s, v0.s[3]\n" - "fmla v18.4s, v6.4s, v0.s[3]\n" - "fmla v19.4s, v7.4s, v0.s[3]\n" - "fmla v20.4s, v5.4s, v1.s[0]\n" - "fmla v21.4s, v6.4s, v1.s[0]\n" - "fmla v22.4s, v7.4s, v1.s[0]\n" - "fmla v23.4s, v5.4s, v1.s[1]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v24.4s, v6.4s, v1.s[1]\n" - "fmla v25.4s, v7.4s, v1.s[1]\n" - "fmla v26.4s, v5.4s, v1.s[2]\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" - "fmla v27.4s, v6.4s, v1.s[2]\n" - "fmla v28.4s, v7.4s, v1.s[2]\n" - "fmla v29.4s, v5.4s, v1.s[3]\n" - "fmla v30.4s, v6.4s, v1.s[3]\n" - "prfm pldl1keep, [%[b_ptr], #64]\n" - "fmla v31.4s, v7.4s, v1.s[3]\n" - - "subs %w[K], %w[K], #1\n" - "bne 3b\n" - - "4:\n" - "cmp %w[oddk], #1\n" - "beq 5f\n" - - // Even tail - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "fmla v12.4s, v3.4s, v0.s[1]\n" - "fmla v13.4s, v4.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v15.4s, v3.4s, v0.s[2]\n" - "fmla v16.4s, v4.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "fmla v18.4s, v3.4s, v0.s[3]\n" - "fmla v19.4s, v4.4s, v0.s[3]\n" - "fmla v20.4s, v2.4s, v1.s[0]\n" - "fmla v21.4s, v3.4s, v1.s[0]\n" - "fmla v22.4s, v4.4s, v1.s[0]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v23.4s, v2.4s, v1.s[1]\n" - "fmla v24.4s, v3.4s, v1.s[1]\n" - "fmla v25.4s, v4.4s, v1.s[1]\n" - "fmla v26.4s, v2.4s, v1.s[2]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" - "fmla v27.4s, v3.4s, v1.s[2]\n" - "fmla v28.4s, v4.4s, v1.s[2]\n" - "fmla v29.4s, v2.4s, v1.s[3]\n" - "fmla v30.4s, v3.4s, v1.s[3]\n" - "fmla v31.4s, v4.4s, v1.s[3]\n" - - "fmla v8.4s, v5.4s, v0.s[0]\n" - "fmla v9.4s, v6.4s, v0.s[0]\n" - "fmla v10.4s, v7.4s, v0.s[0]\n" - "fmla v11.4s, v5.4s, v0.s[1]\n" - "fmla v12.4s, v6.4s, v0.s[1]\n" - "fmla v13.4s, v7.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v5.4s, v0.s[2]\n" - "fmla v15.4s, v6.4s, v0.s[2]\n" - "fmla v16.4s, v7.4s, v0.s[2]\n" - "fmla v17.4s, v5.4s, v0.s[3]\n" - "st1 {v8.4s, v9.4s, v10.4s}, [x0]\n" - "fmla v18.4s, v6.4s, v0.s[3]\n" - "fmla v19.4s, v7.4s, v0.s[3]\n" - "fmla v20.4s, v5.4s, v1.s[0]\n" - "fmla v21.4s, v6.4s, v1.s[0]\n" - "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" - "fmla v22.4s, v7.4s, v1.s[0]\n" - "fmla v23.4s, v5.4s, v1.s[1]\n" - "fmla v24.4s, v6.4s, v1.s[1]\n" - "fmla v25.4s, v7.4s, v1.s[1]\n" - "st1 {v14.4s, v15.4s, v16.4s}, [x2]\n" - "fmla v26.4s, v5.4s, v1.s[2]\n" - "fmla v27.4s, v6.4s, v1.s[2]\n" - "fmla v28.4s, v7.4s, v1.s[2]\n" - "fmla v29.4s, v5.4s, v1.s[3]\n" - "fmla v30.4s, v6.4s, v1.s[3]\n" - "fmla v31.4s, v7.4s, v1.s[3]\n" - "st1 {v17.4s, v18.4s, v19.4s}, [x3]\n" - "st1 {v20.4s, v21.4s, v22.4s}, [x4]\n" - "st1 {v23.4s, v24.4s, v25.4s}, [x5]\n" - "st1 {v26.4s, v27.4s, v28.4s}, [x6]\n" - "st1 {v29.4s, v30.4s, v31.4s}, [x7]\n" - "b 6f\n" - - // odd tail - "5:\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "fmla v12.4s, v3.4s, v0.s[1]\n" - "fmla v13.4s, v4.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v15.4s, v3.4s, v0.s[2]\n" - "fmla v16.4s, v4.4s, v0.s[2]\n" - "st1 {v8.4s, v9.4s, v10.4s}, [x0]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "fmla v18.4s, v3.4s, v0.s[3]\n" - "fmla v19.4s, v4.4s, v0.s[3]\n" - "fmla v20.4s, v2.4s, v1.s[0]\n" - "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" - "fmla v21.4s, v3.4s, v1.s[0]\n" - "fmla v22.4s, v4.4s, v1.s[0]\n" - "fmla v23.4s, v2.4s, v1.s[1]\n" - "fmla v24.4s, v3.4s, v1.s[1]\n" - "st1 {v14.4s, v15.4s, v16.4s}, [x2]\n" - "fmla v25.4s, v4.4s, v1.s[1]\n" - "fmla v26.4s, v2.4s, v1.s[2]\n" - "fmla v27.4s, v3.4s, v1.s[2]\n" - "fmla v28.4s, v4.4s, v1.s[2]\n" - "fmla v29.4s, v2.4s, v1.s[3]\n" - "st1 {v17.4s, v18.4s, v19.4s}, [x3]\n" - "fmla v30.4s, v3.4s, v1.s[3]\n" - "fmla v31.4s, v4.4s, v1.s[3]\n" - "st1 {v20.4s, v21.4s, v22.4s}, [x4]\n" - "st1 {v23.4s, v24.4s, v25.4s}, [x5]\n" - "st1 {v26.4s, v27.4s, v28.4s}, [x6]\n" - "st1 {v29.4s, v30.4s, v31.4s}, [x7]\n" - - "6:\n" - - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), - [outptr] "+r"(outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", - "cc", "memory"); + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [x0]\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "prfm pstl1keep, [x1]\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "prfm pstl1keep, [x2]\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "prfm pstl1keep, [x3]\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "prfm pstl1keep, [x4]\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "prfm pstl1keep, [x5]\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + "prfm pstl1keep, [x6]\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "prfm pldl1keep, [%[a_ptr], #64]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v21.4s, v3.4s, v1.s[0]\n" + "fmla v22.4s, v4.4s, v1.s[0]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v24.4s, v3.4s, v1.s[1]\n" + "fmla v25.4s, v4.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" + "fmla v27.4s, v3.4s, v1.s[2]\n" + "fmla v28.4s, v4.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + "prfm pldl1keep, [%[b_ptr], #64]\n" + "fmla v30.4s, v3.4s, v1.s[3]\n" + "fmla v31.4s, v4.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "fmla v9.4s, v6.4s, v0.s[0]\n" + "fmla v10.4s, v7.4s, v0.s[0]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v15.4s, v6.4s, v0.s[2]\n" + "fmla v16.4s, v7.4s, v0.s[2]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v18.4s, v6.4s, v0.s[3]\n" + "fmla v19.4s, v7.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v21.4s, v6.4s, v1.s[0]\n" + "fmla v22.4s, v7.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v24.4s, v6.4s, v1.s[1]\n" + "fmla v25.4s, v7.4s, v1.s[1]\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "fmla v27.4s, v6.4s, v1.s[2]\n" + "fmla v28.4s, v7.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + "fmla v30.4s, v6.4s, v1.s[3]\n" + "prfm pldl1keep, [%[b_ptr], #64]\n" + "fmla v31.4s, v7.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v21.4s, v3.4s, v1.s[0]\n" + "fmla v22.4s, v4.4s, v1.s[0]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v24.4s, v3.4s, v1.s[1]\n" + "fmla v25.4s, v4.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" + "fmla v27.4s, v3.4s, v1.s[2]\n" + "fmla v28.4s, v4.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + "fmla v30.4s, v3.4s, v1.s[3]\n" + "fmla v31.4s, v4.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "fmla v9.4s, v6.4s, v0.s[0]\n" + "fmla v10.4s, v7.4s, v0.s[0]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v15.4s, v6.4s, v0.s[2]\n" + "fmla v16.4s, v7.4s, v0.s[2]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "st1 {v8.4s, v9.4s, v10.4s}, [x0]\n" + "fmla v18.4s, v6.4s, v0.s[3]\n" + "fmla v19.4s, v7.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v21.4s, v6.4s, v1.s[0]\n" + "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" + "fmla v22.4s, v7.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "fmla v24.4s, v6.4s, v1.s[1]\n" + "fmla v25.4s, v7.4s, v1.s[1]\n" + "st1 {v14.4s, v15.4s, v16.4s}, [x2]\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "fmla v27.4s, v6.4s, v1.s[2]\n" + "fmla v28.4s, v7.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + "fmla v30.4s, v6.4s, v1.s[3]\n" + "fmla v31.4s, v7.4s, v1.s[3]\n" + "st1 {v17.4s, v18.4s, v19.4s}, [x3]\n" + "st1 {v20.4s, v21.4s, v22.4s}, [x4]\n" + "st1 {v23.4s, v24.4s, v25.4s}, [x5]\n" + "st1 {v26.4s, v27.4s, v28.4s}, [x6]\n" + "st1 {v29.4s, v30.4s, v31.4s}, [x7]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "st1 {v8.4s, v9.4s, v10.4s}, [x0]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" + "fmla v21.4s, v3.4s, v1.s[0]\n" + "fmla v22.4s, v4.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v24.4s, v3.4s, v1.s[1]\n" + "st1 {v14.4s, v15.4s, v16.4s}, [x2]\n" + "fmla v25.4s, v4.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v27.4s, v3.4s, v1.s[2]\n" + "fmla v28.4s, v4.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + "st1 {v17.4s, v18.4s, v19.4s}, [x3]\n" + "fmla v30.4s, v3.4s, v1.s[3]\n" + "fmla v31.4s, v4.4s, v1.s[3]\n" + "st1 {v20.4s, v21.4s, v22.4s}, [x4]\n" + "st1 {v23.4s, v24.4s, v25.4s}, [x5]\n" + "st1 {v26.4s, v27.4s, v28.4s}, [x6]\n" + "st1 {v29.4s, v30.4s, v31.4s}, [x7]\n" + + "6:\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", + "x6", "x7", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE #undef STORE_C -} - -// Overview of register layout: -// -// A 1x12 cell of Rhs is stored in 32bit in v2-v7 -// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) -// A 8x12 block of accumulators is stored in 32bit in v8-v31. -// -// +--------+ -// | v2[0-3]| -// | v5[0-3]| -// Rhs +--------+ -// -// | | -// -// Lhs | | -// -// +--+ --- - +--------+ -// |v0| | v8[0-3]| -// |v0| |v11[0-3]| -// |v0| |v14[0-3]| -// |v0| |v17[0-3]| -// |v1| |v20[0-3]| -// |v1| |v23[0-3]| -// |v1| |v26[0-3]| -// |v1| |v29[0-3]| -// +--+ --- - +--------+ -// -// Accumulator -void kern_8x4(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k, int n_remain) { - const float* a_ptr = packA; - const float* b_ptr = packB; - int oddk = (K & 1); - K = ((K + 1) / 2) - 1; - - LDC = LDC * sizeof(float); - register float* outptr asm("x0") = reinterpret_cast(output); + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+ + // | v2[0-3]| + // | v5[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // |v1| |v20[0-3]| + // |v1| |v23[0-3]| + // |v1| |v26[0-3]| + // |v1| |v29[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_8x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = reinterpret_cast(output); // clang-format off #define LOAD_LINE(v0, n) \ @@ -414,158 +414,159 @@ void kern_8x4(const float* packA, const float* packB, int K, float* output, STORE_LINE("20", "4") \ STORE_LINE("23", "5") \ STORE_LINE("26", "6") \ - STORE_LINE("29", "7") \ + STORE_LINE("29", "7") \ // clang-format on - asm volatile( - // load accumulator C - "add x1, x0, %x[LDC]\n" - "add x2, x1, %x[LDC]\n" - "add x3, x2, %x[LDC]\n" - "add x4, x3, %x[LDC]\n" - "add x5, x4, %x[LDC]\n" - "add x6, x5, %x[LDC]\n" - "add x7, x6, %x[LDC]\n" - - "cmp %w[is_first_k], #1\n" - "beq 1f\n" LOAD_C - - "b 2f\n" - - "1:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "eor v11.16b, v11.16b, v11.16b\n" - "eor v14.16b, v14.16b, v14.16b\n" - "eor v17.16b, v17.16b, v17.16b\n" - "eor v20.16b, v20.16b, v20.16b\n" - "eor v23.16b, v23.16b, v23.16b\n" - "eor v26.16b, v26.16b, v26.16b\n" - "eor v29.16b, v29.16b, v29.16b\n" - - "2: \n" - "ld1 {v2.4s}, [%[b_ptr]], 16\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "cmp %w[K], #0\n" - "beq 4f\n" - - "3:\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "ld1 {v5.4s}, [%[b_ptr]], 16\n" - "fmla v20.4s, v2.4s, v1.s[0]\n" - "fmla v23.4s, v2.4s, v1.s[1]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v26.4s, v2.4s, v1.s[2]\n" - "fmla v29.4s, v2.4s, v1.s[3]\n" - - "fmla v8.4s, v5.4s, v0.s[0]\n" - "fmla v11.4s, v5.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v5.4s, v0.s[2]\n" - "fmla v17.4s, v5.4s, v0.s[3]\n" - "fmla v20.4s, v5.4s, v1.s[0]\n" - "fmla v23.4s, v5.4s, v1.s[1]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v26.4s, v5.4s, v1.s[2]\n" - "ld1 {v2.4s}, [%[b_ptr]], 16\n" - "fmla v29.4s, v5.4s, v1.s[3]\n" - - "subs %w[K], %w[K], #1\n" - "bne 3b\n" - - "4:\n" - "cmp %w[oddk], #1\n" - "beq 5f\n" - - // Even tail - "fmla v8.4s, v2.4s, v0.s[0]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "ld1 {v5.4s}, [%[b_ptr]], 16\n" - "fmla v20.4s, v2.4s, v1.s[0]\n" - "fmla v23.4s, v2.4s, v1.s[1]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v26.4s, v2.4s, v1.s[2]\n" - "fmla v29.4s, v2.4s, v1.s[3]\n" - - "fmla v8.4s, v5.4s, v0.s[0]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v11.4s, v5.4s, v0.s[1]\n" - "fmla v14.4s, v5.4s, v0.s[2]\n" - "fmla v17.4s, v5.4s, v0.s[3]\n" - "fmla v20.4s, v5.4s, v1.s[0]\n" - "fmla v23.4s, v5.4s, v1.s[1]\n" - "fmla v26.4s, v5.4s, v1.s[2]\n" - "fmla v29.4s, v5.4s, v1.s[3]\n" - - "b 6f\n" - - // odd tail - "5:\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "fmla v20.4s, v2.4s, v1.s[0]\n" - "fmla v23.4s, v2.4s, v1.s[1]\n" - "fmla v26.4s, v2.4s, v1.s[2]\n" - "fmla v29.4s, v2.4s, v1.s[3]\n" - - "6:\n" STORE_C - - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), - [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) - : - : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", - "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", - "memory"); + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + + "2: \n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", + "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE #undef STORE_C -} - - -// Overview of register layout: -// -// A 1x12 cell of Rhs is stored in 32bit in v2-v7 -// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) -// A 8x12 block of accumulators is stored in 32bit in v8-v31. -// -// +--------+--------+--------+ -// | v2[0-3]| v3[0-3]| v4[0-3]| -// | v5[0-3]| v6[0-3]| v7[0-3]| -// Rhs +--------+--------+--------+ -// -// | | | | -// -// Lhs | | | | -// -// +--+ --- - +--------+--------+--------+ -// |v0| | v8[0-3]| v9[0-3]|v10[0-3]| -// |v0| |v11[0-3]|v12[0-3]|v13[0-3]| -// |v0| |v14[0-3]|v15[0-3]|v16[0-3]| -// |v0| |v17[0-3]|v18[0-3]|v19[0-3]| -// +--+ --- - +--------+--------+--------+ -// -// Accumulator -void kern_4x12(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k, int m_remain) { - const float* a_ptr = packA; - const float* b_ptr = packB; - int oddk = (K & 1); - K = ((K + 1) / 2) - 1; - - LDC = LDC * sizeof(float); - register float* outptr asm("x0") = output; + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + static void kern_4x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int m_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = output; // clang-format off #define LOAD_LINE(v0, v1, v2, n) \ @@ -596,172 +597,173 @@ void kern_4x12(const float* packA, const float* packB, int K, float* output, STORE_LINE("14","15","16", "2") \ STORE_LINE("17","18","19", "3") \ "105:\n" - // clang-format on - - asm volatile( - // load accumulator C - "add x1, x0, %x[LDC]\n" - "add x2, x1, %x[LDC]\n" - "add x3, x2, %x[LDC]\n" - - "cmp %w[is_first_k], #1\n" - "beq 1f\n" LOAD_C - - "b 2f\n" - - "1:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "eor v9.16b, v9.16b, v9.16b\n" - "eor v10.16b, v10.16b, v10.16b\n" - "eor v11.16b, v11.16b, v11.16b\n" - "eor v12.16b, v12.16b, v12.16b\n" - "eor v13.16b, v13.16b, v13.16b\n" - "eor v14.16b, v14.16b, v14.16b\n" - "eor v15.16b, v15.16b, v15.16b\n" - "eor v16.16b, v16.16b, v16.16b\n" - "eor v17.16b, v17.16b, v17.16b\n" - "eor v18.16b, v18.16b, v18.16b\n" - "eor v19.16b, v19.16b, v19.16b\n" - - "2: \n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "cmp %w[K], #0\n" - "beq 4f\n" - - "3:\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" - "fmla v12.4s, v3.4s, v0.s[1]\n" - "fmla v13.4s, v4.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v15.4s, v3.4s, v0.s[2]\n" - "fmla v16.4s, v4.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "fmla v18.4s, v3.4s, v0.s[3]\n" - "fmla v19.4s, v4.4s, v0.s[3]\n" - - "fmla v8.4s, v5.4s, v1.s[0]\n" - "fmla v9.4s, v6.4s, v1.s[0]\n" - "fmla v10.4s, v7.4s, v1.s[0]\n" - "fmla v11.4s, v5.4s, v1.s[1]\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" - "fmla v12.4s, v6.4s, v1.s[1]\n" - "fmla v13.4s, v7.4s, v1.s[1]\n" - "fmla v14.4s, v5.4s, v1.s[2]\n" - "fmla v15.4s, v6.4s, v1.s[2]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v16.4s, v7.4s, v1.s[2]\n" - "fmla v17.4s, v5.4s, v1.s[3]\n" - "fmla v18.4s, v6.4s, v1.s[3]\n" - "fmla v19.4s, v7.4s, v1.s[3]\n" - - "subs %w[K], %w[K], #1\n" - "bne 3b\n" - - "4:\n" - "cmp %w[oddk], #1\n" - "beq 5f\n" - - // Even tail - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" - "fmla v12.4s, v3.4s, v0.s[1]\n" - "fmla v13.4s, v4.4s, v0.s[1]\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v15.4s, v3.4s, v0.s[2]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v16.4s, v4.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "fmla v18.4s, v3.4s, v0.s[3]\n" - "fmla v19.4s, v4.4s, v0.s[3]\n" - - "fmla v8.4s, v5.4s, v1.s[0]\n" - "fmla v9.4s, v6.4s, v1.s[0]\n" - "fmla v10.4s, v7.4s, v1.s[0]\n" - "fmla v11.4s, v5.4s, v1.s[1]\n" - "fmla v12.4s, v6.4s, v1.s[1]\n" - "fmla v13.4s, v7.4s, v1.s[1]\n" - "fmla v14.4s, v5.4s, v1.s[2]\n" - "fmla v15.4s, v6.4s, v1.s[2]\n" - "fmla v16.4s, v7.4s, v1.s[2]\n" - "fmla v17.4s, v5.4s, v1.s[3]\n" - "fmla v18.4s, v6.4s, v1.s[3]\n" - "fmla v19.4s, v7.4s, v1.s[3]\n" - - "b 6f\n" - - // odd tail - "5:\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "fmla v12.4s, v3.4s, v0.s[1]\n" - "fmla v13.4s, v4.4s, v0.s[1]\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v15.4s, v3.4s, v0.s[2]\n" - "fmla v16.4s, v4.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "fmla v18.4s, v3.4s, v0.s[3]\n" - "fmla v19.4s, v4.4s, v0.s[3]\n" - - "6:\n" STORE_C - - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), - [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "x1", "x2", "x3", "x10", "cc", "memory"); + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v9.4s, v6.4s, v1.s[0]\n" + "fmla v10.4s, v7.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v15.4s, v6.4s, v1.s[2]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v16.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v1.s[3]\n" + "fmla v19.4s, v7.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], 48\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v9.4s, v6.4s, v1.s[0]\n" + "fmla v10.4s, v7.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v15.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v1.s[3]\n" + "fmla v19.4s, v7.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr), + [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "x1", "x2", "x3", "x10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE #undef STORE_C -} - - -// Overview of register layout: -// -// A 2x4 cell of Rhs is stored in 32bit in v2 - v3 -// A 4x2 cell of Lhs is stored in 32bit in v0 - v1 -// A 4x4 block of accumulators is stored in 32bit in v4-v6 -// -// +--------+ -// | v2[0-3]| -// | v5[0-3]| -// Rhs +--------+ -// -// | | -// -// Lhs | | -// -// +--+ --- - +--------+ -// |v0| | v8[0-3]| -// |v0| |v11[0-3]| -// |v0| |v14[0-3]| -// |v0| |v17[0-3]| -// +--+ --- - +--------+ -// -// Accumulator -void kern_4x4(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k, int m_remain, int n_remain) { - const float* a_ptr = packA; - const float* b_ptr = packB; - int oddk = (K & 1); - K = ((K + 1) / 2) - 1; - - LDC = LDC * sizeof(float); - register float* outptr asm("x0") = output; + } + + // Overview of register layout: + // + // A 2x4 cell of Rhs is stored in 32bit in v2 - v3 + // A 4x2 cell of Lhs is stored in 32bit in v0 - v1 + // A 4x4 block of accumulators is stored in 32bit in v4-v6 + // + // +--------+ + // | v2[0-3]| + // | v5[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_4x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = output; // clang-format off #define LOAD_LINE(v0, n) \ @@ -820,427 +822,436 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, STORE_LINE("14", "2") \ STORE_LINE("17", "3") \ "105:\n" - // clang-format on - - asm volatile( - // load accumulator C - "add x1, x0, %x[LDC]\n" - "add x2, x1, %x[LDC]\n" - "add x3, x2, %x[LDC]\n" - - "cmp %w[is_first_k], #1\n" - "beq 1f\n" LOAD_C - - "b 2f\n" - - "1:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "eor v11.16b, v11.16b, v11.16b\n" - "eor v14.16b, v14.16b, v14.16b\n" - "eor v17.16b, v17.16b, v17.16b\n" - - "2: \n" - "ld1 {v2.4s}, [%[b_ptr]], 16\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "cmp %w[K], #0\n" - "beq 4f\n" - - "3:\n" - "ld1 {v5.4s}, [%[b_ptr]], 16\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v8.4s, v5.4s, v1.s[0]\n" - "fmla v11.4s, v5.4s, v1.s[1]\n" - "ld1 {v2.4s}, [%[b_ptr]], 16\n" - "fmla v14.4s, v5.4s, v1.s[2]\n" - "fmla v17.4s, v5.4s, v1.s[3]\n" - - "subs %w[K], %w[K], #1\n" - "bne 3b\n" - - "4:\n" - "cmp %w[oddk], #1\n" - "beq 5f\n" - - // Even tail - "ld1 {v5.4s}, [%[b_ptr]], 16\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - - "fmla v8.4s, v5.4s, v1.s[0]\n" - "fmla v11.4s, v5.4s, v1.s[1]\n" - "fmla v14.4s, v5.4s, v1.s[2]\n" - "fmla v17.4s, v5.4s, v1.s[3]\n" - - "b 6f\n" - - // odd tail - "5:\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "fmla v29.4s, v2.4s, v1.s[3]\n" - - "6:\n" STORE_C - - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), - [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), - [m_remain] "+r"(m_remain) - : - : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", - "x3", "x10", "cc", "memory"); + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + + "2: \n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr), + [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", + "x3", "x10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE #undef STORE_C -} - -void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, int y0, - int ymax, int k0, int kmax) { - float zerobuff[8]; - std::memset(zerobuff, 0, sizeof(float) * 8); - constexpr int PACK_SIZE_32 = 4*8; - constexpr int PACK_SIZE_16 = 4*4; - int y = y0; - for (; y + 7 < ymax; y += 8) { - const float* inptr0 = inptr + y * ldin + k0; - const float* inptr1 = inptr0 + ldin; - const float* inptr2 = inptr1 + ldin; - const float* inptr3 = inptr2 + ldin; - const float* inptr4 = inptr3 + ldin; - const float* inptr5 = inptr4 + ldin; - const float* inptr6 = inptr5 + ldin; - const float* inptr7 = inptr6 + ldin; - prefetch_2x(inptr0); - prefetch_2x(inptr1); - prefetch_2x(inptr2); - prefetch_2x(inptr3); - prefetch_2x(inptr4); - prefetch_2x(inptr5); - prefetch_2x(inptr6); - prefetch_2x(inptr7); - int x = (kmax - k0); - for (; x > 3; x -= 4) { - transpose_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); - outptr += PACK_SIZE_32; - } - for (; x > 0; x--) { - *outptr++ = *inptr0++; - *outptr++ = *inptr1++; - *outptr++ = *inptr2++; - *outptr++ = *inptr3++; - *outptr++ = *inptr4++; - *outptr++ = *inptr5++; - *outptr++ = *inptr6++; - *outptr++ = *inptr7++; - } } - for (; y < ymax; y += 4) { - const float* inptr0 = inptr + y * ldin + k0; - const float* inptr1 = inptr0 + ldin; - const float* inptr2 = inptr1 + ldin; - const float* inptr3 = inptr2 + ldin; - - prefetch_2x(inptr0); - prefetch_2x(inptr1); - prefetch_2x(inptr2); - prefetch_2x(inptr3); - - int K = (kmax - k0); - for (; K > 3; K -= 4) { - if ((y + 3) >= ymax) { - switch ((y + 3) - ymax) { - /* Everything falls through in here */ - case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU - case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU - case 0: - inptr3 = zerobuff; - break; - default: - megdnn_assert(0); - } + static void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, + int y0, int ymax, int k0, int kmax) { + float zerobuff[8]; + std::memset(zerobuff, 0, sizeof(float) * 8); + constexpr int PACK_SIZE_32 = 4 * 8; + constexpr int PACK_SIZE_16 = 4 * 4; + int y = y0; + for (; y + 7 < ymax; y += 8) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + const float* inptr4 = inptr3 + ldin; + const float* inptr5 = inptr4 + ldin; + const float* inptr6 = inptr5 + ldin; + const float* inptr7 = inptr6 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + int x = (kmax - k0); + for (; x > 3; x -= 4) { + transpose_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, + inptr5, inptr6, inptr7, outptr); + outptr += PACK_SIZE_32; + } + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; } - - transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); - outptr += PACK_SIZE_16; } - if (K > 0) { - if (y + 3 >= ymax) { - switch (y + 3 - ymax) { - case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU - case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU - case 0: - inptr3 = zerobuff; - break; - default: - megdnn_assert(0); + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + int K = (kmax - k0); + for (; K > 3; K -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } } + + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += PACK_SIZE_16; } - interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); - } - } -} - -void sgemm_8x12_pack_A_t(float* out, const float* in, int ldin, int x0, - int xmax, int k0, int kmax) { - int ksize = kmax - k0; - int ksize8 = (ksize << 3); - int ksize4 = (ksize << 2); - float* outptr_base = out; - float* outptr_base4 = outptr_base + (xmax - x0) / 8 * ksize8; - - int k = k0; - for (; k + 3 < kmax; k += 4) { - const float* inptr = in + k * ldin + x0; - const float* inptr1 = inptr + ldin; - const float* inptr2 = inptr1 + ldin; - const float* inptr3 = inptr2 + ldin; - - prefetch_3x(inptr); - prefetch_3x(inptr1); - prefetch_3x(inptr2); - prefetch_3x(inptr3); - - int x = x0; - auto outptr = outptr_base; - for (; x + 8 <= xmax; x += 8) { - auto outptr_interleave = outptr; - interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); - outptr += ksize8; - } - outptr = outptr_base4; - for (; x + 4 <= xmax; x += 4) { - auto outptr_interleave = outptr; - interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); - outptr += ksize4; - } - if (x < xmax) { - interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); - } - outptr_base += 4 * 8; - outptr_base4 += 4 * 4; - } - for (; k < kmax; k++) { - const float* inptr = in + k * ldin + x0; - prefetch_3x(inptr); - int x = x0; - auto outptr = outptr_base; - for (; x + 8 <= xmax; x += 8) { - auto outptr_interleave = outptr; - interleave_1x8_1_s(inptr, outptr_interleave); - outptr += ksize8; - } - outptr = outptr_base4; - for (; x + 4 <= xmax; x += 4) { - auto outptr_interleave = outptr; - interleave_1x4_1_s(inptr, outptr_interleave); - outptr += ksize4; - } - if (x < xmax) { - interleave_1(inptr, outptr, 4, xmax - x); + if (K > 0) { + if (y + 3 >= ymax) { + switch (y + 3 - ymax) { + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } } - outptr_base += 8; - outptr_base4 += 4; } -} - -void sgemm_8x12_pack_B_n(float* out, const float* in, int ldin, int x0, - int xmax, int k0, int kmax) { - int ksize = kmax - k0; - int ksize12 = ksize * 12; - int ksize4 = (ksize << 2); - float* outptr_base = out; - float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; - - int k = k0; - for (; k + 3 < kmax; k += 4) { - const float* inptr = in + k * ldin + x0; - const float* inptr1 = inptr + ldin; - const float* inptr2 = inptr1 + ldin; - const float* inptr3 = inptr2 + ldin; - - prefetch_3x(inptr); - prefetch_3x(inptr1); - prefetch_3x(inptr2); - prefetch_3x(inptr3); - - int x = x0; - auto outptr = outptr_base; - for (; x + 12 <= xmax; x += 12) { - auto outptr_interleave = outptr; - interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); - outptr += ksize12; - } - outptr = outptr_base4; - for (; x + 4 <= xmax; x += 4) { - auto outptr_interleave = outptr; - interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); - outptr += ksize4; + + static void sgemm_8x12_pack_A_t(float* out, const float* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize8 = (ksize << 3); + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 8 * ksize8; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k * ldin + x0; + const float* inptr1 = inptr + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + auto outptr = outptr_base; + for (; x + 8 <= xmax; x += 8) { + auto outptr_interleave = outptr; + interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize8; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, + xmax - x); + } + outptr_base += 4 * 8; + outptr_base4 += 4 * 4; } - if (x < xmax) { - interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + + for (; k < kmax; k++) { + const float* inptr = in + k * ldin + x0; + prefetch_3x(inptr); + int x = x0; + auto outptr = outptr_base; + for (; x + 8 <= xmax; x += 8) { + auto outptr_interleave = outptr; + interleave_1x8_1_s(inptr, outptr_interleave); + outptr += ksize8; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_s(inptr, outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + outptr_base += 8; + outptr_base4 += 4; } - outptr_base += 12 * 4; - outptr_base4 += 4 * 4; } - for (; k < kmax; k++) { - const float* inptr = in + k * ldin + x0; - prefetch_3x(inptr); - int x = x0; - auto outptr = outptr_base; - for (; x + 12 <= xmax; x += 12) { - auto outptr_interleave = outptr; - interleave_1x12_1_s(inptr, outptr_interleave); - outptr += ksize12; - } - outptr = outptr_base4; - for (; x + 4 <= xmax; x += 4) { - auto outptr_interleave = outptr; - interleave_1x4_1_s(inptr, outptr_interleave); - outptr += ksize4; + static void sgemm_8x12_pack_B_n(float* out, const float* in, int ldin, + int x0, int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize12 = ksize * 12; + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k * ldin + x0; + const float* inptr1 = inptr + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_3x(inptr); + prefetch_3x(inptr1); + prefetch_3x(inptr2); + prefetch_3x(inptr3); + + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, + outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, + xmax - x); + } + outptr_base += 12 * 4; + outptr_base4 += 4 * 4; } - if (x < xmax) { - interleave_1(inptr, outptr, 4, xmax - x); + + for (; k < kmax; k++) { + const float* inptr = in + k * ldin + x0; + prefetch_3x(inptr); + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + interleave_1x12_1_s(inptr, outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_s(inptr, outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + outptr_base += 12; + outptr_base4 += 4; } - outptr_base += 12; - outptr_base4 += 4; } -} - -void sgemm_8x12_pack_B_t(float* out, const float* in, int ldin, - int y0, int ymax, int k0, int kmax) { - float* outptr = out; - const float* inptr = in; - float zerobuff[12]; - std::memset(zerobuff, 0, sizeof(float) * 12); - int y = y0; - for (; y + 12 <= ymax; y += 12) { - const float* inptr0 = inptr + y * ldin + k0; - const float* inptr1 = inptr0 + ldin; - const float* inptr2 = inptr1 + ldin; - const float* inptr3 = inptr2 + ldin; - const float* inptr4 = inptr3 + ldin; - const float* inptr5 = inptr4 + ldin; - const float* inptr6 = inptr5 + ldin; - const float* inptr7 = inptr6 + ldin; - const float* inptr8 = inptr7 + ldin; - const float* inptr9 = inptr8 + ldin; - const float* inptr10 = inptr9 + ldin; - const float* inptr11 = inptr10 + ldin; - prefetch_2x(inptr0); - prefetch_2x(inptr1); - prefetch_2x(inptr2); - prefetch_2x(inptr3); - prefetch_2x(inptr4); - prefetch_2x(inptr5); - prefetch_2x(inptr6); - prefetch_2x(inptr7); - prefetch_2x(inptr8); - prefetch_2x(inptr9); - prefetch_2x(inptr10); - prefetch_2x(inptr11); - int x = (kmax - k0); - for (; x > 3; x -= 4) { - transpose_12x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, - outptr); - outptr += 48; - } - for (; x > 0; x--) { - *outptr++ = *inptr0++; - *outptr++ = *inptr1++; - *outptr++ = *inptr2++; - *outptr++ = *inptr3++; - *outptr++ = *inptr4++; - *outptr++ = *inptr5++; - *outptr++ = *inptr6++; - *outptr++ = *inptr7++; - *outptr++ = *inptr8++; - *outptr++ = *inptr9++; - *outptr++ = *inptr10++; - *outptr++ = *inptr11++; + + static void sgemm_8x12_pack_B_t(float* out, const float* in, int ldin, + int y0, int ymax, int k0, int kmax) { + float* outptr = out; + const float* inptr = in; + float zerobuff[12]; + std::memset(zerobuff, 0, sizeof(float) * 12); + int y = y0; + for (; y + 12 <= ymax; y += 12) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + const float* inptr4 = inptr3 + ldin; + const float* inptr5 = inptr4 + ldin; + const float* inptr6 = inptr5 + ldin; + const float* inptr7 = inptr6 + ldin; + const float* inptr8 = inptr7 + ldin; + const float* inptr9 = inptr8 + ldin; + const float* inptr10 = inptr9 + ldin; + const float* inptr11 = inptr10 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + prefetch_2x(inptr8); + prefetch_2x(inptr9); + prefetch_2x(inptr10); + prefetch_2x(inptr11); + int x = (kmax - k0); + for (; x > 3; x -= 4) { + transpose_12x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, + inptr5, inptr6, inptr7, inptr8, inptr9, + inptr10, inptr11, outptr); + outptr += 48; + } + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + *outptr++ = *inptr8++; + *outptr++ = *inptr9++; + *outptr++ = *inptr10++; + *outptr++ = *inptr11++; + } } - } - for (; y < ymax; y += 4) { - const float* inptr0 = inptr + y * ldin + k0; - const float* inptr1 = inptr0 + ldin; - const float* inptr2 = inptr1 + ldin; - const float* inptr3 = inptr2 + ldin; - - prefetch_2x(inptr0); - prefetch_2x(inptr1); - prefetch_2x(inptr2); - prefetch_2x(inptr3); - - /* Cope with ragged cases by copying from a buffer of zeroes instead - */ - int x = (kmax - k0); - for (; x > 3; x -= 4) { - if ((y + 3) >= ymax) { - switch ((y + 3) - ymax) { - /* Everything falls through in here */ - case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU - case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU - case 0: - inptr3 = zerobuff; - break; - default: - megdnn_assert(0); + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + + /* Cope with ragged cases by copying from a buffer of zeroes instead + */ + int x = (kmax - k0); + for (; x > 3; x -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } } - } - transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); - outptr += 16; - } + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + outptr += 16; + } - if (x > 0) { - if ((y + 3) >= ymax) { - switch ((y + 3) - ymax) { - /* Everything falls through in here */ - case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU - case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU - case 0: - inptr3 = zerobuff; - break; - default: - megdnn_assert(0); + if (x > 0) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x); } - interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x); } } -} - -} // matmul_general_4x16 -} // aarch64 -} // megdnn +}; +} // namespace aarch64 +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h new file mode 100644 index 000000000..46632f57f --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h @@ -0,0 +1,1331 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +struct matmul_general_8x12_a53 { + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // |v1| |v20[0-3]|v21[0-3]|v22[0-3]| + // |v1| |v23[0-3]|v24[0-3]|v25[0-3]| + // |v1| |v26[0-3]|v27[0-3]|v28[0-3]| + // |v1| |v29[0-3]|v30[0-3]|v31[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + static void kern_8x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + LDC = LDC * sizeof(float); + + register float* outptr asm("x0") = reinterpret_cast(output); +// clang-format off +#define LOAD_LINE(v0, v1, v2, n) \ + "ld1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s}, [x" n "]\n" \ + +#define LOAD_C \ + LOAD_LINE("8", "9", "10", "0") \ + LOAD_LINE("11", "12", "13", "1") \ + LOAD_LINE("14", "15", "16", "2") \ + LOAD_LINE("17", "18", "19", "3") \ + LOAD_LINE("20", "21", "22", "4") \ + LOAD_LINE("23", "24", "25", "5") \ + LOAD_LINE("26", "27", "28", "6") \ + LOAD_LINE("29", "30", "31", "7") + + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "prfm pldl1keep, [%[a_ptr]]\n" + "add x2, x1, %x[LDC]\n" + "prfm pldl1keep, [%[b_ptr]]\n" + "add x3, x2, %x[LDC]\n" + "prfm pldl1keep, [%[a_ptr], #64]\n" + "add x4, x3, %x[LDC]\n" + "prfm pldl1keep, [%[a_ptr], #128]\n" + "add x5, x4, %x[LDC]\n" + "prfm pldl1keep, [%[a_ptr], #192]\n" + "add x6, x5, %x[LDC]\n" + "prfm pldl1keep, [%[a_ptr], #256]\n" + "add x7, x6, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ldr d2, [%[b_ptr]]\n" + + "eor v9.16b, v9.16b, v9.16b\n" + "ldr x10, [%[b_ptr], #8]\n" + + "eor v10.16b, v10.16b, v10.16b\n" + "ldr d3, [%[b_ptr], #16]\n" + + "eor v11.16b, v11.16b, v11.16b\n" + "ldr x11, [%[b_ptr], #24]\n" + + "eor v12.16b, v12.16b, v12.16b\n" + "ldr d4, [%[b_ptr], #32]\n" + + "eor v13.16b, v13.16b, v13.16b\n" + "ldr x12, [%[b_ptr], #40]\n" + + "eor v14.16b, v14.16b, v14.16b\n" + "ldr d0, [%[a_ptr]]\n" + + "eor v15.16b, v15.16b, v15.16b\n" + "ldr x9, [%[a_ptr], #8]\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "add %[b_ptr], %[b_ptr], #48\n" + + "eor v17.16b, v17.16b, v17.16b\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "eor v18.16b, v18.16b, v18.16b\n" + "ins v2.d[1], x10\n" + + "eor v19.16b, v19.16b, v19.16b\n" + "ins v3.d[1], x11\n" + + "eor v20.16b, v20.16b, v20.16b\n" + "ins v4.d[1], x12\n" + + "eor v21.16b, v21.16b, v21.16b\n" + "ins v0.d[1], x9\n" + + "eor v22.16b, v22.16b, v22.16b\n" + "prfm pldl1keep, [%[a_ptr], #384]\n" + + "eor v23.16b, v23.16b, v23.16b\n" + "prfm pldl1keep, [%[b_ptr]]\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "prfm pldl1keep, [%[b_ptr], #64]\n" + + "eor v25.16b, v25.16b, v25.16b\n" + "prfm pldl1keep, [%[b_ptr], #128]\n" + + "eor v26.16b, v26.16b, v26.16b\n" + "prfm pldl1keep, [%[b_ptr], #192]\n" + + "eor v27.16b, v27.16b, v27.16b\n" + "prfm pldl1keep, [%[b_ptr], #256]\n" + + "eor v28.16b, v28.16b, v28.16b\n" + "prfm pldl1keep, [%[b_ptr], #320]\n" + + "eor v29.16b, v29.16b, v29.16b\n" + "prfm pldl1keep, [%[b_ptr], #384]\n" + + "eor v30.16b, v30.16b, v30.16b\n" + "prfm pldl1keep, [%[b_ptr], #448]\n" + + "eor v31.16b, v31.16b, v31.16b\n" + "prfm pldl1keep, [%[b_ptr], #512]\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + + "ldr d1, [%[a_ptr]]\n" + "nop\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v9.4s, v3.4s, v0.s[0]\n" + "subs %w[K], %w[K], #1\n" + + "fmla v10.4s, v4.4s, v0.s[0]\n" + + "ldr d5, [%[b_ptr]]\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v12.4s, v3.4s, v0.s[1]\n" + + "fmla v13.4s, v4.4s, v0.s[1]\n" + + "ldr d6, [%[b_ptr], #16]\n" + "ins v5.d[1], x10\n" + + "fmla v14.4s, v2.4s, v0.s[2]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v15.4s, v3.4s, v0.s[2]\n" + + "fmla v16.4s, v4.4s, v0.s[2]\n" + + "ldr d7, [%[b_ptr], #32]\n" + "ins v6.d[1], x11\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v18.4s, v3.4s, v0.s[3]\n" + + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "ldr d0, [%[a_ptr], #16]\n" + "ins v7.d[1], x12\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v21.4s, v3.4s, v1.s[0]\n" + + "fmla v22.4s, v4.4s, v1.s[0]\n" + + "prfm pldl1keep, [%[a_ptr], #448]\n" + "ins v0.d[1], x9\n" + + "fmla v23.4s, v2.4s, v1.s[1]\n" + + "fmla v24.4s, v3.4s, v1.s[1]\n" + + "fmla v25.4s, v4.4s, v1.s[1]\n" + + "prfm pldl1keep, [%[b_ptr], #576]\n" + "nop\n" + + "fmla v26.4s, v2.4s, v1.s[2]\n" + + "fmla v27.4s, v3.4s, v1.s[2]\n" + + "fmla v28.4s, v4.4s, v1.s[2]\n" + + "nop\n" + "nop\n" + + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "fmla v30.4s, v3.4s, v1.s[3]\n" + + "fmla v31.4s, v4.4s, v1.s[3]\n" + + //! UNROLL + "ldr d1, [%[a_ptr], #32]\n" + "nop\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + + "fmla v9.4s, v6.4s, v0.s[0]\n" + + "fmla v10.4s, v7.4s, v0.s[0]\n" + + "ldr d2, [%[b_ptr], #48]\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v5.4s, v0.s[1]\n" + "ldr x10, [%[b_ptr], #56]\n" + + "fmla v12.4s, v6.4s, v0.s[1]\n" + + "fmla v13.4s, v7.4s, v0.s[1]\n" + + "ldr d3, [%[b_ptr], #64]\n" + "ins v2.d[1], x10\n" + + "fmla v14.4s, v5.4s, v0.s[2]\n" + "ldr x11, [%[b_ptr], #72]\n" + + "fmla v15.4s, v6.4s, v0.s[2]\n" + + "fmla v16.4s, v7.4s, v0.s[2]\n" + + "ldr d4, [%[b_ptr], #80]\n" + "ins v3.d[1], x11\n" + + "fmla v17.4s, v5.4s, v0.s[3]\n" + "ldr x12, [%[b_ptr], #88]\n" + + "fmla v18.4s, v6.4s, v0.s[3]\n" + + "fmla v19.4s, v7.4s, v0.s[3]\n" + + "ldr d0, [%[a_ptr], #48]\n" + "ins v4.d[1], x12\n" + + "fmla v20.4s, v5.4s, v1.s[0]\n" + "ldr x9, [%[a_ptr], #56]\n" + + "fmla v21.4s, v6.4s, v1.s[0]\n" + "add %[b_ptr], %[b_ptr], #96\n" + + "fmla v22.4s, v7.4s, v1.s[0]\n" + + "nop\n" + "ins v0.d[1], x9\n" + + "fmla v23.4s, v5.4s, v1.s[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + + "fmla v24.4s, v6.4s, v1.s[1]\n" + + "fmla v25.4s, v7.4s, v1.s[1]\n" + + "prfm pldl1keep, [%[b_ptr], #640]\n" + "nop\n" + + "fmla v26.4s, v5.4s, v1.s[2]\n" + + "fmla v27.4s, v6.4s, v1.s[2]\n" + + "fmla v28.4s, v7.4s, v1.s[2]\n" + + "nop\n" + "nop\n" + + "fmla v29.4s, v5.4s, v1.s[3]\n" + + "fmla v30.4s, v6.4s, v1.s[3]\n" + + "fmla v31.4s, v7.4s, v1.s[3]\n" + + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + + "ldr d1, [%[a_ptr]] \n" + "prfm pstl1keep, [x0]\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #8] \n" + + "fmla v9.4s, v3.4s, v0.s[0]\n" + + "fmla v10.4s, v4.4s, v0.s[0]\n" + "nop\n" + + "ldr d5, [%[b_ptr]]\n" + "prfm pstl1keep, [x1]\n" + + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v12.4s, v3.4s, v0.s[1]\n" + + "fmla v13.4s, v4.4s, v0.s[1]\n" + + "ldr d6, [%[b_ptr], #16]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v2.4s, v0.s[2]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v15.4s, v3.4s, v0.s[2]\n" + + "fmla v16.4s, v4.4s, v0.s[2]\n" + + "prfm pstl1keep, [x2]\n" + "ins v5.d[1], x10\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + + "fmla v18.4s, v3.4s, v0.s[3]\n" + + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "ldr d0, [%[a_ptr], #16]\n" + "ins v6.d[1], x11\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v21.4s, v3.4s, v1.s[0]\n" + + "fmla v22.4s, v4.4s, v1.s[0]\n" + + "ldr d7, [%[b_ptr], #32]\n" + "ins v0.d[1], x9\n" + + "fmla v23.4s, v2.4s, v1.s[1]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v24.4s, v3.4s, v1.s[1]\n" + + "fmla v25.4s, v4.4s, v1.s[1]\n" + + "nop\n" + "ins v7.d[1], x12\n" + + "fmla v26.4s, v2.4s, v1.s[2]\n" + + "fmla v27.4s, v3.4s, v1.s[2]\n" + + "fmla v28.4s, v4.4s, v1.s[2]\n" + + "nop\n" + "nop\n" + + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "fmla v30.4s, v3.4s, v1.s[3]\n" + + "fmla v31.4s, v4.4s, v1.s[3]\n" + + "ldr d1, [%[a_ptr], #32]\n" + "nop\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + + "fmla v9.4s, v6.4s, v0.s[0]\n" + + "fmla v10.4s, v7.4s, v0.s[0]\n" + + "nop\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v5.4s, v0.s[1]\n" + + "fmla v12.4s, v6.4s, v0.s[1]\n" + + "fmla v13.4s, v7.4s, v0.s[1]\n" + + "fmla v14.4s, v5.4s, v0.s[2]\n" + + "fmla v15.4s, v6.4s, v0.s[2]\n" + "str q8, [x0]\n" + + "fmla v16.4s, v7.4s, v0.s[2]\n" + "str q9, [x0, #16]\n" + + "fmla v17.4s, v5.4s, v0.s[3]\n" + "str q10, [x0, #32]\n" + + "fmla v18.4s, v6.4s, v0.s[3]\n" + + "fmla v19.4s, v7.4s, v0.s[3]\n" + "str q11, [x1]\n" + + "fmla v20.4s, v5.4s, v1.s[0]\n" + "str q12, [x1, #16]\n" + + "fmla v21.4s, v6.4s, v1.s[0]\n" + "str q13, [x1, #32]\n" + + "fmla v22.4s, v7.4s, v1.s[0]\n" + "str q14, [x2]\n" + + "fmla v23.4s, v5.4s, v1.s[1]\n" + "str q15, [x2, #16]\n" + + "fmla v24.4s, v6.4s, v1.s[1]\n" + "str q16, [x2, #32]\n" + + "fmla v25.4s, v7.4s, v1.s[1]\n" + "str q17, [x3]\n" + + "fmla v26.4s, v5.4s, v1.s[2]\n" + "str q18, [x3, #16]\n" + + "fmla v27.4s, v6.4s, v1.s[2]\n" + "str q19, [x3, #32]\n" + + "fmla v28.4s, v7.4s, v1.s[2]\n" + "str q20, [x4]\n" + + "fmla v29.4s, v5.4s, v1.s[3]\n" + "str q21, [x4, #16]\n" + + "fmla v30.4s, v6.4s, v1.s[3]\n" + "str q22, [x4, #32]\n" + + "fmla v31.4s, v7.4s, v1.s[3]\n" + "str q23, [x5]\n" + + "str q24, [x5, #16]\n" + "str q25, [x5, #32]\n" + + "st1 {v26.4s, v27.4s, v28.4s}, [x6]\n" + "st1 {v29.4s, v30.4s, v31.4s}, [x7]\n" + "b 6f\n" + + // odd tail + "5:\n" + "ldr d1, [%[a_ptr]]\n" + "nop\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v9.4s, v3.4s, v0.s[0]\n" + + "fmla v10.4s, v4.4s, v0.s[0]\n" + + "nop\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v2.4s, v0.s[1]\n" + + "fmla v12.4s, v3.4s, v0.s[1]\n" + + "fmla v13.4s, v4.4s, v0.s[1]\n" + + "fmla v14.4s, v2.4s, v0.s[2]\n" + + "fmla v15.4s, v3.4s, v0.s[2]\n" + + "fmla v16.4s, v4.4s, v0.s[2]\n" + "str q8, [x0]\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + "str q9, [x0, #16]\n" + + "fmla v18.4s, v3.4s, v0.s[3]\n" + "str q10, [x0, #32]\n" + + "fmla v19.4s, v4.4s, v0.s[3]\n" + "str q11, [x1]\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + "str q12, [x1, #16]\n" + + "fmla v21.4s, v3.4s, v1.s[0]\n" + "str q13, [x1, #32]\n" + + "fmla v22.4s, v4.4s, v1.s[0]\n" + "str q14, [x2]\n" + + "fmla v23.4s, v2.4s, v1.s[1]\n" + "str q15, [x2, #16]\n" + + "fmla v24.4s, v3.4s, v1.s[1]\n" + "str q16, [x2, #32]\n" + + "fmla v25.4s, v4.4s, v1.s[1]\n" + "str q17, [x3]\n" + + "fmla v26.4s, v2.4s, v1.s[2]\n" + "str q18, [x3, #16]\n" + + "fmla v27.4s, v3.4s, v1.s[2]\n" + "str q19, [x3, #32]\n" + + "fmla v28.4s, v4.4s, v1.s[2]\n" + "str q20, [x4]\n" + + "fmla v29.4s, v2.4s, v1.s[3]\n" + "str q21, [x4, #16]\n" + + "fmla v30.4s, v3.4s, v1.s[3]\n" + "str q22, [x4, #32]\n" + + "fmla v31.4s, v4.4s, v1.s[3]\n" + "str q23, [x5]\n" + + "str q24, [x5, #16]\n" + "str q25, [x5, #32]\n" + "str q26, [x6]\n" + "str q27, [x6, #16]\n" + "str q28, [x6, #32]\n" + + "st1 {v29.4s, v30.4s, v31.4s}, [x7]\n" + + "6:\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", + "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory"); +#undef LOAD_LINE +#undef LOAD_C + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+ + // | v2[0-3]| + // | v5[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // |v1| |v20[0-3]| + // |v1| |v23[0-3]| + // |v1| |v26[0-3]| + // |v1| |v29[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_8x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = reinterpret_cast(output); + +// clang-format off +#define LOAD_LINE(v0, n) \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" v0 ".4s}, [x" n "],#16\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[0], [x" n "],#4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[1], [x" n "],#4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[2], [x" n "],#4\n" \ + "101" n ":\n" \ + +#define LOAD_C \ + LOAD_LINE("8", "0") \ + LOAD_LINE("11", "1") \ + LOAD_LINE("14", "2") \ + LOAD_LINE("17", "3") \ + LOAD_LINE("20", "4") \ + LOAD_LINE("23", "5") \ + LOAD_LINE("26", "6") \ + LOAD_LINE("29", "7") \ + + +#define STORE_LINE(v0, n) \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "st1 {v" v0 ".4s}, [x" n " ],#16\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[0], [x" n "],#4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[1], [x" n "],#4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[2], [x" n "],#4\n" \ + "104" n ":\n" \ + + +#define STORE_C \ + STORE_LINE("8", "0") \ + STORE_LINE("11", "1") \ + STORE_LINE("14", "2") \ + STORE_LINE("17", "3") \ + STORE_LINE("20", "4") \ + STORE_LINE("23", "5") \ + STORE_LINE("26", "6") \ + STORE_LINE("29", "7") \ + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + + "2: \n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ldr q1, [%[a_ptr]]\n" + "nop\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v11.4s, v2.4s, v0.s[1]\n" + + "fmla v14.4s, v2.4s, v0.s[2]\n" + + "ldr d5, [%[b_ptr]]\n" + "ins v1.d[1], x8\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + + "fmla v23.4s, v2.4s, v1.s[1]\n" + + "ldr d0, [%[a_ptr], #16]\n" + "ins v5.d[1], x10\n" + + "fmla v26.4s, v2.4s, v1.s[2]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "ldr d1, [%[a_ptr], #32]\n" + "ldr x8, [%[a_ptr], #40]\n" + + "ldr d2, [%[b_ptr], #16]\n" + "ins v0.d[1], x9\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "ldr x10, [%[b_ptr], #24]\n" + + "fmla v11.4s, v5.4s, v0.s[1]\n" + + "fmla v14.4s, v5.4s, v0.s[2]\n" + "nop\n" + + "ins v2.d[1], x10\n" + "ins v1.d[1], x8\n" + + "fmla v17.4s, v5.4s, v0.s[3]\n" + + "fmla v20.4s, v5.4s, v1.s[0]\n" + + "fmla v23.4s, v5.4s, v1.s[1]\n" + + "ldr d0, [%[a_ptr], #48]\n" + "nop\n" + + "fmla v26.4s, v5.4s, v1.s[2]\n" + "ldr x9, [%[a_ptr], #56]\n" + + "fmla v29.4s, v5.4s, v1.s[3]\n" + "add %[b_ptr], %[b_ptr], #32\n" + + "add %[a_ptr], %[a_ptr], #64\n" + "subs %w[K], %w[K], #1\n" + + "ins v0.d[1], x9\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ldr d1, [%[a_ptr]]\n" + "nop\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + + "ldr d5, [%[b_ptr]]\n" + "ins v1.d[1], x8\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + + "ldr d0, [%[a_ptr], #16]\n" + "ins v5.d[1], x10\n" + + "fmla v26.4s, v2.4s, v1.s[2]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "ldr d1, [%[a_ptr], #32]\n" + "ins v0.d[1], x9\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + + "fmla v14.4s, v5.4s, v0.s[2]\n" + + "nop\n" + "ins v1.d[1], x8\n" + + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "ldr q1, [%[a_ptr]]\n" + "nop\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v11.4s, v2.4s, v0.s[1]\n" + + "fmla v14.4s, v2.4s, v0.s[2]\n" + + "nop\n" + "ins v1.d[1], x8\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", + "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + static void kern_4x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int m_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = output; + +// clang-format off +#define LOAD_LINE(v0, v1, v2, n) \ + "cmp x10, #0\n" \ + "beq 102f\n" \ + "ld1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s}, [x" n "]\n" \ + "subs x10, x10, #1\n" + +#define LOAD_C \ + "mov x10, %x[m_remain]\n" \ + LOAD_LINE("8","9","10", "0") \ + LOAD_LINE("11","12","13", "1") \ + LOAD_LINE("14","15","16", "2") \ + LOAD_LINE("17","18","19", "3") \ + "102:\n" + +#define STORE_LINE(v0, v1, v2, n) \ + "cmp x10, #0 \n" \ + "beq 105f\n" \ + "st1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s}, [x" n "]\n" \ + "subs x10, x10, #1\n" + + +#define STORE_C \ + "mov x10, %x[m_remain]\n" \ + STORE_LINE("8","9","10", "0") \ + STORE_LINE("11","12","13", "1") \ + STORE_LINE("14","15","16", "2") \ + STORE_LINE("17","18","19", "3") \ + "105:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ldr d5, [%[b_ptr]]\n" + "nop\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr x20, [%[b_ptr], #8]\n" + + "fmla v9.4s, v3.4s, v0.s[0]\n" + + "fmla v10.4s, v4.4s, v0.s[0]\n" + + "ldr d6, [%[b_ptr], #16]\n" + "ins v5.d[1], x20\n" + + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ldr x21, [%[b_ptr], #24]\n" + + "fmla v12.4s, v3.4s, v0.s[1]\n" + + "fmla v13.4s, v4.4s, v0.s[1]\n" + + "ldr d1, [%[a_ptr]]\n" + "ins v6.d[1], x21\n" + + "fmla v14.4s, v2.4s, v0.s[2]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v15.4s, v3.4s, v0.s[2]\n" + + "fmla v16.4s, v4.4s, v0.s[2]\n" + + "ldr d7, [%[b_ptr], #32]\n" + "ins v1.d[1], x8\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ldr x22, [%[b_ptr], #40]\n" + + "fmla v18.4s, v3.4s, v0.s[3]\n" + + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "ldr d2, [%[b_ptr], #48]\n" + "ins v7.d[1], x22\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "ldr x20, [%[b_ptr], #56]\n" + + "fmla v9.4s, v6.4s, v1.s[0]\n" + + "fmla v10.4s, v7.4s, v1.s[0]\n" + + "ldr d3, [%[b_ptr], #64]\n" + "ins v2.d[1], x20\n" + + "fmla v11.4s, v5.4s, v1.s[1]\n" + "ldr x21, [%[b_ptr], #72]\n" + + "fmla v12.4s, v6.4s, v1.s[1]\n" + + "fmla v13.4s, v7.4s, v1.s[1]\n" + + "ldr d4, [%[b_ptr], #80]\n" + "ins v3.d[1], x21\n" + + "fmla v14.4s, v5.4s, v1.s[2]\n" + "ldr x22, [%[b_ptr], #88]\n" + + "fmla v15.4s, v6.4s, v1.s[2]\n" + "add %[b_ptr], %[b_ptr], #96\n" + + "fmla v16.4s, v7.4s, v1.s[2]\n" + + "ldr q0, [%[a_ptr], #16]\n" + "ins v4.d[1], x22\n" + + "fmla v17.4s, v5.4s, v1.s[3]\n" + "ldr x8, [%[a_ptr], #24]\n" + + "fmla v18.4s, v6.4s, v1.s[3]\n" + "add %[a_ptr], %[a_ptr], #32\n" + + "fmla v19.4s, v7.4s, v1.s[3]\n" + "subs %w[K], %w[K], #1\n" + + "ins v0.d[1], x8\n" + + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ldr d5, [%[b_ptr]]\n" + "nop\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr x20, [%[b_ptr], #8]\n" + + "fmla v9.4s, v3.4s, v0.s[0]\n" + + "fmla v10.4s, v4.4s, v0.s[0]\n" + + "ldr d6, [%[b_ptr], #16]\n" + "ins v5.d[1], x20\n" + + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ldr x21, [%[b_ptr], #24]\n" + + "fmla v12.4s, v3.4s, v0.s[1]\n" + + "fmla v13.4s, v4.4s, v0.s[1]\n" + + "ldr d1, [%[a_ptr]]\n" + "ins v6.d[1], x21\n" + + "fmla v14.4s, v2.4s, v0.s[2]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + + "ldr d7, [%[b_ptr], #32]\n" + "ins v1.d[1], x8\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ldr x22, [%[b_ptr], #40]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "nop\n" + "ins v7.d[1], x22\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v9.4s, v6.4s, v1.s[0]\n" + "fmla v10.4s, v7.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v15.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v1.s[3]\n" + "fmla v19.4s, v7.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr), + [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "x1", "x2", "x3", "x8", "x9", "x10", "x20", "x21", + "x22", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C + } + + // Overview of register layout: + // + // A 2x4 cell of Rhs is stored in 32bit in v2 - v3 + // A 4x2 cell of Lhs is stored in 32bit in v0 - v1 + // A 4x4 block of accumulators is stored in 32bit in v4-v6 + // + // +--------+ + // | v2[0-3]| + // | v5[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_4x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = output; + +// clang-format off +#define LOAD_LINE(v0, n) \ + "cmp x10, #0\n" \ + "beq 102f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" v0 ".4s}, [x" n "], 16\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[0], [x" n "], 4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[1], [x" n "], 4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[2], [x" n "], 4\n" \ + "101" n ":\n" \ + "subs x10, x10, #1\n" + +#define LOAD_C \ + "mov x10, %x[m_remain]\n" \ + LOAD_LINE("8", "0") \ + LOAD_LINE("11", "1") \ + LOAD_LINE("14", "2") \ + LOAD_LINE("17", "3") \ + "102:\n" + +#define STORE_LINE(v0, n) \ + "cmp x10, #0 \n" \ + "beq 105f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "st1 {v" v0 ".4s}, [x" n " ], 16\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[0], [x" n "], 4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[1], [x" n "], 4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[2], [x" n "], 4\n" \ + "104" n ":\n" \ + "subs x10, x10, #1\n" + + +#define STORE_C \ + "mov x10, %x[m_remain]\n" \ + STORE_LINE("8", "0") \ + STORE_LINE("11", "1") \ + STORE_LINE("14", "2") \ + STORE_LINE("17", "3") \ + "105:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + + "2: \n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr), + [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", + "x3", "x10", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C + } +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h new file mode 100644 index 000000000..c205a48d4 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h @@ -0,0 +1,1170 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +struct matmul_general_8x12_a55 { + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // |v1| |v20[0-3]|v21[0-3]|v22[0-3]| + // |v1| |v23[0-3]|v24[0-3]|v25[0-3]| + // |v1| |v26[0-3]|v27[0-3]|v28[0-3]| + // |v1| |v29[0-3]|v30[0-3]|v31[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + static void kern_8x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = reinterpret_cast(output); +// clang-format off +#define LOAD_LINE(v0, v1, v2, n) \ + "ld1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s}, [x" n "]\n" \ + +#define LOAD_C \ + LOAD_LINE("8", "9", "10", "0") \ + LOAD_LINE("11", "12", "13", "1") \ + LOAD_LINE("14", "15", "16", "2") \ + LOAD_LINE("17", "18", "19", "3") \ + LOAD_LINE("20", "21", "22", "4") \ + LOAD_LINE("23", "24", "25", "5") \ + LOAD_LINE("26", "27", "28", "6") \ + LOAD_LINE("29", "30", "31", "7") + + // clang-format on + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "prfm pldl1keep, [%[a_ptr]]\n" + "add x2, x1, %x[LDC]\n" + "prfm pldl1keep, [%[b_ptr]]\n" + "add x3, x2, %x[LDC]\n" + "prfm pldl1keep, [%[a_ptr], #64]\n" + "add x4, x3, %x[LDC]\n" + "prfm pldl1keep, [%[a_ptr], #128]\n" + "add x5, x4, %x[LDC]\n" + "prfm pldl1keep, [%[a_ptr], #192]\n" + "add x6, x5, %x[LDC]\n" + "prfm pldl1keep, [%[a_ptr], #256]\n" + "add x7, x6, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ldr d2, [%[b_ptr]]\n" + + "eor v9.16b, v9.16b, v9.16b\n" + "ldr x10, [%[b_ptr], #8]\n" + + "eor v10.16b, v10.16b, v10.16b\n" + "ldr d3, [%[b_ptr], #16]\n" + + "eor v11.16b, v11.16b, v11.16b\n" + "ldr x11, [%[b_ptr], #24]\n" + + "eor v12.16b, v12.16b, v12.16b\n" + "ldr d4, [%[b_ptr], #32]\n" + + "eor v13.16b, v13.16b, v13.16b\n" + "ldr x12, [%[b_ptr], #40]\n" + + "eor v14.16b, v14.16b, v14.16b\n" + "ldr d0, [%[a_ptr]]\n" + + "eor v15.16b, v15.16b, v15.16b\n" + "ldr x9, [%[a_ptr], #8]\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "add %[b_ptr], %[b_ptr], #48\n" + + "eor v17.16b, v17.16b, v17.16b\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "eor v18.16b, v18.16b, v18.16b\n" + "ins v2.d[1], x10\n" + + "eor v19.16b, v19.16b, v19.16b\n" + "ins v3.d[1], x11\n" + + "eor v20.16b, v20.16b, v20.16b\n" + "ins v4.d[1], x12\n" + + "eor v21.16b, v21.16b, v21.16b\n" + "ins v0.d[1], x9\n" + + "eor v22.16b, v22.16b, v22.16b\n" + "prfm pldl1keep, [%[a_ptr], #384]\n" + + "eor v23.16b, v23.16b, v23.16b\n" + "prfm pldl1keep, [%[b_ptr]]\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "prfm pldl1keep, [%[b_ptr], #64]\n" + + "eor v25.16b, v25.16b, v25.16b\n" + "prfm pldl1keep, [%[b_ptr], #128]\n" + + "eor v26.16b, v26.16b, v26.16b\n" + "prfm pldl1keep, [%[b_ptr], #192]\n" + + "eor v27.16b, v27.16b, v27.16b\n" + "prfm pldl1keep, [%[b_ptr], #256]\n" + + "eor v28.16b, v28.16b, v28.16b\n" + "prfm pldl1keep, [%[b_ptr], #320]\n" + + "eor v29.16b, v29.16b, v29.16b\n" + "prfm pldl1keep, [%[b_ptr], #384]\n" + + "eor v30.16b, v30.16b, v30.16b\n" + "prfm pldl1keep, [%[b_ptr], #448]\n" + + "eor v31.16b, v31.16b, v31.16b\n" + "prfm pldl1keep, [%[b_ptr], #512]\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr d1, [%[a_ptr]]\n" + + "fmla v9.4s, v3.4s, v0.s[0]\n" + "subs %w[K], %w[K], #1\n" + + "fmla v10.4s, v4.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v11.4s, v2.4s, v0.s[1]\n" + + "fmla v12.4s, v3.4s, v0.s[1]\n" + "ldr d5, [%[b_ptr]]\n" + + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v2.4s, v0.s[2]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v15.4s, v3.4s, v0.s[2]\n" + + "fmla v16.4s, v4.4s, v0.s[2]\n" + "ldr d6, [%[b_ptr], #16]\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ins v5.d[1], x10\n" + + "fmla v18.4s, v3.4s, v0.s[3]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + "ldr d7, [%[b_ptr], #32]\n" + + "fmla v21.4s, v3.4s, v1.s[0]\n" + "ins v6.d[1], x11\n" + + "fmla v22.4s, v4.4s, v1.s[0]\n" + "ldr d0, [%[a_ptr], #16]\n" + + "fmla v23.4s, v2.4s, v1.s[1]\n" + + "fmla v24.4s, v3.4s, v1.s[1]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v25.4s, v4.4s, v1.s[1]\n" + + "fmla v26.4s, v2.4s, v1.s[2]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v27.4s, v3.4s, v1.s[2]\n" + "ins v7.d[1], x12\n" + + "fmla v28.4s, v4.4s, v1.s[2]\n" + "prfm pldl1keep, [%[a_ptr], #448]\n" + + "fmla v29.4s, v2.4s, v1.s[3]\n" + "ins v0.d[1], x9\n" + + "fmla v30.4s, v3.4s, v1.s[3]\n" + "prfm pldl1keep, [%[b_ptr], #576]\n" + + "fmla v31.4s, v4.4s, v1.s[3]\n" + + //! UNROLL + "fmla v8.4s, v5.4s, v0.s[0]\n" + "ldr d1, [%[a_ptr], #32]\n" + + "fmla v9.4s, v6.4s, v0.s[0]\n" + + "fmla v10.4s, v7.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + + "fmla v11.4s, v5.4s, v0.s[1]\n" + + "fmla v12.4s, v6.4s, v0.s[1]\n" + "ldr d2, [%[b_ptr], #48]\n" + + "fmla v13.4s, v7.4s, v0.s[1]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v5.4s, v0.s[2]\n" + "ldr x10, [%[b_ptr], #56]\n" + + "fmla v15.4s, v6.4s, v0.s[2]\n" + + "fmla v16.4s, v7.4s, v0.s[2]\n" + "ldr d3, [%[b_ptr], #64]\n" + + "fmla v17.4s, v5.4s, v0.s[3]\n" + "ins v2.d[1], x10\n" + + "fmla v18.4s, v6.4s, v0.s[3]\n" + "ldr x11, [%[b_ptr], #72]\n" + + "fmla v19.4s, v7.4s, v0.s[3]\n" + + "fmla v20.4s, v5.4s, v1.s[0]\n" + "ldr d4, [%[b_ptr], #80]\n" + + "fmla v21.4s, v6.4s, v1.s[0]\n" + "ins v3.d[1], x11\n" + + "fmla v22.4s, v7.4s, v1.s[0]\n" + "ldr x12, [%[b_ptr], #88]\n" + + "fmla v23.4s, v5.4s, v1.s[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + + "fmla v24.4s, v6.4s, v1.s[1]\n" + "ldr d0, [%[a_ptr], #48]\n" + + "fmla v25.4s, v7.4s, v1.s[1]\n" + "ins v4.d[1], x12\n" + + "fmla v26.4s, v5.4s, v1.s[2]\n" + "ldr x9, [%[a_ptr], #56]\n" + + "fmla v27.4s, v6.4s, v1.s[2]\n" + "add %[a_ptr], %[a_ptr], #64\n" + + "fmla v28.4s, v7.4s, v1.s[2]\n" + "prfm pldl1keep, [%[b_ptr], #640]\n" + + "fmla v29.4s, v5.4s, v1.s[3]\n" + "ins v0.d[1], x9\n" + + "fmla v30.4s, v6.4s, v1.s[3]\n" + + "fmla v31.4s, v7.4s, v1.s[3]\n" + + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v2.4s, v0.s[0]\n" + "prfm pstl1keep, [x0]\n" + + "fmla v9.4s, v3.4s, v0.s[0]\n" + "ldr d1, [%[a_ptr]] \n" + + "fmla v10.4s, v4.4s, v0.s[0]\n" + "prfm pstl1keep, [x1]\n" + + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ldr x8, [%[a_ptr], #8] \n" + + "fmla v12.4s, v3.4s, v0.s[1]\n" + "prfm pstl1keep, [x2]\n" + + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ldr d5, [%[b_ptr]]\n" + + "fmla v14.4s, v2.4s, v0.s[2]\n" + "ins v1.d[1], x8\n" + + "fmla v15.4s, v3.4s, v0.s[2]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v16.4s, v4.4s, v0.s[2]\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ldr d6, [%[b_ptr], #16]\n" + + "fmla v18.4s, v3.4s, v0.s[3]\n" + "ins v5.d[1], x10\n" + + "fmla v19.4s, v4.4s, v0.s[3]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + + "fmla v21.4s, v3.4s, v1.s[0]\n" + "ldr d0, [%[a_ptr], #16]\n" + + "fmla v22.4s, v4.4s, v1.s[0]\n" + "ins v6.d[1], x11\n" + + "fmla v23.4s, v2.4s, v1.s[1]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v24.4s, v3.4s, v1.s[1]\n" + + "fmla v25.4s, v4.4s, v1.s[1]\n" + "ldr d7, [%[b_ptr], #32]\n" + + "fmla v26.4s, v2.4s, v1.s[2]\n" + "ins v0.d[1], x9\n" + + "fmla v27.4s, v3.4s, v1.s[2]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v28.4s, v4.4s, v1.s[2]\n" + + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "fmla v30.4s, v3.4s, v1.s[3]\n" + "ins v7.d[1], x12\n" + + "fmla v31.4s, v4.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "ldr d1, [%[a_ptr], #32]\n" + + "fmla v9.4s, v6.4s, v0.s[0]\n" + + "fmla v10.4s, v7.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + + "fmla v11.4s, v5.4s, v0.s[1]\n" + + "fmla v12.4s, v6.4s, v0.s[1]\n" + "str q8, [x0]\n" + + "fmla v13.4s, v7.4s, v0.s[1]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v5.4s, v0.s[2]\n" + "str q9, [x0, #16]\n" + + "fmla v15.4s, v6.4s, v0.s[2]\n" + "str q10, [x0, #32]\n" + + "fmla v16.4s, v7.4s, v0.s[2]\n" + "str q11, [x1]\n" + + "fmla v17.4s, v5.4s, v0.s[3]\n" + "str q12, [x1, #16]\n" + + "fmla v18.4s, v6.4s, v0.s[3]\n" + "str q13, [x1, #32]\n" + + "fmla v19.4s, v7.4s, v0.s[3]\n" + "str q14, [x2]\n" + + "fmla v20.4s, v5.4s, v1.s[0]\n" + "str q15, [x2, #16]\n" + + "fmla v21.4s, v6.4s, v1.s[0]\n" + "str q16, [x2, #32]\n" + + "fmla v22.4s, v7.4s, v1.s[0]\n" + "str q17, [x3]\n" + + "fmla v23.4s, v5.4s, v1.s[1]\n" + "str q18, [x3, #16]\n" + + "fmla v24.4s, v6.4s, v1.s[1]\n" + "str q19, [x3, #32]\n" + + "fmla v25.4s, v7.4s, v1.s[1]\n" + "str q20, [x4]\n" + + "fmla v26.4s, v5.4s, v1.s[2]\n" + "str q21, [x4, #16]\n" + + "fmla v27.4s, v6.4s, v1.s[2]\n" + "str q22, [x4, #32]\n" + + "fmla v28.4s, v7.4s, v1.s[2]\n" + "str q23, [x5]\n" + + "fmla v29.4s, v5.4s, v1.s[3]\n" + "str q24, [x5, #16]\n" + + "fmla v30.4s, v6.4s, v1.s[3]\n" + "str q25, [x5, #32]\n" + + "fmla v31.4s, v7.4s, v1.s[3]\n" + + "st1 {v26.4s, v27.4s, v28.4s}, [x6]\n" + "st1 {v29.4s, v30.4s, v31.4s}, [x7]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr d1, [%[a_ptr]]\n" + + "fmla v9.4s, v3.4s, v0.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v10.4s, v4.4s, v0.s[0]\n" + "str q8, [x0]\n" + + "fmla v11.4s, v2.4s, v0.s[1]\n" + "str q9, [x0, #16]\n" + + "fmla v12.4s, v3.4s, v0.s[1]\n" + "str q10, [x0, #32]\n" + + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v2.4s, v0.s[2]\n" + "str q11, [x1]\n" + + "fmla v15.4s, v3.4s, v0.s[2]\n" + "str q12, [x1, #16]\n" + + "fmla v16.4s, v4.4s, v0.s[2]\n" + "str q13, [x1, #32]\n" + + "fmla v17.4s, v2.4s, v0.s[3]\n" + "str q14, [x2]\n" + + "fmla v18.4s, v3.4s, v0.s[3]\n" + "str q15, [x2, #16]\n" + + "fmla v19.4s, v4.4s, v0.s[3]\n" + "str q16, [x2, #32]\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + "str q17, [x3]\n" + + "fmla v21.4s, v3.4s, v1.s[0]\n" + "str q18, [x3, #16]\n" + + "fmla v22.4s, v4.4s, v1.s[0]\n" + "str q19, [x3, #32]\n" + + "fmla v23.4s, v2.4s, v1.s[1]\n" + "str q20, [x4]\n" + + "fmla v24.4s, v3.4s, v1.s[1]\n" + "str q21, [x4, #16]\n" + + "fmla v25.4s, v4.4s, v1.s[1]\n" + "str q22, [x4, #32]\n" + + "fmla v26.4s, v2.4s, v1.s[2]\n" + "str q23, [x5]\n" + + "fmla v27.4s, v3.4s, v1.s[2]\n" + "str q24, [x5, #16]\n" + + "fmla v28.4s, v4.4s, v1.s[2]\n" + "str q25, [x5, #32]\n" + + "fmla v29.4s, v2.4s, v1.s[3]\n" + "str q26, [x6]\n" + + "fmla v30.4s, v3.4s, v1.s[3]\n" + "str q27, [x6, #16]\n" + + "fmla v31.4s, v4.4s, v1.s[3]\n" + "str q28, [x6, #32]\n" + + "st1 {v29.4s, v30.4s, v31.4s}, [x7]\n" + + "6:\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", + "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory"); +#undef LOAD_LINE +#undef LOAD_C + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+ + // | v2[0-3]| + // | v5[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // |v1| |v20[0-3]| + // |v1| |v23[0-3]| + // |v1| |v26[0-3]| + // |v1| |v29[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_8x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = reinterpret_cast(output); + +// clang-format off +#define LOAD_LINE(v0, n) \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" v0 ".4s}, [x" n "],#16\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[0], [x" n "],#4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[1], [x" n "],#4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[2], [x" n "],#4\n" \ + "101" n ":\n" \ + +#define LOAD_C \ + LOAD_LINE("8", "0") \ + LOAD_LINE("11", "1") \ + LOAD_LINE("14", "2") \ + LOAD_LINE("17", "3") \ + LOAD_LINE("20", "4") \ + LOAD_LINE("23", "5") \ + LOAD_LINE("26", "6") \ + LOAD_LINE("29", "7") \ + + +#define STORE_LINE(v0, n) \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "st1 {v" v0 ".4s}, [x" n " ],#16\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[0], [x" n "],#4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[1], [x" n "],#4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[2], [x" n "],#4\n" \ + "104" n ":\n" \ + + +#define STORE_C \ + STORE_LINE("8", "0") \ + STORE_LINE("11", "1") \ + STORE_LINE("14", "2") \ + STORE_LINE("17", "3") \ + STORE_LINE("20", "4") \ + STORE_LINE("23", "5") \ + STORE_LINE("26", "6") \ + STORE_LINE("29", "7") \ + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + "add x4, x3, %x[LDC]\n" + "add x5, x4, %x[LDC]\n" + "add x6, x5, %x[LDC]\n" + "add x7, x6, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + + "2: \n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ldr d5, [%[b_ptr]]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ldr x10, [%[b_ptr], #8]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "ins v5.d[1], x10\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "ldr d2, [%[b_ptr], #16]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "ldr x10, [%[b_ptr], #24]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "ins v2.d[1], x10\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "add %[b_ptr], %[b_ptr], #32\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", + "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x10", "cc", "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + static void kern_4x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int m_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = output; + +// clang-format off +#define LOAD_LINE(v0, v1, v2, n) \ + "cmp x10, #0\n" \ + "beq 102f\n" \ + "ld1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s}, [x" n "]\n" \ + "subs x10, x10, #1\n" + +#define LOAD_C \ + "mov x10, %x[m_remain]\n" \ + LOAD_LINE("8","9","10", "0") \ + LOAD_LINE("11","12","13", "1") \ + LOAD_LINE("14","15","16", "2") \ + LOAD_LINE("17","18","19", "3") \ + "102:\n" + +#define STORE_LINE(v0, v1, v2, n) \ + "cmp x10, #0 \n" \ + "beq 105f\n" \ + "st1 {v" v0 ".4s, v" v1 ".4s, v" v2 ".4s}, [x" n "]\n" \ + "subs x10, x10, #1\n" + + +#define STORE_C \ + "mov x10, %x[m_remain]\n" \ + STORE_LINE("8","9","10", "0") \ + STORE_LINE("11","12","13", "1") \ + STORE_LINE("14","15","16", "2") \ + STORE_LINE("17","18","19", "3") \ + "105:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr d5, [%[b_ptr]]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "ldr x20, [%[b_ptr], #8]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "ldr d6, [%[b_ptr], #16]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ldr x21, [%[b_ptr], #24]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "ins v5.d[1], x20\n" + + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ldr d7, [%[b_ptr], #32]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "ldr x22, [%[b_ptr], #40]\n" + + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + + "fmla v15.4s, v3.4s, v0.s[2]\n" + "ins v6.d[1], x21\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "ins v7.d[1], x22\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "ldr d2, [%[b_ptr], #48]\n" + "fmla v9.4s, v6.4s, v1.s[0]\n" + "ldr x20, [%[b_ptr], #56]\n" + "fmla v10.4s, v7.4s, v1.s[0]\n" + "ldr d3, [%[b_ptr], #64]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "ldr x21, [%[b_ptr], #72]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "ldr d4, [%[b_ptr], #80]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "ldr x22, [%[b_ptr], #88]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "ins v2.d[1], x20\n" + "fmla v15.4s, v6.4s, v1.s[2]\n" + "ins v3.d[1], x21\n" + + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + + "fmla v16.4s, v7.4s, v1.s[2]\n" + "ins v4.d[1], x22\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v18.4s, v6.4s, v1.s[3]\n" + "subs %w[K], %w[K], #1\n" + "fmla v19.4s, v7.4s, v1.s[3]\n" + + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v2.4s, v0.s[0]\n" + "ldr d5, [%[b_ptr]]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "ldr x20, [%[b_ptr], #8]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "ldr d6, [%[b_ptr], #16]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ldr x21, [%[b_ptr], #24]\n" + + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + + "fmla v12.4s, v3.4s, v0.s[1]\n" + "ldr d7, [%[b_ptr], #32]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ins v5.d[1], x20\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "ldr x22, [%[b_ptr], #40]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "ins v6.d[1], x21\n" + + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "ins v7.d[1], x22\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v9.4s, v6.4s, v1.s[0]\n" + "fmla v10.4s, v7.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v15.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v1.s[3]\n" + "fmla v19.4s, v7.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr), + [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "x1", "x2", "x3", "x10", "x20", "x21", "x22", "cc", + "memory"); + +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C + } + + // Overview of register layout: + // + // A 2x4 cell of Rhs is stored in 32bit in v2 - v3 + // A 4x2 cell of Lhs is stored in 32bit in v0 - v1 + // A 4x4 block of accumulators is stored in 32bit in v4-v6 + // + // +--------+ + // | v2[0-3]| + // | v5[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_4x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + LDC = LDC * sizeof(float); + register float* outptr asm("x0") = output; + +// clang-format off +#define LOAD_LINE(v0, n) \ + "cmp x10, #0\n" \ + "beq 102f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 100" n "f\n" \ + "ld1 {v" v0 ".4s}, [x" n "], 16\n" \ + "b 101" n "f\n" \ + "100" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[0], [x" n "], 4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[1], [x" n "], 4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 101" n "f\n" \ + "ld1 {v" v0 ".s}[2], [x" n "], 4\n" \ + "101" n ":\n" \ + "subs x10, x10, #1\n" + +#define LOAD_C \ + "mov x10, %x[m_remain]\n" \ + LOAD_LINE("8", "0") \ + LOAD_LINE("11", "1") \ + LOAD_LINE("14", "2") \ + LOAD_LINE("17", "3") \ + "102:\n" + +#define STORE_LINE(v0, n) \ + "cmp x10, #0 \n" \ + "beq 105f\n" \ + "cmp %w[n_remain], #4\n" \ + "blt 103" n "f\n" \ + "st1 {v" v0 ".4s}, [x" n " ], 16\n" \ + "b 104" n "f\n" \ + "103" n ":\n" \ + "cmp %w[n_remain], #0\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[0], [x" n "], 4\n" \ + "cmp %w[n_remain], #1\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[1], [x" n "], 4\n" \ + "cmp %w[n_remain], #2\n" \ + "beq 104" n "f\n" \ + "st1 {v" v0 ".s}[2], [x" n "], 4\n" \ + "104" n ":\n" \ + "subs x10, x10, #1\n" + + +#define STORE_C \ + "mov x10, %x[m_remain]\n" \ + STORE_LINE("8", "0") \ + STORE_LINE("11", "1") \ + STORE_LINE("14", "2") \ + STORE_LINE("17", "3") \ + "105:\n" + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + + "2: \n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ld1 {v5.4s}, [%[b_ptr]], 16\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + + "fmla v8.4s, v5.4s, v1.s[0]\n" + "fmla v11.4s, v5.4s, v1.s[1]\n" + "fmla v14.4s, v5.4s, v1.s[2]\n" + "fmla v17.4s, v5.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [oddk] "+r"(oddk), [outptr] "+r"(outptr), + [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) + : + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", + "x3", "x10", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C + } +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h index 0e2bab774..2a230562b 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h @@ -6,319 +6,352 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" - namespace megdnn { namespace aarch64 { -namespace matmul_mk4_8x12 { - -// Overview of register layout: -// -// A 1x12 cell of Rhs is stored in 32bit in v2-v7 -// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) -// A 8x12 block of accumulators is stored in 32bit in v8-v31. -// -// +--------+--------+--------+ -// | v2[0-3]| v3[0-3]| v4[0-3]| -// | v5[0-3]| v6[0-3]| v7[0-3]| -// Rhs +--------+--------+--------+ -// -// | | | | -// -// Lhs | | | | -// -// +--+ --- - +--------+--------+--------+ -// |v0| | v8[0-3]| v9[0-3]|v10[0-3]| -// |v0| |v11[0-3]|v12[0-3]|v13[0-3]| -// |v0| |v14[0-3]|v15[0-3]|v16[0-3]| -// |v0| |v17[0-3]|v18[0-3]|v19[0-3]| -// |v1| |v20[0-3]|v21[0-3]|v22[0-3]| -// |v1| |v23[0-3]|v24[0-3]|v25[0-3]| -// |v1| |v26[0-3]|v27[0-3]|v28[0-3]| -// |v1| |v29[0-3]|v30[0-3]|v31[0-3]| -// +--+ --- - +--------+--------+--------+ -// -// Accumulator -void kern_8x12(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k) { - const float* a_ptr = packA; - const float* b_ptr = packB; - float* output0 = output; - float* output1 = output0 + LDC; - - int oddk = (K & 1); - K = ((K + 1) / 2) - 1; - - asm volatile( - "cmp %w[is_first_k], #1\n" - "beq 1f\n" - "mov x1, %[output0]\n" - "mov x2, %[output1]\n" - "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n" - "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n" - "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n" - "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64\n" - "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64\n" - "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64\n" - - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" - "b 2f\n" - - "1:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "eor v9.16b, v9.16b, v9.16b\n" - "eor v10.16b, v10.16b, v10.16b\n" - "prfm pstl1keep, [%[output0]]\n" - "eor v11.16b, v11.16b, v11.16b\n" - "eor v12.16b, v12.16b, v12.16b\n" - "eor v13.16b, v13.16b, v13.16b\n" - "prfm pstl1keep, [%[output1]]\n" - "eor v14.16b, v14.16b, v14.16b\n" - "eor v15.16b, v15.16b, v15.16b\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" - "eor v16.16b, v16.16b, v16.16b\n" - "eor v17.16b, v17.16b, v17.16b\n" - "eor v18.16b, v18.16b, v18.16b\n" - "eor v19.16b, v19.16b, v19.16b\n" - "eor v20.16b, v20.16b, v20.16b\n" - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "eor v21.16b, v21.16b, v21.16b\n" - "eor v22.16b, v22.16b, v22.16b\n" - "eor v23.16b, v23.16b, v23.16b\n" - "eor v24.16b, v24.16b, v24.16b\n" - "eor v25.16b, v25.16b, v25.16b\n" - "eor v26.16b, v26.16b, v26.16b\n" - "eor v27.16b, v27.16b, v27.16b\n" - "eor v28.16b, v28.16b, v28.16b\n" - "eor v29.16b, v29.16b, v29.16b\n" - "eor v30.16b, v30.16b, v30.16b\n" - "eor v31.16b, v31.16b, v31.16b\n" - - "2: \n" - "cmp %w[K], #0\n" - "beq 4f\n" - - "3:\n" - "fmla v8.4s, v0.4s, v2.s[0]\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - "fmla v12.4s, v0.4s, v3.s[0]\n" - "fmla v13.4s, v0.4s, v3.s[1]\n" - "fmla v14.4s, v0.4s, v3.s[2]\n" - "fmla v15.4s, v0.4s, v3.s[3]\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "fmla v17.4s, v0.4s, v4.s[1]\n" - "fmla v18.4s, v0.4s, v4.s[2]\n" - "fmla v19.4s, v0.4s, v4.s[3]\n" - - "fmla v20.4s, v1.4s, v2.s[0]\n" - "fmla v21.4s, v1.4s, v2.s[1]\n" - "fmla v22.4s, v1.4s, v2.s[2]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "fmla v24.4s, v1.4s, v3.s[0]\n" - "fmla v25.4s, v1.4s, v3.s[1]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v26.4s, v1.4s, v3.s[2]\n" - "fmla v27.4s, v1.4s, v3.s[3]\n" - "fmla v28.4s, v1.4s, v4.s[0]\n" - "fmla v29.4s, v1.4s, v4.s[1]\n" - "fmla v30.4s, v1.4s, v4.s[2]\n" - "fmla v31.4s, v1.4s, v4.s[3]\n" - - "fmla v8.4s, v0.4s, v5.s[0]\n" - "fmla v9.4s, v0.4s, v5.s[1]\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" - "fmla v10.4s, v0.4s, v5.s[2]\n" - "fmla v11.4s, v0.4s, v5.s[3]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v12.4s, v0.4s, v6.s[0]\n" - "fmla v13.4s, v0.4s, v6.s[1]\n" - "fmla v14.4s, v0.4s, v6.s[2]\n" - "fmla v15.4s, v0.4s, v6.s[3]\n" - "fmla v16.4s, v0.4s, v7.s[0]\n" - "fmla v17.4s, v0.4s, v7.s[1]\n" - "fmla v18.4s, v0.4s, v7.s[2]\n" - "fmla v19.4s, v0.4s, v7.s[3]\n" - - "fmla v20.4s, v1.4s, v5.s[0]\n" - "fmla v21.4s, v1.4s, v5.s[1]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v22.4s, v1.4s, v5.s[2]\n" - "fmla v23.4s, v1.4s, v5.s[3]\n" - "fmla v24.4s, v1.4s, v6.s[0]\n" - "subs %w[K], %w[K], #1\n" - "fmla v25.4s, v1.4s, v6.s[1]\n" - "fmla v26.4s, v1.4s, v6.s[2]\n" - "fmla v27.4s, v1.4s, v6.s[3]\n" - "fmla v28.4s, v1.4s, v7.s[0]\n" - "fmla v29.4s, v1.4s, v7.s[1]\n" - "fmla v30.4s, v1.4s, v7.s[2]\n" - "fmla v31.4s, v1.4s, v7.s[3]\n" - - "bne 3b\n" - - "4:\n" - "cmp %w[oddk], #1\n" - "beq 5f\n" - - // Even tail - "fmla v8.4s, v0.4s, v2.s[0]\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - "fmla v12.4s, v0.4s, v3.s[0]\n" - "fmla v13.4s, v0.4s, v3.s[1]\n" - "fmla v14.4s, v0.4s, v3.s[2]\n" - "fmla v15.4s, v0.4s, v3.s[3]\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "fmla v17.4s, v0.4s, v4.s[1]\n" - "fmla v18.4s, v0.4s, v4.s[2]\n" - "fmla v19.4s, v0.4s, v4.s[3]\n" - - "fmla v20.4s, v1.4s, v2.s[0]\n" - "fmla v21.4s, v1.4s, v2.s[1]\n" - "fmla v22.4s, v1.4s, v2.s[2]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "fmla v24.4s, v1.4s, v3.s[0]\n" - "fmla v25.4s, v1.4s, v3.s[1]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v26.4s, v1.4s, v3.s[2]\n" - "fmla v27.4s, v1.4s, v3.s[3]\n" - "fmla v28.4s, v1.4s, v4.s[0]\n" - "fmla v29.4s, v1.4s, v4.s[1]\n" - "fmla v30.4s, v1.4s, v4.s[2]\n" - "fmla v31.4s, v1.4s, v4.s[3]\n" - - "fmla v8.4s, v0.4s, v5.s[0]\n" - "fmla v9.4s, v0.4s, v5.s[1]\n" - "fmla v10.4s, v0.4s, v5.s[2]\n" - "fmla v11.4s, v0.4s, v5.s[3]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v12.4s, v0.4s, v6.s[0]\n" - "fmla v13.4s, v0.4s, v6.s[1]\n" - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n" - "fmla v14.4s, v0.4s, v6.s[2]\n" - "fmla v15.4s, v0.4s, v6.s[3]\n" - "fmla v16.4s, v0.4s, v7.s[0]\n" - "fmla v17.4s, v0.4s, v7.s[1]\n" - "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n" - "fmla v18.4s, v0.4s, v7.s[2]\n" - "fmla v19.4s, v0.4s, v7.s[3]\n" - - "fmla v20.4s, v1.4s, v5.s[0]\n" - "fmla v21.4s, v1.4s, v5.s[1]\n" - "fmla v22.4s, v1.4s, v5.s[2]\n" - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n" - "fmla v23.4s, v1.4s, v5.s[3]\n" - "fmla v24.4s, v1.4s, v6.s[0]\n" - "fmla v25.4s, v1.4s, v6.s[1]\n" - "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output1]], #64\n" - "fmla v26.4s, v1.4s, v6.s[2]\n" - "fmla v27.4s, v1.4s, v6.s[3]\n" - "fmla v28.4s, v1.4s, v7.s[0]\n" - "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64\n" - "fmla v29.4s, v1.4s, v7.s[1]\n" - "fmla v30.4s, v1.4s, v7.s[2]\n" - "fmla v31.4s, v1.4s, v7.s[3]\n" - "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64\n" - "b 6f\n" - - // odd tail - "5:\n" - "fmla v8.4s, v0.4s, v2.s[0]\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - "fmla v12.4s, v0.4s, v3.s[0]\n" - "fmla v13.4s, v0.4s, v3.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v14.4s, v0.4s, v3.s[2]\n" - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n" - "fmla v15.4s, v0.4s, v3.s[3]\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "fmla v17.4s, v0.4s, v4.s[1]\n" - "fmla v18.4s, v0.4s, v4.s[2]\n" - "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n" - "fmla v19.4s, v0.4s, v4.s[3]\n" - - "fmla v20.4s, v1.4s, v2.s[0]\n" - "fmla v21.4s, v1.4s, v2.s[1]\n" - "fmla v22.4s, v1.4s, v2.s[2]\n" - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "fmla v24.4s, v1.4s, v3.s[0]\n" - "fmla v25.4s, v1.4s, v3.s[1]\n" - "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output1]], #64\n" - "fmla v26.4s, v1.4s, v3.s[2]\n" - "fmla v27.4s, v1.4s, v3.s[3]\n" - "fmla v28.4s, v1.4s, v4.s[0]\n" - "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64\n" - "fmla v29.4s, v1.4s, v4.s[1]\n" - "fmla v30.4s, v1.4s, v4.s[2]\n" - "fmla v31.4s, v1.4s, v4.s[3]\n" - "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64\n" - - "6:\n" - : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), - [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), - [ output0 ] "+r"(output0), [ output1 ] "+r"(output1) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "x1", "x2", "cc", "memory"); -} - -// Overview of register layout: -// -// A 1x12 cell of Rhs is stored in 32bit in v2-v7 -// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) -// A 8x12 block of accumulators is stored in 32bit in v8-v31. -// -// +--------+ -// | v2[0-3]| -// | v3[0-3]| -// Rhs +--------+ -// -// | | -// -// Lhs | | -// -// +--+ --- - +--------+ -// |v0| | v8[0-3]| -// |v0| |v11[0-3]| -// |v0| |v14[0-3]| -// |v0| |v17[0-3]| -// |v1| |v20[0-3]| -// |v1| |v23[0-3]| -// |v1| |v26[0-3]| -// |v1| |v29[0-3]| -// +--+ --- - +--------+ -// -// Accumulator -void kern_8x4(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k, int n_remain) { - const float* a_ptr = packA; - const float* b_ptr = packB; - float* output0 = output; - float* output1 = output0 + LDC; - - int oddk = (K & 1); - K = ((K + 1) / 2) - 1; - - //clang-format off +struct matmul_mk4_8x12 { + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // |v1| |v20[0-3]|v21[0-3]|v22[0-3]| + // |v1| |v23[0-3]|v24[0-3]|v25[0-3]| + // |v1| |v26[0-3]|v27[0-3]|v28[0-3]| + // |v1| |v29[0-3]|v30[0-3]|v31[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + static void kern_8x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k) { + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + float* output1 = output0 + LDC; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + asm volatile( + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "mov x1, %[output0]\n" + "mov x2, %[output1]\n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64\n" + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "prfm pstl1keep, [%[output1]]\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "eor v16.16b, v16.16b, v16.16b\n" + "ld1 {v3.4s}, [%[b_ptr]], #16\n" + "eor v17.16b, v17.16b, v17.16b\n" + "ld1 {v4.4s}, [%[b_ptr]], #16\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ld1 {v5.4s}, [%[b_ptr]], #16\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "ld1 {v6.4s}, [%[b_ptr]], #16\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "ld1 {v7.4s}, [%[b_ptr]], #16\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "fmla v21.4s, v1.4s, v2.s[1]\n" + "fmla v22.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "fmla v24.4s, v1.4s, v3.s[0]\n" + "fmla v25.4s, v1.4s, v3.s[1]\n" + "fmla v26.4s, v1.4s, v3.s[2]\n" + "fmla v27.4s, v1.4s, v3.s[3]\n" + "fmla v28.4s, v1.4s, v4.s[0]\n" + "fmla v29.4s, v1.4s, v4.s[1]\n" + "fmla v30.4s, v1.4s, v4.s[2]\n" + "fmla v31.4s, v1.4s, v4.s[3]\n" + + "fmla v8.4s, v0.4s, v5.s[0]\n" + "fmla v9.4s, v0.4s, v5.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v0.4s, v5.s[2]\n" + "fmla v11.4s, v0.4s, v5.s[3]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmla v12.4s, v0.4s, v6.s[0]\n" + "fmla v13.4s, v0.4s, v6.s[1]\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v14.4s, v0.4s, v6.s[2]\n" + "fmla v15.4s, v0.4s, v6.s[3]\n" + "ld1 {v4.4s}, [%[b_ptr]], 16\n" + "fmla v16.4s, v0.4s, v7.s[0]\n" + "fmla v17.4s, v0.4s, v7.s[1]\n" + "fmla v18.4s, v0.4s, v7.s[2]\n" + "fmla v19.4s, v0.4s, v7.s[3]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + + "fmla v20.4s, v1.4s, v5.s[0]\n" + "fmla v21.4s, v1.4s, v5.s[1]\n" + "fmla v22.4s, v1.4s, v5.s[2]\n" + "fmla v23.4s, v1.4s, v5.s[3]\n" + "fmla v24.4s, v1.4s, v6.s[0]\n" + "subs %w[K], %w[K], #1\n" + "fmla v25.4s, v1.4s, v6.s[1]\n" + "fmla v26.4s, v1.4s, v6.s[2]\n" + "fmla v27.4s, v1.4s, v6.s[3]\n" + "fmla v28.4s, v1.4s, v7.s[0]\n" + "fmla v29.4s, v1.4s, v7.s[1]\n" + "fmla v30.4s, v1.4s, v7.s[2]\n" + "fmla v31.4s, v1.4s, v7.s[3]\n" + + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "ld1 {v5.4s}, [%[b_ptr]], #16\n" + "fmla v21.4s, v1.4s, v2.s[1]\n" + "fmla v22.4s, v1.4s, v2.s[2]\n" + "ld1 {v6.4s}, [%[b_ptr]], #16\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "fmla v24.4s, v1.4s, v3.s[0]\n" + "ld1 {v7.4s}, [%[b_ptr]], #16\n" + "fmla v25.4s, v1.4s, v3.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v1.4s, v3.s[2]\n" + "fmla v27.4s, v1.4s, v3.s[3]\n" + "fmla v28.4s, v1.4s, v4.s[0]\n" + "fmla v29.4s, v1.4s, v4.s[1]\n" + "fmla v30.4s, v1.4s, v4.s[2]\n" + "fmla v31.4s, v1.4s, v4.s[3]\n" + + "fmla v8.4s, v0.4s, v5.s[0]\n" + "fmla v9.4s, v0.4s, v5.s[1]\n" + "fmla v10.4s, v0.4s, v5.s[2]\n" + "fmla v11.4s, v0.4s, v5.s[3]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v12.4s, v0.4s, v6.s[0]\n" + "fmla v13.4s, v0.4s, v6.s[1]\n" + + "fmla v14.4s, v0.4s, v6.s[2]\n" + "fmla v15.4s, v0.4s, v6.s[3]\n" + "st1 {v8.4s}, [%[output0]], #16\n" + "fmla v16.4s, v0.4s, v7.s[0]\n" + "st1 {v9.4s}, [%[output0]], #16\n" + "fmla v17.4s, v0.4s, v7.s[1]\n" + "st1 {v10.4s}, [%[output0]], #16\n" + "fmla v18.4s, v0.4s, v7.s[2]\n" + "st1 {v11.4s}, [%[output0]], #16\n" + "fmla v19.4s, v0.4s, v7.s[3]\n" + "st1 {v12.4s}, [%[output0]], #16\n" + + "fmla v20.4s, v1.4s, v5.s[0]\n" + "st1 {v13.4s}, [%[output0]], #16\n" + "fmla v21.4s, v1.4s, v5.s[1]\n" + "st1 {v14.4s}, [%[output0]], #16\n" + "fmla v22.4s, v1.4s, v5.s[2]\n" + "st1 {v15.4s}, [%[output0]], #16\n" + "fmla v23.4s, v1.4s, v5.s[3]\n" + "st1 {v16.4s}, [%[output0]], #16\n" + "fmla v24.4s, v1.4s, v6.s[0]\n" + "st1 {v17.4s}, [%[output0]], #16\n" + "fmla v25.4s, v1.4s, v6.s[1]\n" + "st1 {v18.4s}, [%[output0]], #16\n" + "fmla v26.4s, v1.4s, v6.s[2]\n" + "st1 {v19.4s}, [%[output0]], #16\n" + "fmla v27.4s, v1.4s, v6.s[3]\n" + "st1 {v20.4s}, [%[output1]], #16\n" + "fmla v28.4s, v1.4s, v7.s[0]\n" + "st1 {v21.4s}, [%[output1]], #16\n" + "fmla v29.4s, v1.4s, v7.s[1]\n" + "st1 {v22.4s}, [%[output1]], #16\n" + "fmla v30.4s, v1.4s, v7.s[2]\n" + "st1 {v23.4s}, [%[output1]], #16\n" + "fmla v31.4s, v1.4s, v7.s[3]\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "st1 {v8.4s}, [%[output0]], #16\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "st1 {v9.4s}, [%[output0]], #16\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "st1 {v10.4s}, [%[output0]], #16\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "st1 {v11.4s}, [%[output0]], #16\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "st1 {v12.4s}, [%[output0]], #16\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + "st1 {v13.4s}, [%[output0]], #16\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "st1 {v14.4s}, [%[output0]], #16\n" + "fmla v21.4s, v1.4s, v2.s[1]\n" + "st1 {v15.4s}, [%[output0]], #16\n" + "fmla v22.4s, v1.4s, v2.s[2]\n" + "st1 {v16.4s}, [%[output0]], #16\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "st1 {v17.4s}, [%[output0]], #16\n" + "fmla v24.4s, v1.4s, v3.s[0]\n" + "st1 {v18.4s}, [%[output0]], #16\n" + "fmla v25.4s, v1.4s, v3.s[1]\n" + "st1 {v19.4s}, [%[output0]], #16\n" + "fmla v26.4s, v1.4s, v3.s[2]\n" + "st1 {v20.4s}, [%[output1]], #16\n" + "fmla v27.4s, v1.4s, v3.s[3]\n" + "st1 {v21.4s}, [%[output1]], #16\n" + "fmla v28.4s, v1.4s, v4.s[0]\n" + "st1 {v22.4s}, [%[output1]], #16\n" + "fmla v29.4s, v1.4s, v4.s[1]\n" + "st1 {v23.4s}, [%[output1]], #16\n" + "fmla v30.4s, v1.4s, v4.s[2]\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64\n" + "fmla v31.4s, v1.4s, v4.s[3]\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64\n" + + "6:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0), [output1] "+r"(output1) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "x1", "x2", "cc", "memory"); + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+ + // | v2[0-3]| + // | v3[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // |v1| |v20[0-3]| + // |v1| |v23[0-3]| + // |v1| |v26[0-3]| + // |v1| |v29[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_8x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + float* output1 = output0 + LDC; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + //clang-format off #define LOAD_C \ "cmp %w[n_remain], #4\n" \ "blt 11f\n" \ @@ -364,511 +397,509 @@ void kern_8x4(const float* packA, const float* packB, int K, float* output, "st1 {v8.4s}, [%[output0]]\n" \ "st1 {v12.4s},[%[output1]]\n" \ "24:\n" - //clang-format on - - asm volatile( - // load accumulator C - "cmp %w[is_first_k], #1\n" - "beq 1f\n" LOAD_C - - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "ld1 {v2.4s}, [%[b_ptr]], #16\n" - "b 2f\n" - - "1:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "eor v9.16b, v9.16b, v9.16b\n" - "eor v10.16b, v10.16b, v10.16b\n" - "prfm pstl1keep, [%[output0]]\n" - "eor v11.16b, v11.16b, v11.16b\n" - "eor v12.16b, v12.16b, v12.16b\n" - "prfm pstl1keep, [%[output1]]\n" - "eor v13.16b, v13.16b, v13.16b\n" - "eor v14.16b, v14.16b, v14.16b\n" - "eor v15.16b, v15.16b, v15.16b\n" - "ld1 {v2.4s}, [%[b_ptr]], #16\n" - - "2: \n" - "cmp %w[K], #0\n" - "beq 4f\n" - - "3:\n" - "fmla v8.4s, v0.4s, v2.s[0]\n" - "ld1 {v1.4s}, [%[a_ptr]], #16\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "ld1 {v3.4s}, [%[b_ptr]], #16\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - "fmla v12.4s, v1.4s, v2.s[0]\n" - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "fmla v13.4s, v1.4s, v2.s[1]\n" - "fmla v14.4s, v1.4s, v2.s[2]\n" - "fmla v15.4s, v1.4s, v2.s[3]\n" - - "fmla v8.4s, v0.4s, v3.s[0]\n" - "ld1 {v1.4s}, [%[a_ptr]], #16\n" - "fmla v9.4s, v0.4s, v3.s[1]\n" - "fmla v10.4s, v0.4s, v3.s[2]\n" - "fmla v11.4s, v0.4s, v3.s[3]\n" - "ld1 {v2.4s}, [%[b_ptr]], #16\n" - "fmla v12.4s, v1.4s, v3.s[0]\n" - "subs %w[K], %w[K], #1\n" - "fmla v13.4s, v1.4s, v3.s[1]\n" - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "fmla v14.4s, v1.4s, v3.s[2]\n" - "fmla v15.4s, v1.4s, v3.s[3]\n" - "bne 3b\n" - - "4:\n" - "cmp %w[oddk], #1\n" - "beq 5f\n" - - // Even tail - "fmla v8.4s, v0.4s, v2.s[0]\n" - "ld1 {v1.4s}, [%[a_ptr]], #16\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "ld1 {v3.4s}, [%[b_ptr]], #16\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - "fmla v12.4s, v1.4s, v2.s[0]\n" - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "fmla v13.4s, v1.4s, v2.s[1]\n" - "fmla v14.4s, v1.4s, v2.s[2]\n" - "fmla v15.4s, v1.4s, v2.s[3]\n" - - "fmla v8.4s, v0.4s, v3.s[0]\n" - "ld1 {v1.4s}, [%[a_ptr]], #16\n" - "fmla v9.4s, v0.4s, v3.s[1]\n" - "fmla v10.4s, v0.4s, v3.s[2]\n" - "fmla v11.4s, v0.4s, v3.s[3]\n" - "fmla v12.4s, v1.4s, v3.s[0]\n" - "fmla v13.4s, v1.4s, v3.s[1]\n" - "fmla v14.4s, v1.4s, v3.s[2]\n" - "fmla v15.4s, v1.4s, v3.s[3]\n" - "b 6f\n" - - // odd tail - "5:\n" - "fmla v8.4s, v0.4s, v2.s[0]\n" - "ld1 {v1.4s}, [%[a_ptr]], #16\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - "fmla v12.4s, v1.4s, v2.s[0]\n" - "fmla v13.4s, v1.4s, v2.s[1]\n" - "fmla v14.4s, v1.4s, v2.s[2]\n" - "fmla v15.4s, v1.4s, v2.s[3]\n" - - "6:\n" STORE_C - - : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), - [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), - [ output0 ] "+r"(output0), [ output1 ] "+r"(output1), - [ n_remain ] "+r"(n_remain) - : - : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", - "v14", "v15", "cc", "memory"); + //clang-format on + + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "prfm pstl1keep, [%[output1]]\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], #16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v1.4s, v2.s[0]\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "fmla v13.4s, v1.4s, v2.s[1]\n" + "fmla v14.4s, v1.4s, v2.s[2]\n" + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "fmla v8.4s, v0.4s, v3.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v3.s[1]\n" + "fmla v10.4s, v0.4s, v3.s[2]\n" + "fmla v11.4s, v0.4s, v3.s[3]\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "fmla v12.4s, v1.4s, v3.s[0]\n" + "subs %w[K], %w[K], #1\n" + "fmla v13.4s, v1.4s, v3.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "fmla v14.4s, v1.4s, v3.s[2]\n" + "fmla v15.4s, v1.4s, v3.s[3]\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], #16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v1.4s, v2.s[0]\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "fmla v13.4s, v1.4s, v2.s[1]\n" + "fmla v14.4s, v1.4s, v2.s[2]\n" + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "fmla v8.4s, v0.4s, v3.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v3.s[1]\n" + "fmla v10.4s, v0.4s, v3.s[2]\n" + "fmla v11.4s, v0.4s, v3.s[3]\n" + "fmla v12.4s, v1.4s, v3.s[0]\n" + "fmla v13.4s, v1.4s, v3.s[1]\n" + "fmla v14.4s, v1.4s, v3.s[2]\n" + "fmla v15.4s, v1.4s, v3.s[3]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v1.4s, v2.s[0]\n" + "fmla v13.4s, v1.4s, v2.s[1]\n" + "fmla v14.4s, v1.4s, v2.s[2]\n" + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0), [output1] "+r"(output1), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "cc", "memory"); #undef LOAD_C #undef STORE_C -} - - -// Overview of register layout: -// -// A 1x12 cell of Rhs is stored in 32bit in v2-v7 -// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) -// A 8x12 block of accumulators is stored in 32bit in v8-v31. -// -// +--------+--------+--------+ -// | v2[0-3]| v3[0-3]| v4[0-3]| -// | v5[0-3]| v6[0-3]| v7[0-3]| -// Rhs +--------+--------+--------+ -// -// | | | | -// -// Lhs | | | | -// -// +--+ --- - +--------+--------+--------+ -// |v0| | v8[0-3]| v9[0-3]|v10[0-3]| -// |v0| |v11[0-3]|v12[0-3]|v13[0-3]| -// |v0| |v14[0-3]|v15[0-3]|v16[0-3]| -// |v0| |v17[0-3]|v18[0-3]|v19[0-3]| -// +--+ --- - +--------+--------+--------+ -// -// Accumulator - -void kern_4x12(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k) { - MEGDNN_MARK_USED_VAR(LDC); - const float* a_ptr = packA; - const float* b_ptr = packB; - float* output0 = output; - - int oddk = (K & 1); - K = ((K + 1) / 2) - 1; - - asm volatile( - "cmp %w[is_first_k], #1\n" - "beq 1f\n" - "mov x1, %[output0]\n" - "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n" - "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n" - "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n" - - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" - "b 2f\n" - - "1:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "eor v9.16b, v9.16b, v9.16b\n" - "eor v10.16b, v10.16b, v10.16b\n" - "prfm pstl1keep, [%[output0]]\n" - "eor v11.16b, v11.16b, v11.16b\n" - "eor v12.16b, v12.16b, v12.16b\n" - "eor v13.16b, v13.16b, v13.16b\n" - "eor v14.16b, v14.16b, v14.16b\n" - "eor v15.16b, v15.16b, v15.16b\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" - "eor v16.16b, v16.16b, v16.16b\n" - "eor v17.16b, v17.16b, v17.16b\n" - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "eor v18.16b, v18.16b, v18.16b\n" - "eor v19.16b, v19.16b, v19.16b\n" - - "2: \n" - "cmp %w[K], #0\n" - "beq 4f\n" - - "3:\n" - "fmla v8.4s, v0.4s, v2.s[0]\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - "fmla v12.4s, v0.4s, v3.s[0]\n" - "fmla v13.4s, v0.4s, v3.s[1]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n" - "fmla v14.4s, v0.4s, v3.s[2]\n" - "fmla v15.4s, v0.4s, v3.s[3]\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "fmla v17.4s, v0.4s, v4.s[1]\n" - "fmla v18.4s, v0.4s, v4.s[2]\n" - "fmla v19.4s, v0.4s, v4.s[3]\n" - - "fmla v8.4s, v1.4s, v5.s[0]\n" - "fmla v9.4s, v1.4s, v5.s[1]\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" - "fmla v10.4s, v1.4s, v5.s[2]\n" - "fmla v11.4s, v1.4s, v5.s[3]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v12.4s, v1.4s, v6.s[0]\n" - "fmla v13.4s, v1.4s, v6.s[1]\n" - "subs %w[K], %w[K], #1\n" - "fmla v14.4s, v1.4s, v6.s[2]\n" - "fmla v15.4s, v1.4s, v6.s[3]\n" - "fmla v16.4s, v1.4s, v7.s[0]\n" - "fmla v17.4s, v1.4s, v7.s[1]\n" - "fmla v18.4s, v1.4s, v7.s[2]\n" - "fmla v19.4s, v1.4s, v7.s[3]\n" - "bne 3b\n" - - "4:\n" - "cmp %w[oddk], #1\n" - "beq 5f\n" - - // Even tail - "fmla v8.4s, v0.4s, v2.s[0]\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n" - "fmla v12.4s, v0.4s, v3.s[0]\n" - "fmla v13.4s, v0.4s, v3.s[1]\n" - "fmla v14.4s, v0.4s, v3.s[2]\n" - "fmla v15.4s, v0.4s, v3.s[3]\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "fmla v17.4s, v0.4s, v4.s[1]\n" - "fmla v18.4s, v0.4s, v4.s[2]\n" - "fmla v19.4s, v0.4s, v4.s[3]\n" - - "fmla v8.4s, v1.4s, v5.s[0]\n" - "fmla v9.4s, v1.4s, v5.s[1]\n" - "fmla v10.4s, v1.4s, v5.s[2]\n" - "fmla v11.4s, v1.4s, v5.s[3]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v12.4s, v1.4s, v6.s[0]\n" - "fmla v13.4s, v1.4s, v6.s[1]\n" - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n" - "fmla v14.4s, v1.4s, v6.s[2]\n" - "fmla v15.4s, v1.4s, v6.s[3]\n" - "fmla v16.4s, v1.4s, v7.s[0]\n" - "fmla v17.4s, v1.4s, v7.s[1]\n" - "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n" - "fmla v18.4s, v1.4s, v7.s[2]\n" - "fmla v19.4s, v1.4s, v7.s[3]\n" - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n" - - "b 6f\n" - - // odd tail - "5:\n" - "fmla v8.4s, v0.4s, v2.s[0]\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - "fmla v12.4s, v0.4s, v3.s[0]\n" - "fmla v13.4s, v0.4s, v3.s[1]\n" - "fmla v14.4s, v0.4s, v3.s[2]\n" - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n" - "fmla v15.4s, v0.4s, v3.s[3]\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "fmla v17.4s, v0.4s, v4.s[1]\n" - "fmla v18.4s, v0.4s, v4.s[2]\n" - "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n" - "fmla v19.4s, v0.4s, v4.s[3]\n" - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n" - - "6:\n" - : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), - [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), - [ output0 ] "+r"(output0) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "x1", "cc", "memory"); -} - - - - -// Overview of register layout: -// -// A 2x4 cell of Rhs is stored in 32bit in v2 - v3 -// A 4x2 cell of Lhs is stored in 32bit in v0 - v1 -// A 4x4 block of accumulators is stored in 32bit in v4-v6 -// -// +--------+ -// | v2[0-3]| -// | v5[0-3]| -// Rhs +--------+ -// -// | | -// -// Lhs | | -// -// +--+ --- - +--------+ -// |v0| | v8[0-3]| -// |v0| |v11[0-3]| -// |v0| |v14[0-3]| -// |v0| |v17[0-3]| -// +--+ --- - +--------+ -// -// Accumulator -void kern_4x4(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k, int n_remain) { - MEGDNN_MARK_USED_VAR(LDC); - const float* a_ptr = packA; - const float* b_ptr = packB; - float* output0 = output; - - int oddk = (K & 1); - K = ((K + 1) / 2) - 1; - - //clang-format off -#define LOAD_C \ - "cmp %w[n_remain], #4\n" \ - "blt 11f\n" \ - "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ - "b 14f\n" \ - "11:\n" \ - "cmp %w[n_remain], #3\n" \ - "blt 12f\n" \ - "ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ - "b 14f\n" \ - "12:\n" \ - "cmp %w[n_remain], #2\n" \ - "blt 13f\n" \ - "ld1 {v8.4s, v9.4s}, [%[output0]]\n" \ - "b 14f\n" \ - "13:\n" \ - "ld1 {v8.4s}, [%[output0]]\n" \ + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + + static void kern_4x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + asm volatile( + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "mov x1, %[output0]\n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n" + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "fmla v8.4s, v1.4s, v5.s[0]\n" + "fmla v9.4s, v1.4s, v5.s[1]\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "fmla v10.4s, v1.4s, v5.s[2]\n" + "fmla v11.4s, v1.4s, v5.s[3]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v12.4s, v1.4s, v6.s[0]\n" + "fmla v13.4s, v1.4s, v6.s[1]\n" + "subs %w[K], %w[K], #1\n" + "fmla v14.4s, v1.4s, v6.s[2]\n" + "fmla v15.4s, v1.4s, v6.s[3]\n" + "fmla v16.4s, v1.4s, v7.s[0]\n" + "fmla v17.4s, v1.4s, v7.s[1]\n" + "fmla v18.4s, v1.4s, v7.s[2]\n" + "fmla v19.4s, v1.4s, v7.s[3]\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "fmla v8.4s, v1.4s, v5.s[0]\n" + "fmla v9.4s, v1.4s, v5.s[1]\n" + "fmla v10.4s, v1.4s, v5.s[2]\n" + "fmla v11.4s, v1.4s, v5.s[3]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v12.4s, v1.4s, v6.s[0]\n" + "fmla v13.4s, v1.4s, v6.s[1]\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n" + "fmla v14.4s, v1.4s, v6.s[2]\n" + "fmla v15.4s, v1.4s, v6.s[3]\n" + "fmla v16.4s, v1.4s, v7.s[0]\n" + "fmla v17.4s, v1.4s, v7.s[1]\n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n" + "fmla v18.4s, v1.4s, v7.s[2]\n" + "fmla v19.4s, v1.4s, v7.s[3]\n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n" + + "6:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "x1", "cc", "memory"); + } + + // Overview of register layout: + // + // A 2x4 cell of Rhs is stored in 32bit in v2 - v3 + // A 4x2 cell of Lhs is stored in 32bit in v0 - v1 + // A 4x4 block of accumulators is stored in 32bit in v4-v6 + // + // +--------+ + // | v2[0-3]| + // | v5[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_4x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int n_remain) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + //clang-format off +#define LOAD_C \ + "cmp %w[n_remain], #4\n" \ + "blt 11f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "11:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 12f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "12:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 13f\n" \ + "ld1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "13:\n" \ + "ld1 {v8.4s}, [%[output0]]\n" \ "14:\n" -#define STORE_C \ - "cmp %w[n_remain], #4\n" \ - "blt 21f\n" \ - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ - "b 24f\n" \ - "21:\n" \ - "cmp %w[n_remain], #3\n" \ - "blt 22f\n" \ - "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ - "b 24f\n" \ - "22:\n" \ - "cmp %w[n_remain], #2\n" \ - "blt 23f\n" \ - "st1 {v8.4s, v9.4s}, [%[output0]]\n" \ - "b 24f\n" \ - "23:\n" \ - "st1 {v8.4s}, [%[output0]]\n" \ +#define STORE_C \ + "cmp %w[n_remain], #4\n" \ + "blt 21f\n" \ + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "21:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 22f\n" \ + "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "22:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 23f\n" \ + "st1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "23:\n" \ + "st1 {v8.4s}, [%[output0]]\n" \ "24:\n" - //clang-format on - - asm volatile( - // load accumulator C - "cmp %w[is_first_k], #1\n" - "beq 1f\n" LOAD_C - - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "ld1 {v2.4s}, [%[b_ptr]], #16\n" - "b 2f\n" - - "1:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "ld1 {v2.4s}, [%[b_ptr]], #16\n" - "eor v9.16b, v9.16b, v9.16b\n" - "ld1 {v0.4s}, [%[a_ptr]], #16\n" - "eor v10.16b, v10.16b, v10.16b\n" - "prfm pstl1keep, [%[output0]]\n" - "eor v11.16b, v11.16b, v11.16b\n" - - "2: \n" - "cmp %w[K], #0\n" - "beq 4f\n" - - "3:\n" - "fmla v8.4s, v0.4s, v2.s[0]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "ld1 {v3.4s}, [%[b_ptr]], 16\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - - "fmla v8.4s, v1.4s, v3.s[0]\n" - "fmla v9.4s, v1.4s, v3.s[1]\n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "fmla v10.4s, v1.4s, v3.s[2]\n" - "fmla v11.4s, v1.4s, v3.s[3]\n" - "ld1 {v2.4s}, [%[b_ptr]], 16\n" - "subs %w[K], %w[K], #1\n" - "bne 3b\n" - - "4:\n" - "cmp %w[oddk], #1\n" - "beq 5f\n" - - // Even tail - "fmla v8.4s, v0.4s, v2.s[0]\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "ld1 {v3.4s}, [%[b_ptr]], 16\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - - "fmla v8.4s, v1.4s, v3.s[0]\n" - "fmla v9.4s, v1.4s, v3.s[1]\n" - "fmla v10.4s, v1.4s, v3.s[2]\n" - "fmla v11.4s, v1.4s, v3.s[3]\n" - "b 6f\n" - - // odd tail - "5:\n" - "fmla v8.4s, v0.4s, v2.s[0]\n" - "fmla v9.4s, v0.4s, v2.s[1]\n" - "fmla v10.4s, v0.4s, v2.s[2]\n" - "fmla v11.4s, v0.4s, v2.s[3]\n" - - "6:\n" STORE_C - - : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), - [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), - [ output0 ] "+r"(output0), [ n_remain ] "+r"(n_remain) - : - : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); + //clang-format on + + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "eor v9.16b, v9.16b, v9.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "fmla v8.4s, v1.4s, v3.s[0]\n" + "fmla v9.4s, v1.4s, v3.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v1.4s, v3.s[2]\n" + "fmla v11.4s, v1.4s, v3.s[3]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "fmla v8.4s, v1.4s, v3.s[0]\n" + "fmla v9.4s, v1.4s, v3.s[1]\n" + "fmla v10.4s, v1.4s, v3.s[2]\n" + "fmla v11.4s, v1.4s, v3.s[3]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0), [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", + "memory"); #undef LOAD_C #undef STORE_C -} - -void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin, int y0, - int ymax, int k0, int kmax) { - megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); - megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); - constexpr int PACK_SIZE_32 = 4 * 8; - constexpr int PACK_SIZE_16 = 4 * 4; - constexpr int PACK_C_SIZE = 4; - int y = y0; - for (; y + 7 < ymax; y += 8) { - const float* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0; - const float* inptr1 = inptr0 + ldin; - prefetch_2x(inptr0); - prefetch_2x(inptr1); - int k = (kmax - k0); - for (; k > 3; k -= 4) { - interleave_2x4_4_s(inptr0, inptr1, outptr); - outptr += PACK_SIZE_32; - } } - for (; y < ymax; y += 4) { - const float* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0; - prefetch_2x(inptr0); - int K = (kmax - k0); - for (; K > 3; K -= 4) { - interleave_1x4_4_s(inptr0, outptr); - outptr += PACK_SIZE_16; - } - } -} - -void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0, - int xmax, int k0, int kmax) { - megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); - float tmpbuff[16] = {0.0f}; - - constexpr int PACK_C_SIZE = 4; - int ksize = kmax - k0; - int ksize12 = ksize * 12; - int ksize4 = (ksize << 2); - float* outptr_base = out; - float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; - - int k = k0; - for (; k + 3 < kmax; k += 4) { - const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE; - prefetch_3x(inptr); - - int x = x0; - auto outptr = outptr_base; - for (; x + 12 <= xmax; x += 12) { - auto outptr_interleave = outptr; - transpose_1x12_4_s(inptr, outptr_interleave); - outptr += ksize12; + + static void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin, + int y0, int ymax, int k0, int kmax) { + megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + constexpr int PACK_SIZE_32 = 4 * 8; + constexpr int PACK_SIZE_16 = 4 * 4; + constexpr int PACK_C_SIZE = 4; + int y = y0; + for (; y + 7 < ymax; y += 8) { + const float* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0; + const float* inptr1 = inptr0 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + int k = (kmax - k0); + for (; k > 3; k -= 4) { + interleave_2x4_4_s(inptr0, inptr1, outptr); + outptr += PACK_SIZE_32; + } } - outptr = outptr_base4; - for (; x + 4 <= xmax; x += 4) { - auto outptr_interleave = outptr; - transpose_1x4_4_s(inptr, outptr_interleave); - outptr += ksize4; + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0; + prefetch_2x(inptr0); + int K = (kmax - k0); + for (; K > 3; K -= 4) { + interleave_1x4_4_s(inptr0, outptr); + outptr += PACK_SIZE_16; + } } - if (x < xmax) { - std::memcpy(tmpbuff, inptr, - sizeof(float) * (xmax - x) * PACK_C_SIZE); - auto outptr_interleave = outptr; - const float* tmp_ptr = &tmpbuff[0]; - transpose_1x4_4_s(tmp_ptr, outptr_interleave); - outptr += ksize4; + } + + static void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0, + int xmax, int k0, int kmax) { + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + float tmpbuff[16] = {0.0f}; + + constexpr int PACK_C_SIZE = 4; + int ksize = kmax - k0; + int ksize12 = ksize * 12; + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE; + prefetch_3x(inptr); + + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + transpose_1x12_4_s(inptr, outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + transpose_1x4_4_s(inptr, outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + std::memcpy(tmpbuff, inptr, + sizeof(float) * (xmax - x) * PACK_C_SIZE); + auto outptr_interleave = outptr; + const float* tmp_ptr = &tmpbuff[0]; + transpose_1x4_4_s(tmp_ptr, outptr_interleave); + outptr += ksize4; + } + outptr_base += 12 * 4; + outptr_base4 += 4 * 4; } - outptr_base += 12 * 4; - outptr_base4 += 4 * 4; } -} +}; -} // namespace matmul_mk4_8x12 -} // aarch64 -} // megdnn +} // namespace aarch64 +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h new file mode 100644 index 000000000..50263bd40 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h @@ -0,0 +1,1260 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +struct matmul_mk4_8x12_a53 { + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // |v1| |v20[0-3]|v21[0-3]|v22[0-3]| + // |v1| |v23[0-3]|v24[0-3]|v25[0-3]| + // |v1| |v26[0-3]|v27[0-3]|v28[0-3]| + // |v1| |v29[0-3]|v30[0-3]|v31[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + static void kern_8x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k) { + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + float* output1 = output0 + LDC; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + asm volatile( + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "mov x1, %[output0]\n" + "mov x2, %[output1]\n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64\n" + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ldr d2, [%[b_ptr]]\n" + + "eor v9.16b, v9.16b, v9.16b\n" + "ldr x10, [%[b_ptr], #8]\n" + + "eor v10.16b, v10.16b, v10.16b\n" + "ldr d3, [%[b_ptr], #16]\n" + + "eor v11.16b, v11.16b, v11.16b\n" + "ldr x11, [%[b_ptr], #24]\n" + + "eor v12.16b, v12.16b, v12.16b\n" + "ldr d4, [%[b_ptr], #32]\n" + + "eor v13.16b, v13.16b, v13.16b\n" + "ldr x12, [%[b_ptr], #40]\n" + + "eor v14.16b, v14.16b, v14.16b\n" + "ldr d0, [%[a_ptr]]\n" + + "eor v15.16b, v15.16b, v15.16b\n" + "ldr x9, [%[a_ptr], #8]\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "add %[b_ptr], %[b_ptr], #48\n" + + "eor v17.16b, v17.16b, v17.16b\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "eor v18.16b, v18.16b, v18.16b\n" + "ins v2.d[1], x10\n" + + "eor v19.16b, v19.16b, v19.16b\n" + "ins v3.d[1], x11\n" + + "eor v20.16b, v20.16b, v20.16b\n" + "ins v4.d[1], x12\n" + + "eor v21.16b, v21.16b, v21.16b\n" + "ins v0.d[1], x9\n" + + "eor v22.16b, v22.16b, v22.16b\n" + "prfm pldl1keep, [%[a_ptr], #384]\n" + + "eor v23.16b, v23.16b, v23.16b\n" + "prfm pldl1keep, [%[b_ptr]]\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "prfm pldl1keep, [%[b_ptr], #64]\n" + + "eor v25.16b, v25.16b, v25.16b\n" + "prfm pldl1keep, [%[b_ptr], #128]\n" + + "eor v26.16b, v26.16b, v26.16b\n" + "prfm pldl1keep, [%[b_ptr], #192]\n" + + "eor v27.16b, v27.16b, v27.16b\n" + "prfm pldl1keep, [%[b_ptr], #256]\n" + + "eor v28.16b, v28.16b, v28.16b\n" + "prfm pldl1keep, [%[b_ptr], #320]\n" + + "eor v29.16b, v29.16b, v29.16b\n" + "prfm pldl1keep, [%[b_ptr], #384]\n" + + "eor v30.16b, v30.16b, v30.16b\n" + "prfm pldl1keep, [%[b_ptr], #448]\n" + + "eor v31.16b, v31.16b, v31.16b\n" + "prfm pldl1keep, [%[b_ptr], #512]\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ldr d1, [%[a_ptr]]\n" + "nop\n" + + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + "subs %w[K], %w[K], #1\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + + "ldr d5, [%[b_ptr]]\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v12.4s, v0.4s, v3.s[0]\n" + + "fmla v13.4s, v0.4s, v3.s[1]\n" + + "ldr d6, [%[b_ptr], #16]\n" + "ins v5.d[1], x10\n" + + "fmla v14.4s, v0.4s, v3.s[2]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v15.4s, v0.4s, v3.s[3]\n" + + "fmla v16.4s, v0.4s, v4.s[0]\n" + + "ldr d7, [%[b_ptr], #32]\n" + "ins v6.d[1], x11\n" + + "fmla v17.4s, v0.4s, v4.s[1]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v18.4s, v0.4s, v4.s[2]\n" + + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "ldr d0, [%[a_ptr], #16]\n" + "ins v7.d[1], x12\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v21.4s, v1.4s, v2.s[1]\n" + + "fmla v22.4s, v1.4s, v2.s[2]\n" + + "prfm pldl1keep, [%[a_ptr], #448]\n" + "ins v0.d[1], x9\n" + + "fmla v23.4s, v1.4s, v2.s[3]\n" + + "fmla v24.4s, v1.4s, v3.s[0]\n" + + "fmla v25.4s, v1.4s, v3.s[1]\n" + + "prfm pldl1keep, [%[b_ptr], #576]\n" + "nop\n" + + "fmla v26.4s, v1.4s, v3.s[2]\n" + + "fmla v27.4s, v1.4s, v3.s[3]\n" + + "fmla v28.4s, v1.4s, v4.s[0]\n" + + "nop\n" + "nop\n" + + "fmla v29.4s, v1.4s, v4.s[1]\n" + + "fmla v30.4s, v1.4s, v4.s[2]\n" + + "fmla v31.4s, v1.4s, v4.s[3]\n" + //! UNROLL + "ldr d1, [%[a_ptr], #32]\n" + "nop\n" + + "fmla v8.4s, v0.4s, v5.s[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + "fmla v9.4s, v0.4s, v5.s[1]\n" + + "fmla v10.4s, v0.4s, v5.s[2]\n" + + "ldr d2, [%[b_ptr], #48]\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v0.4s, v5.s[3]\n" + "ldr x10, [%[b_ptr], #56]\n" + + "fmla v12.4s, v0.4s, v6.s[0]\n" + + "fmla v13.4s, v0.4s, v6.s[1]\n" + + "ldr d3, [%[b_ptr], #64]\n" + "ins v2.d[1], x10\n" + + "fmla v14.4s, v0.4s, v6.s[2]\n" + "ldr x11, [%[b_ptr], #72]\n" + "fmla v15.4s, v0.4s, v6.s[3]\n" + + "fmla v16.4s, v0.4s, v7.s[0]\n" + + "ldr d4, [%[b_ptr], #80]\n" + "ins v3.d[1], x11\n" + + "fmla v17.4s, v0.4s, v7.s[1]\n" + "ldr x12, [%[b_ptr], #88]\n" + "fmla v18.4s, v0.4s, v7.s[2]\n" + + "fmla v19.4s, v0.4s, v7.s[3]\n" + + "ldr d0, [%[a_ptr], #48]\n" + "ins v4.d[1], x12\n" + + "fmla v20.4s, v1.4s, v5.s[0]\n" + "ldr x9, [%[a_ptr], #56]\n" + + "fmla v21.4s, v1.4s, v5.s[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + + "fmla v22.4s, v1.4s, v5.s[2]\n" + + "nop\n" + "ins v0.d[1], x9\n" + + "fmla v23.4s, v1.4s, v5.s[3]\n" + "add %[a_ptr], %[a_ptr], #64\n" + + "fmla v24.4s, v1.4s, v6.s[0]\n" + + "fmla v25.4s, v1.4s, v6.s[1]\n" + + "prfm pldl1keep, [%[b_ptr], #640]\n" + "nop\n" + + "fmla v26.4s, v1.4s, v6.s[2]\n" + + "fmla v27.4s, v1.4s, v6.s[3]\n" + + "fmla v28.4s, v1.4s, v7.s[0]\n" + + "nop\n" + "nop\n" + + "fmla v29.4s, v1.4s, v7.s[1]\n" + + "fmla v30.4s, v1.4s, v7.s[2]\n" + + "fmla v31.4s, v1.4s, v7.s[3]\n" + + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ldr d1, [%[a_ptr]] \n" + "prfm pstl1keep, [%[output0]]\n" + + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr x8, [%[a_ptr], #8] \n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + + "ldr d5, [%[b_ptr]]\n" + "prfm pstl1keep, [%[output1]]\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v12.4s, v0.4s, v3.s[0]\n" + + "fmla v13.4s, v0.4s, v3.s[1]\n" + + "ldr d6, [%[b_ptr], #16]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v0.4s, v3.s[2]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v15.4s, v0.4s, v3.s[3]\n" + + "fmla v16.4s, v0.4s, v4.s[0]\n" + + "nop\n" + "ins v5.d[1], x10\n" + + "fmla v17.4s, v0.4s, v4.s[1]\n" + + "fmla v18.4s, v0.4s, v4.s[2]\n" + + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "ldr d0, [%[a_ptr], #16]\n" + "ins v6.d[1], x11\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v21.4s, v1.4s, v2.s[1]\n" + + "fmla v22.4s, v1.4s, v2.s[2]\n" + + "ldr d7, [%[b_ptr], #32]\n" + "ins v0.d[1], x9\n" + + "fmla v23.4s, v1.4s, v2.s[3]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v24.4s, v1.4s, v3.s[0]\n" + + "fmla v25.4s, v1.4s, v3.s[1]\n" + + "nop\n" + "ins v7.d[1], x12\n" + + "fmla v26.4s, v1.4s, v3.s[2]\n" + + "fmla v27.4s, v1.4s, v3.s[3]\n" + + "fmla v28.4s, v1.4s, v4.s[0]\n" + + "nop\n" + "nop\n" + + "fmla v29.4s, v1.4s, v4.s[1]\n" + + "fmla v30.4s, v1.4s, v4.s[2]\n" + + "fmla v31.4s, v1.4s, v4.s[3]\n" + + "ldr d1, [%[a_ptr], #32]\n" + "nop\n" + + "fmla v8.4s, v0.4s, v5.s[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + + "fmla v9.4s, v0.4s, v5.s[1]\n" + + "fmla v10.4s, v0.4s, v5.s[2]\n" + + "nop\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v0.4s, v5.s[3]\n" + + "fmla v12.4s, v0.4s, v6.s[0]\n" + + "fmla v13.4s, v0.4s, v6.s[1]\n" + + "fmla v14.4s, v0.4s, v6.s[2]\n" + + "fmla v15.4s, v0.4s, v6.s[3]\n" + "str q8, [%[output0]]\n" + + "fmla v16.4s, v0.4s, v7.s[0]\n" + "str q9, [%[output0], #16]\n" + + "fmla v17.4s, v0.4s, v7.s[1]\n" + "str q10, [%[output0], #32]\n" + + "fmla v18.4s, v0.4s, v7.s[2]\n" + + "fmla v19.4s, v0.4s, v7.s[3]\n" + "str q11, [%[output0], #48]\n" + + "fmla v20.4s, v1.4s, v5.s[0]\n" + "str q12, [%[output0], #64]\n" + + "fmla v21.4s, v1.4s, v5.s[1]\n" + "str q13, [%[output0], #80]\n" + + "fmla v22.4s, v1.4s, v5.s[2]\n" + "str q14, [%[output0], #96]\n" + + "fmla v23.4s, v1.4s, v5.s[3]\n" + "str q15, [%[output0], #112]\n" + + "fmla v24.4s, v1.4s, v6.s[0]\n" + "str q16, [%[output0], #128]\n" + + "fmla v25.4s, v1.4s, v6.s[1]\n" + "str q17, [%[output0], #144]\n" + + "fmla v26.4s, v1.4s, v6.s[2]\n" + "str q18, [%[output0], #160]\n" + + "fmla v27.4s, v1.4s, v6.s[3]\n" + "str q19, [%[output0], #176]\n" + + "fmla v28.4s, v1.4s, v7.s[0]\n" + "str q20, [%[output1]]\n" + + "fmla v29.4s, v1.4s, v7.s[1]\n" + "str q21, [%[output1], #16]\n" + + "fmla v30.4s, v1.4s, v7.s[2]\n" + "str q22, [%[output1], #32]\n" + + "fmla v31.4s, v1.4s, v7.s[3]\n" + "str q23, [%[output1], #48]\n" + + "str q24, [%[output1], #64]\n" + "str q25, [%[output1], #80]\n" + "str q26, [%[output1], #96]\n" + "str q27, [%[output1], #112]\n" + "str q28, [%[output1], #128]\n" + "str q29, [%[output1], #144]\n" + "str q30, [%[output1], #160]\n" + "str q31, [%[output1], #176]\n" + "b 6f\n" + + // odd tail + "5:\n" + "ldr d1, [%[a_ptr]]\n" + "nop\n" + + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + + "nop\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "fmla v12.4s, v0.4s, v3.s[0]\n" + + "fmla v13.4s, v0.4s, v3.s[1]\n" + + "fmla v14.4s, v0.4s, v3.s[2]\n" + + "fmla v15.4s, v0.4s, v3.s[3]\n" + + "fmla v16.4s, v0.4s, v4.s[0]\n" + "str q8, [%[output0]]\n" + + "fmla v17.4s, v0.4s, v4.s[1]\n" + "str q9, [%[output0], #16]\n" + + "fmla v18.4s, v0.4s, v4.s[2]\n" + "str q10, [%[output0], #32]\n" + + "fmla v19.4s, v0.4s, v4.s[3]\n" + "str q11, [%[output0], #48]\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "str q12, [%[output0], #64]\n" + + "fmla v21.4s, v1.4s, v2.s[1]\n" + "str q13, [%[output0], #80]\n" + + "fmla v22.4s, v1.4s, v2.s[2]\n" + "str q14, [%[output0], #96]\n" + + "fmla v23.4s, v1.4s, v2.s[3]\n" + "str q15, [%[output0], #112]\n" + + "fmla v24.4s, v1.4s, v3.s[0]\n" + "str q16, [%[output0], #128]\n" + + "fmla v25.4s, v1.4s, v3.s[1]\n" + "str q17, [%[output0], #144]\n" + + "fmla v26.4s, v1.4s, v3.s[2]\n" + "str q18, [%[output0], #160]\n" + + "fmla v27.4s, v1.4s, v3.s[3]\n" + "str q19, [%[output0], #176]\n" + + "fmla v28.4s, v1.4s, v4.s[0]\n" + "str q20, [%[output1]]\n" + + "fmla v29.4s, v1.4s, v4.s[1]\n" + "str q21, [%[output1], #16]\n" + + "fmla v30.4s, v1.4s, v4.s[2]\n" + "str q22, [%[output1], #32]\n" + + "fmla v31.4s, v1.4s, v4.s[3]\n" + "str q23, [%[output1], #48]\n" + + "str q24, [%[output1], #64]\n" + "str q25, [%[output1], #80]\n" + "str q26, [%[output1], #96]\n" + "str q27, [%[output1], #112]\n" + "str q28, [%[output1], #128]\n" + "str q29, [%[output1], #144]\n" + "str q30, [%[output1], #160]\n" + "str q31, [%[output1], #176]\n" + + "6:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0), [output1] "+r"(output1) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", + "x11", "x12", "x13", "cc", "memory"); + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+ + // | v2[0-3]| + // | v3[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // |v1| |v20[0-3]| + // |v1| |v23[0-3]| + // |v1| |v26[0-3]| + // |v1| |v29[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_8x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + float* output1 = output0 + LDC; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + //clang-format off +#define LOAD_C \ + "cmp %w[n_remain], #4\n" \ + "blt 11f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \ + "b 14f\n" \ + "11:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 12f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "ld1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \ + "b 14f\n" \ + "12:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 13f\n" \ + "ld1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "ld1 {v12.4s, v13.4s},[%[output1]]\n" \ + "b 14f\n" \ + "13:\n" \ + "ld1 {v8.4s}, [%[output0]]\n" \ + "ld1 {v12.4s},[%[output1]]\n" \ + "14:\n" + +#define STORE_C \ + "cmp %w[n_remain], #4\n" \ + "blt 21f\n" \ + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "st1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \ + "b 24f\n" \ + "21:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 22f\n" \ + "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "st1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \ + "b 23f\n" \ + "22:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 23f\n" \ + "st1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "st1 {v12.4s, v13.4s},[%[output1]]\n" \ + "b 24f\n" \ + "23:\n" \ + "st1 {v8.4s}, [%[output0]]\n" \ + "st1 {v12.4s},[%[output1]]\n" \ + "24:\n" + //clang-format on + + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ldr d0, [%[a_ptr]]\n" + + "eor v9.16b, v9.16b, v9.16b\n" + "ldr x8, [%[a_ptr], #8]\n" + + "eor v10.16b, v10.16b, v10.16b\n" + "ldr d2, [%[b_ptr]]\n" + + "eor v11.16b, v11.16b, v11.16b\n" + "ldr x9, [%[b_ptr], #8]\n" + + "eor v12.16b, v12.16b, v12.16b\n" + "ins v0.d[1], x8\n" + + "eor v13.16b, v13.16b, v13.16b\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "eor v14.16b, v14.16b, v14.16b\n" + "ins v2.d[1], x9\n" + + "eor v15.16b, v15.16b, v15.16b\n" + "add %[b_ptr], %[b_ptr], #16\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ldr q1, [%[a_ptr]]\n" + "nop\n" + + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + + "ldr d3, [%[b_ptr]]\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v12.4s, v1.4s, v2.s[0]\n" + + "fmla v13.4s, v1.4s, v2.s[1]\n" + + "ldr d0, [%[a_ptr], #16]\n" + "ins v3.d[1], x10\n" + + "fmla v14.4s, v1.4s, v2.s[2]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "ldr d1, [%[a_ptr], #32]\n" + "ldr x8, [%[a_ptr], #40]\n" + + "ldr d2, [%[b_ptr], #16]\n" + "ins v0.d[1], x9\n" + + "fmla v8.4s, v0.4s, v3.s[0]\n" + "ldr x10, [%[b_ptr], #24]\n" + + "fmla v9.4s, v0.4s, v3.s[1]\n" + + "fmla v10.4s, v0.4s, v3.s[2]\n" + + "ins v1.d[1], x8\n" + "ins v2.d[1], x10\n" + + "fmla v11.4s, v0.4s, v3.s[3]\n" + + "fmla v12.4s, v1.4s, v3.s[0]\n" + + "fmla v13.4s, v1.4s, v3.s[1]\n" + + "ldr d0, [%[a_ptr], #48]\n" + "nop\n" + + "fmla v14.4s, v1.4s, v3.s[2]\n" + "ldr x9, [%[a_ptr], #56]\n" + + "fmla v15.4s, v1.4s, v3.s[3]\n" + "add %[b_ptr], %[b_ptr], #32\n" + + "add %[a_ptr], %[a_ptr], #64\n" + "subs %w[K], %w[K], #1\n" + + "ins v0.d[1], x9\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ldr d1, [%[a_ptr]]\n" + "nop\n" + + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + "prfm pstl1keep, [%[output1]]\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + + "ldr d3, [%[b_ptr]]\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v12.4s, v1.4s, v2.s[0]\n" + + "fmla v13.4s, v1.4s, v2.s[1]\n" + + "ldr d0, [%[a_ptr], #16]\n" + "ins v3.d[1], x10\n" + + "fmla v14.4s, v1.4s, v2.s[2]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v15.4s, v1.4s, v2.s[3]\n" + "prfm pstl1keep, [%[output1]]\n" + + "ldr d1, [%[a_ptr], #32]\n" + "ins v0.d[1], x9\n" + + "fmla v8.4s, v0.4s, v3.s[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + "fmla v9.4s, v0.4s, v3.s[1]\n" + + "fmla v10.4s, v0.4s, v3.s[2]\n" + + "nop\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v0.4s, v3.s[3]\n" + "fmla v12.4s, v1.4s, v3.s[0]\n" + "fmla v13.4s, v1.4s, v3.s[1]\n" + "fmla v14.4s, v1.4s, v3.s[2]\n" + "fmla v15.4s, v1.4s, v3.s[3]\n" + "b 6f\n" + + // odd tail + "5:\n" + "ldr q1, [%[a_ptr]]\n" + "nop\n" + + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + + "nop\n" + "ins v1.d[1], x8\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v1.4s, v2.s[0]\n" + "fmla v13.4s, v1.4s, v2.s[1]\n" + "fmla v14.4s, v1.4s, v2.s[2]\n" + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0), [output1] "+r"(output1), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); + +#undef LOAD_C +#undef STORE_C + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + + static void kern_4x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + asm volatile( + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "mov x1, %[output0]\n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n" + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ldr d5, [%[b_ptr]]\n" + "nop\n" + + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + + "ldr d1, [%[a_ptr]]\n" + "ins v5.d[1], x10\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v12.4s, v0.4s, v3.s[0]\n" + + "fmla v13.4s, v0.4s, v3.s[1]\n" + + "ldr d6, [%[b_ptr], #16]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v0.4s, v3.s[2]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v15.4s, v0.4s, v3.s[3]\n" + + "fmla v16.4s, v0.4s, v4.s[0]\n" + + "ldr d7, [%[b_ptr], #32]\n" + "ins v6.d[1], x11\n" + + "fmla v17.4s, v0.4s, v4.s[1]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v18.4s, v0.4s, v4.s[2]\n" + + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "ldr d2, [%[b_ptr], #48]\n" + "ins v7.d[1], x12\n" + + "fmla v8.4s, v1.4s, v5.s[0]\n" + "ldr x10, [%[b_ptr], #56]\n" + + "fmla v9.4s, v1.4s, v5.s[1]\n" + + "fmla v10.4s, v1.4s, v5.s[2]\n" + + "ldr d3, [%[b_ptr], #64]\n" + "ins v2.d[1], x10\n" + + "fmla v11.4s, v1.4s, v5.s[3]\n" + "ldr x11, [%[b_ptr], #72]\n" + + "fmla v12.4s, v1.4s, v6.s[0]\n" + "subs %w[K], %w[K], #1\n" + + "fmla v13.4s, v1.4s, v6.s[1]\n" + + "ldr d4, [%[b_ptr], #80]\n" + "ins v3.d[1], x11\n" + + "fmla v14.4s, v1.4s, v6.s[2]\n" + "ldr x12, [%[b_ptr], #88]\n" + + "fmla v15.4s, v1.4s, v6.s[3]\n" + + "fmla v16.4s, v1.4s, v7.s[0]\n" + + "ldr d0, [%[a_ptr], #16]\n" + "ins v4.d[1], x12\n" + + "fmla v17.4s, v1.4s, v7.s[1]\n" + "ldr x10, [%[a_ptr], #24]\n" + + "fmla v18.4s, v1.4s, v7.s[2]\n" + "add %[b_ptr], %[b_ptr], #96\n" + + "fmla v19.4s, v1.4s, v7.s[3]\n" + "add %[a_ptr], %[a_ptr], #32\n" + + "ins v0.d[1], x10\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ldr d5, [%[b_ptr]]\n" + "nop\n" + + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + + "ldr d6, [%[b_ptr], #16]\n" + "ins v5.d[1], x10\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v12.4s, v0.4s, v3.s[0]\n" + + "fmla v13.4s, v0.4s, v3.s[1]\n" + + "ldr d1, [%[a_ptr]]\n" + "ins v6.d[1], x11\n" + + "fmla v14.4s, v0.4s, v3.s[2]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v15.4s, v0.4s, v3.s[3]\n" + + "fmla v16.4s, v0.4s, v4.s[0]\n" + + "ldr d7, [%[b_ptr], #32]\n" + "ins v1.d[1], x8\n" + + "fmla v17.4s, v0.4s, v4.s[1]\n" + "ldr x12, [%[b_ptr], #40]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "nop\n" + "ins v7.d[1], x12\n" + + "fmla v8.4s, v1.4s, v5.s[0]\n" + "fmla v9.4s, v1.4s, v5.s[1]\n" + "fmla v10.4s, v1.4s, v5.s[2]\n" + "fmla v11.4s, v1.4s, v5.s[3]\n" + "fmla v12.4s, v1.4s, v6.s[0]\n" + "fmla v13.4s, v1.4s, v6.s[1]\n" + "fmla v14.4s, v1.4s, v6.s[2]\n" + "fmla v15.4s, v1.4s, v6.s[3]\n" + "fmla v16.4s, v1.4s, v7.s[0]\n" + "fmla v17.4s, v1.4s, v7.s[1]\n" + "str q8, [%[output0]]\n" + "fmla v18.4s, v1.4s, v7.s[2]\n" + "str q9, [%[output0], #16]\n" + "fmla v19.4s, v1.4s, v7.s[3]\n" + "str q10, [%[output0], #32]\n" + "str q11, [%[output0], #48]\n" + "str q12, [%[output0], #64]\n" + "str q13, [%[output0], #80]\n" + "str q14, [%[output0], #96]\n" + "str q15, [%[output0], #112]\n" + "str q16, [%[output0], #128]\n" + "str q17, [%[output0], #144]\n" + "str q18, [%[output0], #160]\n" + "str q19, [%[output0], #176]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "str q8, [%[output0]]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "str q9, [%[output0], #16]\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + "str q10, [%[output0], #32]\n" + "str q11, [%[output0], #48]\n" + "str q12, [%[output0], #64]\n" + "str q13, [%[output0], #80]\n" + "str q14, [%[output0], #96]\n" + "str q15, [%[output0], #112]\n" + "str q16, [%[output0], #128]\n" + "str q17, [%[output0], #144]\n" + "str q18, [%[output0], #160]\n" + "str q19, [%[output0], #176]\n" + + "6:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); + } + + // Overview of register layout: + // + // A 2x4 cell of Rhs is stored in 32bit in v2 - v3 + // A 4x2 cell of Lhs is stored in 32bit in v0 - v1 + // A 4x4 block of accumulators is stored in 32bit in v4-v6 + // + // +--------+ + // | v2[0-3]| + // | v5[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_4x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int n_remain) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + //clang-format off +#define LOAD_C \ + "cmp %w[n_remain], #4\n" \ + "blt 11f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "11:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 12f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "12:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 13f\n" \ + "ld1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "13:\n" \ + "ld1 {v8.4s}, [%[output0]]\n" \ + "14:\n" + +#define STORE_C \ + "cmp %w[n_remain], #4\n" \ + "blt 21f\n" \ + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "21:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 22f\n" \ + "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "22:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 23f\n" \ + "st1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "23:\n" \ + "st1 {v8.4s}, [%[output0]]\n" \ + "24:\n" + //clang-format on + + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "eor v9.16b, v9.16b, v9.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v8.4s, v1.4s, v3.s[0]\n" + "fmla v9.4s, v1.4s, v3.s[1]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmla v10.4s, v1.4s, v3.s[2]\n" + "fmla v11.4s, v1.4s, v3.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v8.4s, v1.4s, v3.s[0]\n" + "fmla v9.4s, v1.4s, v3.s[1]\n" + "fmla v10.4s, v1.4s, v3.s[2]\n" + "fmla v11.4s, v1.4s, v3.s[3]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0), [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", + "memory"); + +#undef LOAD_C +#undef STORE_C + } +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h new file mode 100644 index 000000000..23bed9509 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h @@ -0,0 +1,1160 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +struct matmul_mk4_8x12_a55 { + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // |v1| |v20[0-3]|v21[0-3]|v22[0-3]| + // |v1| |v23[0-3]|v24[0-3]|v25[0-3]| + // |v1| |v26[0-3]|v27[0-3]|v28[0-3]| + // |v1| |v29[0-3]|v30[0-3]|v31[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + static void kern_8x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k) { + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + float* output1 = output0 + LDC; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + asm volatile( + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "mov x1, %[output0]\n" + "mov x2, %[output1]\n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64\n" + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ldr d2, [%[b_ptr]]\n" + + "eor v9.16b, v9.16b, v9.16b\n" + "ldr x10, [%[b_ptr], #8]\n" + + "eor v10.16b, v10.16b, v10.16b\n" + "ldr d3, [%[b_ptr], #16]\n" + + "eor v11.16b, v11.16b, v11.16b\n" + "ldr x11, [%[b_ptr], #24]\n" + + "eor v12.16b, v12.16b, v12.16b\n" + "ldr d4, [%[b_ptr], #32]\n" + + "eor v13.16b, v13.16b, v13.16b\n" + "ldr x12, [%[b_ptr], #40]\n" + + "eor v14.16b, v14.16b, v14.16b\n" + "ldr d0, [%[a_ptr]]\n" + + "eor v15.16b, v15.16b, v15.16b\n" + "ldr x9, [%[a_ptr], #8]\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "add %[b_ptr], %[b_ptr], #48\n" + + "eor v17.16b, v17.16b, v17.16b\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "eor v18.16b, v18.16b, v18.16b\n" + "ins v2.d[1], x10\n" + + "eor v19.16b, v19.16b, v19.16b\n" + "ins v3.d[1], x11\n" + + "eor v20.16b, v20.16b, v20.16b\n" + "ins v4.d[1], x12\n" + + "eor v21.16b, v21.16b, v21.16b\n" + "ins v0.d[1], x9\n" + + "eor v22.16b, v22.16b, v22.16b\n" + "prfm pldl1keep, [%[a_ptr], #384]\n" + + "eor v23.16b, v23.16b, v23.16b\n" + "prfm pldl1keep, [%[b_ptr]]\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "prfm pldl1keep, [%[b_ptr], #64]\n" + + "eor v25.16b, v25.16b, v25.16b\n" + "prfm pldl1keep, [%[b_ptr], #128]\n" + + "eor v26.16b, v26.16b, v26.16b\n" + "prfm pldl1keep, [%[b_ptr], #192]\n" + + "eor v27.16b, v27.16b, v27.16b\n" + "prfm pldl1keep, [%[b_ptr], #256]\n" + + "eor v28.16b, v28.16b, v28.16b\n" + "prfm pldl1keep, [%[b_ptr], #320]\n" + + "eor v29.16b, v29.16b, v29.16b\n" + "prfm pldl1keep, [%[b_ptr], #384]\n" + + "eor v30.16b, v30.16b, v30.16b\n" + "prfm pldl1keep, [%[b_ptr], #448]\n" + + "eor v31.16b, v31.16b, v31.16b\n" + "prfm pldl1keep, [%[b_ptr], #512]\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr d1, [%[a_ptr]]\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + "subs %w[K], %w[K], #1\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "fmla v12.4s, v0.4s, v3.s[0]\n" + "ldr d5, [%[b_ptr]]\n" + + "fmla v13.4s, v0.4s, v3.s[1]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v0.4s, v3.s[2]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v15.4s, v0.4s, v3.s[3]\n" + + "fmla v16.4s, v0.4s, v4.s[0]\n" + "ldr d6, [%[b_ptr], #16]\n" + + "fmla v17.4s, v0.4s, v4.s[1]\n" + "ins v5.d[1], x10\n" + + "fmla v18.4s, v0.4s, v4.s[2]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "ldr d7, [%[b_ptr], #32]\n" + + "fmla v21.4s, v1.4s, v2.s[1]\n" + "ins v6.d[1], x11\n" + + "fmla v22.4s, v1.4s, v2.s[2]\n" + "ldr d0, [%[a_ptr], #16]\n" + + "fmla v23.4s, v1.4s, v2.s[3]\n" + + "fmla v24.4s, v1.4s, v3.s[0]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v25.4s, v1.4s, v3.s[1]\n" + + "fmla v26.4s, v1.4s, v3.s[2]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v27.4s, v1.4s, v3.s[3]\n" + "ins v7.d[1], x12\n" + + "fmla v28.4s, v1.4s, v4.s[0]\n" + "prfm pldl1keep, [%[a_ptr], #448]\n" + + "fmla v29.4s, v1.4s, v4.s[1]\n" + "ins v0.d[1], x9\n" + + "fmla v30.4s, v1.4s, v4.s[2]\n" + "prfm pldl1keep, [%[b_ptr], #576]\n" + + "fmla v31.4s, v1.4s, v4.s[3]\n" + //! UNROLL + + "fmla v8.4s, v0.4s, v5.s[0]\n" + "ldr d1, [%[a_ptr], #32]\n" + + "fmla v9.4s, v0.4s, v5.s[1]\n" + + "fmla v10.4s, v0.4s, v5.s[2]\n" + "ldr x8, [%[a_ptr], #40]\n" + + "fmla v11.4s, v0.4s, v5.s[3]\n" + + "fmla v12.4s, v0.4s, v6.s[0]\n" + "ldr d2, [%[b_ptr], #48]\n" + + "fmla v13.4s, v0.4s, v6.s[1]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v0.4s, v6.s[2]\n" + "ldr x10, [%[b_ptr], #56]\n" + + "fmla v15.4s, v0.4s, v6.s[3]\n" + + "fmla v16.4s, v0.4s, v7.s[0]\n" + "ldr d3, [%[b_ptr], #64]\n" + + "fmla v17.4s, v0.4s, v7.s[1]\n" + "ins v2.d[1], x10\n" + + "fmla v18.4s, v0.4s, v7.s[2]\n" + "ldr x11, [%[b_ptr], #72]\n" + + "fmla v19.4s, v0.4s, v7.s[3]\n" + + "fmla v20.4s, v1.4s, v5.s[0]\n" + "ldr d4, [%[b_ptr], #80]\n" + + "fmla v21.4s, v1.4s, v5.s[1]\n" + "ins v3.d[1], x11\n" + + "fmla v22.4s, v1.4s, v5.s[2]\n" + "ldr x12, [%[b_ptr], #88]\n" + + "fmla v23.4s, v1.4s, v5.s[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + + "fmla v24.4s, v1.4s, v6.s[0]\n" + "ldr d0, [%[a_ptr], #48]\n" + + "fmla v25.4s, v1.4s, v6.s[1]\n" + "ins v4.d[1], x12\n" + + "fmla v26.4s, v1.4s, v6.s[2]\n" + "ldr x9, [%[a_ptr], #56]\n" + + "fmla v27.4s, v1.4s, v6.s[3]\n" + "add %[a_ptr], %[a_ptr], #64\n" + + "fmla v28.4s, v1.4s, v7.s[0]\n" + "prfm pldl1keep, [%[b_ptr], #640]\n" + + "fmla v29.4s, v1.4s, v7.s[1]\n" + "ins v0.d[1], x9\n" + + "fmla v30.4s, v1.4s, v7.s[2]\n" + + "fmla v31.4s, v1.4s, v7.s[3]\n" + + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "prfm pstl1keep, [%[output0]]\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ldr d1, [%[a_ptr]] \n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + "prfm pstl1keep, [%[output1]]\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ldr x8, [%[a_ptr], #8] \n" + + "fmla v12.4s, v0.4s, v3.s[0]\n" + + "fmla v13.4s, v0.4s, v3.s[1]\n" + "ldr d5, [%[b_ptr]]\n" + + "fmla v14.4s, v0.4s, v3.s[2]\n" + "ins v1.d[1], x8\n" + + "fmla v15.4s, v0.4s, v3.s[3]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v16.4s, v0.4s, v4.s[0]\n" + + "fmla v17.4s, v0.4s, v4.s[1]\n" + "ldr d6, [%[b_ptr], #16]\n" + + "fmla v18.4s, v0.4s, v4.s[2]\n" + "ins v5.d[1], x10\n" + + "fmla v19.4s, v0.4s, v4.s[3]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + + "fmla v21.4s, v1.4s, v2.s[1]\n" + "ldr d0, [%[a_ptr], #16]\n" + + "fmla v22.4s, v1.4s, v2.s[2]\n" + "ins v6.d[1], x11\n" + + "fmla v23.4s, v1.4s, v2.s[3]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "fmla v24.4s, v1.4s, v3.s[0]\n" + + "fmla v25.4s, v1.4s, v3.s[1]\n" + "ldr d7, [%[b_ptr], #32]\n" + + "fmla v26.4s, v1.4s, v3.s[2]\n" + "ins v0.d[1], x9\n" + + "fmla v27.4s, v1.4s, v3.s[3]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v28.4s, v1.4s, v4.s[0]\n" + + "fmla v29.4s, v1.4s, v4.s[1]\n" + + "fmla v30.4s, v1.4s, v4.s[2]\n" + "ins v7.d[1], x12\n" + + "fmla v31.4s, v1.4s, v4.s[3]\n" + + "fmla v8.4s, v0.4s, v5.s[0]\n" + "ldr d1, [%[a_ptr], #32]\n" + + "fmla v9.4s, v0.4s, v5.s[1]\n" + + "fmla v10.4s, v0.4s, v5.s[2]\n" + "ldr x8, [%[a_ptr], #40]\n" + + "fmla v11.4s, v0.4s, v5.s[3]\n" + "str q8, [%[output0]]\n" + + "fmla v12.4s, v0.4s, v6.s[0]\n" + "str q9, [%[output0], #16]\n" + + "fmla v13.4s, v0.4s, v6.s[1]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v0.4s, v6.s[2]\n" + "str q10, [%[output0], #32]\n" + + "fmla v15.4s, v0.4s, v6.s[3]\n" + + "fmla v16.4s, v0.4s, v7.s[0]\n" + "str q11, [%[output0], #48]\n" + + "fmla v17.4s, v0.4s, v7.s[1]\n" + "str q12, [%[output0], #64]\n" + + "fmla v18.4s, v0.4s, v7.s[2]\n" + "str q13, [%[output0], #80]\n" + + "fmla v19.4s, v0.4s, v7.s[3]\n" + "str q14, [%[output0], #96]\n" + + "fmla v20.4s, v1.4s, v5.s[0]\n" + "str q15, [%[output0], #112]\n" + + "fmla v21.4s, v1.4s, v5.s[1]\n" + "str q16, [%[output0], #128]\n" + + "fmla v22.4s, v1.4s, v5.s[2]\n" + "str q17, [%[output0], #144]\n" + + "fmla v23.4s, v1.4s, v5.s[3]\n" + "str q18, [%[output0], #160]\n" + + "fmla v24.4s, v1.4s, v6.s[0]\n" + "str q19, [%[output0], #176]\n" + + "fmla v25.4s, v1.4s, v6.s[1]\n" + "str q20, [%[output1]]\n" + + "fmla v26.4s, v1.4s, v6.s[2]\n" + "str q21, [%[output1], #16]\n" + + "fmla v27.4s, v1.4s, v6.s[3]\n" + "str q22, [%[output1], #32]\n" + + "fmla v28.4s, v1.4s, v7.s[0]\n" + "str q23, [%[output1], #48]\n" + + "fmla v29.4s, v1.4s, v7.s[1]\n" + "str q24, [%[output1], #64]\n" + + "fmla v30.4s, v1.4s, v7.s[2]\n" + "str q25, [%[output1], #80]\n" + + "fmla v31.4s, v1.4s, v7.s[3]\n" + "str q26, [%[output1], #96]\n" + + "str q27, [%[output1], #112]\n" + + "str q28, [%[output1], #128]\n" + + "str q29, [%[output1], #144]\n" + + "str q30, [%[output1], #160]\n" + + "str q31, [%[output1], #176]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr d1, [%[a_ptr]]\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + "str q8, [%[output0]]\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "str q9, [%[output0], #16]\n" + + "fmla v12.4s, v0.4s, v3.s[0]\n" + "str q10, [%[output0], #32]\n" + + "fmla v13.4s, v0.4s, v3.s[1]\n" + "ins v1.d[1], x8\n" + + "fmla v14.4s, v0.4s, v3.s[2]\n" + "str q11, [%[output0], #48]\n" + + "fmla v15.4s, v0.4s, v3.s[3]\n" + "str q12, [%[output0], #64]\n" + + "fmla v16.4s, v0.4s, v4.s[0]\n" + "str q13, [%[output0], #80]\n" + + "fmla v17.4s, v0.4s, v4.s[1]\n" + "str q14, [%[output0], #96]\n" + + "fmla v18.4s, v0.4s, v4.s[2]\n" + "str q15, [%[output0], #112]\n" + + "fmla v19.4s, v0.4s, v4.s[3]\n" + "str q16, [%[output0], #128]\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "str q17, [%[output0], #144]\n" + + "fmla v21.4s, v1.4s, v2.s[1]\n" + "str q18, [%[output0], #160]\n" + + "fmla v22.4s, v1.4s, v2.s[2]\n" + "str q19, [%[output0], #176]\n" + + "fmla v23.4s, v1.4s, v2.s[3]\n" + "str q20, [%[output1]]\n" + + "fmla v24.4s, v1.4s, v3.s[0]\n" + "str q21, [%[output1], #16]\n" + + "fmla v25.4s, v1.4s, v3.s[1]\n" + "str q22, [%[output1], #32]\n" + + "fmla v26.4s, v1.4s, v3.s[2]\n" + "str q23, [%[output1], #48]\n" + + "fmla v27.4s, v1.4s, v3.s[3]\n" + "str q24, [%[output1], #64]\n" + + "fmla v28.4s, v1.4s, v4.s[0]\n" + "str q25, [%[output1], #80]\n" + + "fmla v29.4s, v1.4s, v4.s[1]\n" + "str q26, [%[output1], #96]\n" + + "fmla v30.4s, v1.4s, v4.s[2]\n" + "str q27, [%[output1], #112]\n" + + "fmla v31.4s, v1.4s, v4.s[3]\n" + "str q28, [%[output1], #128]\n" + + "str q29, [%[output1], #144]\n" + + "str q30, [%[output1], #160]\n" + + "str q31, [%[output1], #176]\n" + + "6:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0), [output1] "+r"(output1) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", + "x11", "x12", "x13", "cc", "memory"); + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+ + // | v2[0-3]| + // | v3[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // |v1| |v20[0-3]| + // |v1| |v23[0-3]| + // |v1| |v26[0-3]| + // |v1| |v29[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_8x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + float* output1 = output0 + LDC; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + //clang-format off +#define LOAD_C \ + "cmp %w[n_remain], #4\n" \ + "blt 11f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \ + "b 14f\n" \ + "11:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 12f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "ld1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \ + "b 14f\n" \ + "12:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 13f\n" \ + "ld1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "ld1 {v12.4s, v13.4s},[%[output1]]\n" \ + "b 14f\n" \ + "13:\n" \ + "ld1 {v8.4s}, [%[output0]]\n" \ + "ld1 {v12.4s},[%[output1]]\n" \ + "14:\n" + +#define STORE_C \ + "cmp %w[n_remain], #4\n" \ + "blt 21f\n" \ + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "st1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \ + "b 24f\n" \ + "21:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 22f\n" \ + "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "st1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \ + "b 23f\n" \ + "22:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 23f\n" \ + "st1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "st1 {v12.4s, v13.4s},[%[output1]]\n" \ + "b 24f\n" \ + "23:\n" \ + "st1 {v8.4s}, [%[output0]]\n" \ + "st1 {v12.4s},[%[output1]]\n" \ + "24:\n" + //clang-format on + + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ldr d0, [%[a_ptr]]\n" + + "eor v9.16b, v9.16b, v9.16b\n" + "ldr x8, [%[a_ptr], #8]\n" + + "eor v10.16b, v10.16b, v10.16b\n" + "ldr d2, [%[b_ptr]]\n" + + "eor v11.16b, v11.16b, v11.16b\n" + "ldr x9, [%[b_ptr], #8]\n" + + "eor v12.16b, v12.16b, v12.16b\n" + "ins v0.d[1], x8\n" + + "eor v13.16b, v13.16b, v13.16b\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "eor v14.16b, v14.16b, v14.16b\n" + "ins v2.d[1], x9\n" + + "eor v15.16b, v15.16b, v15.16b\n" + "add %[b_ptr], %[b_ptr], #16\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ldr d3, [%[b_ptr]]\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + "subs %w[K], %w[K], #1\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v12.4s, v1.4s, v2.s[0]\n" + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + + "fmla v13.4s, v1.4s, v2.s[1]\n" + "ins v3.d[1], x10\n" + + "fmla v14.4s, v1.4s, v2.s[2]\n" + + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "fmla v8.4s, v0.4s, v3.s[0]\n" + + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + + "fmla v9.4s, v0.4s, v3.s[1]\n" + "ldr d2, [%[b_ptr], #16]\n" + + "fmla v10.4s, v0.4s, v3.s[2]\n" + + "fmla v11.4s, v0.4s, v3.s[3]\n" + "ldr x10, [%[b_ptr], #24]\n" + + "fmla v12.4s, v1.4s, v3.s[0]\n" + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + + "fmla v13.4s, v1.4s, v3.s[1]\n" + "ins v2.d[1], x10\n" + + "fmla v14.4s, v1.4s, v3.s[2]\n" + "add %[b_ptr], %[b_ptr], #32\n" + + "fmla v15.4s, v1.4s, v3.s[3]\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "prfm pstl1keep, [%[output1]]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], #16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "prfm pstl1keep, [%[output1]]\n" + "fmla v12.4s, v1.4s, v2.s[0]\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "fmla v13.4s, v1.4s, v2.s[1]\n" + "fmla v14.4s, v1.4s, v2.s[2]\n" + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "fmla v8.4s, v0.4s, v3.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v3.s[1]\n" + "fmla v10.4s, v0.4s, v3.s[2]\n" + "fmla v11.4s, v0.4s, v3.s[3]\n" + "fmla v12.4s, v1.4s, v3.s[0]\n" + "fmla v13.4s, v1.4s, v3.s[1]\n" + "fmla v14.4s, v1.4s, v3.s[2]\n" + "fmla v15.4s, v1.4s, v3.s[3]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v1.4s, v2.s[0]\n" + "fmla v13.4s, v1.4s, v2.s[1]\n" + "fmla v14.4s, v1.4s, v2.s[2]\n" + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0), [output1] "+r"(output1), + [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); + +#undef LOAD_C +#undef STORE_C + } + + // Overview of register layout: + // + // A 1x12 cell of Rhs is stored in 32bit in v2-v7 + // A 8x1 cell of Lhs is stored in 32bit in (v0-v1) + // A 8x12 block of accumulators is stored in 32bit in v8-v31. + // + // +--------+--------+--------+ + // | v2[0-3]| v3[0-3]| v4[0-3]| + // | v5[0-3]| v6[0-3]| v7[0-3]| + // Rhs +--------+--------+--------+ + // + // | | | | + // + // Lhs | | | | + // + // +--+ --- - +--------+--------+--------+ + // |v0| | v8[0-3]| v9[0-3]|v10[0-3]| + // |v0| |v11[0-3]|v12[0-3]|v13[0-3]| + // |v0| |v14[0-3]|v15[0-3]|v16[0-3]| + // |v0| |v17[0-3]|v18[0-3]|v19[0-3]| + // +--+ --- - +--------+--------+--------+ + // + // Accumulator + + static void kern_4x12(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + asm volatile( + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "mov x1, %[output0]\n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n" + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr d5, [%[b_ptr]]\n" + + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ldr d1, [%[a_ptr]]\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "fmla v12.4s, v0.4s, v3.s[0]\n" + "ins v5.d[1], x10\n" + + "fmla v13.4s, v0.4s, v3.s[1]\n" + "ldr d6, [%[b_ptr], #16]\n" + + "fmla v14.4s, v0.4s, v3.s[2]\n" + "ins v1.d[1], x8\n" + + "fmla v15.4s, v0.4s, v3.s[3]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v16.4s, v0.4s, v4.s[0]\n" + "ldr d7, [%[b_ptr], #32]\n" + + "fmla v17.4s, v0.4s, v4.s[1]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v18.4s, v0.4s, v4.s[2]\n" + "ins v6.d[1], x11\n" + + "fmla v19.4s, v0.4s, v4.s[3]\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "fmla v8.4s, v1.4s, v5.s[0]\n" + "ins v7.d[1], x12\n" + + "fmla v9.4s, v1.4s, v5.s[1]\n" + "ldr d2, [%[b_ptr], #48]\n" + + "fmla v10.4s, v1.4s, v5.s[2]\n" + "ldr x10, [%[b_ptr], #56]\n" + + "fmla v11.4s, v1.4s, v5.s[3]\n" + "ldr d3, [%[b_ptr], #64]\n" + + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "subs %w[K], %w[K], #1\n" + + "fmla v12.4s, v1.4s, v6.s[0]\n" + "ldr x11, [%[b_ptr], #72]\n" + + "fmla v13.4s, v1.4s, v6.s[1]\n" + "ldr d4, [%[b_ptr], #80]\n" + + "fmla v14.4s, v1.4s, v6.s[2]\n" + "ins v2.d[1], x10\n" + + "fmla v15.4s, v1.4s, v6.s[3]\n" + "ldr x12, [%[b_ptr], #88]\n" + + "fmla v16.4s, v1.4s, v7.s[0]\n" + "ins v3.d[1], x11\n" + + "fmla v17.4s, v1.4s, v7.s[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + + "fmla v18.4s, v1.4s, v7.s[2]\n" + "ins v4.d[1], x12\n" + + "fmla v19.4s, v1.4s, v7.s[3]\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ldr d5, [%[b_ptr]]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ldr x10, [%[b_ptr], #8]\n" + + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ldr d6, [%[b_ptr], #16]\n" + + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ldr x11, [%[b_ptr], #24]\n" + + "fmla v12.4s, v0.4s, v3.s[0]\n" + "ins v5.d[1], x10\n" + + "fmla v13.4s, v0.4s, v3.s[1]\n" + "ldr d7, [%[b_ptr], #32]\n" + + "fmla v14.4s, v0.4s, v3.s[2]\n" + "ldr x12, [%[b_ptr], #40]\n" + + "fmla v15.4s, v0.4s, v3.s[3]\n" + "ins v6.d[1], x11\n" + + "fmla v16.4s, v0.4s, v4.s[0]\n" + + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "ins v7.d[1], x12\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "fmla v8.4s, v1.4s, v5.s[0]\n" + + "fmla v9.4s, v1.4s, v5.s[1]\n" + "str q8, [%[output0]]\n" + "fmla v10.4s, v1.4s, v5.s[2]\n" + "str q9, [%[output0], #16]\n" + "fmla v11.4s, v1.4s, v5.s[3]\n" + "str q10, [%[output0], #32]\n" + + "fmla v12.4s, v1.4s, v6.s[0]\n" + "str q11, [%[output0], #48]\n" + "fmla v13.4s, v1.4s, v6.s[1]\n" + "str q12, [%[output0], #64]\n" + "fmla v14.4s, v1.4s, v6.s[2]\n" + "str q13, [%[output0], #80]\n" + "fmla v15.4s, v1.4s, v6.s[3]\n" + "str q14, [%[output0], #96]\n" + "fmla v16.4s, v1.4s, v7.s[0]\n" + "str q15, [%[output0], #112]\n" + "fmla v17.4s, v1.4s, v7.s[1]\n" + "str q16, [%[output0], #128]\n" + "fmla v18.4s, v1.4s, v7.s[2]\n" + "str q17, [%[output0], #144]\n" + "fmla v19.4s, v1.4s, v7.s[3]\n" + "str q18, [%[output0], #160]\n" + "str q19, [%[output0], #176]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "str q8, [%[output0]]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "str q9, [%[output0], #16]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "str q10, [%[output0], #32]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "str q11, [%[output0], #48]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "str q12, [%[output0], #64]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "str q13, [%[output0], #80]\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "str q14, [%[output0], #96]\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "str q15, [%[output0], #112]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "str q16, [%[output0], #128]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "str q17, [%[output0], #144]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "str q18, [%[output0], #160]\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + "str q19, [%[output0], #176]\n" + + "6:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); + } + + // Overview of register layout: + // + // A 2x4 cell of Rhs is stored in 32bit in v2 - v3 + // A 4x2 cell of Lhs is stored in 32bit in v0 - v1 + // A 4x4 block of accumulators is stored in 32bit in v4-v6 + // + // +--------+ + // | v2[0-3]| + // | v5[0-3]| + // Rhs +--------+ + // + // | | + // + // Lhs | | + // + // +--+ --- - +--------+ + // |v0| | v8[0-3]| + // |v0| |v11[0-3]| + // |v0| |v14[0-3]| + // |v0| |v17[0-3]| + // +--+ --- - +--------+ + // + // Accumulator + static void kern_4x4(const float* packA, const float* packB, int K, + float* output, int LDC, bool is_first_k, + int n_remain) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + //clang-format off +#define LOAD_C \ + "cmp %w[n_remain], #4\n" \ + "blt 11f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "11:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 12f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "12:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 13f\n" \ + "ld1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "13:\n" \ + "ld1 {v8.4s}, [%[output0]]\n" \ + "14:\n" + +#define STORE_C \ + "cmp %w[n_remain], #4\n" \ + "blt 21f\n" \ + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "21:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 22f\n" \ + "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "22:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 23f\n" \ + "st1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "23:\n" \ + "st1 {v8.4s}, [%[output0]]\n" \ + "24:\n" + //clang-format on + + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "eor v9.16b, v9.16b, v9.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "fmla v8.4s, v1.4s, v3.s[0]\n" + "fmla v9.4s, v1.4s, v3.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v1.4s, v3.s[2]\n" + "fmla v11.4s, v1.4s, v3.s[3]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "fmla v8.4s, v1.4s, v3.s[0]\n" + "fmla v9.4s, v1.4s, v3.s[1]\n" + "fmla v10.4s, v1.4s, v3.s[2]\n" + "fmla v11.4s, v1.4s, v3.s[3]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [output0] "+r"(output0), [n_remain] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", + "memory"); + +#undef LOAD_C +#undef STORE_C + } +}; + +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp b/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp index 4fa5f6b8d..7e34ac457 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp @@ -6,42 +6,55 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" +#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" +#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h" #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" +#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" +#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" +#include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/common/utils.h" +#if MGB_ENABLE_CPUINFO +#include "cpuinfo.h" +#endif + using namespace megdnn; using namespace aarch64; using namespace aarch64::matmul; MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); -void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, - int ymax, int k0, int kmax, bool transpose_A) const { +void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax, + int k0, int kmax, bool transpose_A) const { if (transpose_A) { - matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); + matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, + kmax); } else { - matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); + matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); } } void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, bool transpose_B) const { if (transpose_B) { - matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); + matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, + kmax); } else { - matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); } } -void sgemm_4x16::kern(const float* packA, const float* packB, - size_t M, size_t N, size_t K, float* C, size_t LDC, - bool is_first_k, const float*, float*) const { +void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, + size_t N, size_t K, float* C, size_t LDC, bool is_first_k, + const float*, float*) const { megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && A_dtype.enumv() == DTypeEnum::Float32); @@ -61,15 +74,17 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t n = 0; const float* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4)); + matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, + is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K16; } for (; n < N; n += 4) { - matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), std::min(N - n, 4)); + matmul_general_4x16::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -80,8 +95,8 @@ void sgemm_4x16::kern(const float* packA, const float* packB, MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); -void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, - int ymax, int k0, int kmax, bool transpose_A) const { +void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax, + int k0, int kmax, bool transpose_A) const { if (transpose_A) { matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); @@ -102,16 +117,10 @@ void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, } } -void sgemm_8x12::kern(const float* packA, const float* packB, - size_t M, size_t N, size_t K, float* C, size_t LDC, - bool is_first_k, const float*, float*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == C_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Float32); - MEGDNN_MARK_USED_VAR(A_dtype); - MEGDNN_MARK_USED_VAR(B_dtype); - MEGDNN_MARK_USED_VAR(C_dtype); - +template +static inline void sgemm_8x12_helper(const float* packA, const float* packB, + size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k) { constexpr size_t A_INTERLEAVE = 8; constexpr size_t A_INTERLEAVE4 = 4; constexpr size_t B_INTERLEAVE = 12; @@ -126,16 +135,14 @@ void sgemm_8x12::kern(const float* packA, const float* packB, size_t n = 0; const float* cur_packB = packB; for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { - matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC, - is_first_k); + gemm_class::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K12; } for (; n < N; n += 4) { - matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC, - is_first_k, - std::min(N - n, 4)); + gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -146,17 +153,16 @@ void sgemm_8x12::kern(const float* packA, const float* packB, size_t n = 0; const float* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC, - is_first_k, - std::min(M - m, 4)); + gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K12; } for (; n < N; n += 4) { - matmul_general_8x12::kern_4x4( - packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), std::min(N - n, 4)); + gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), + std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -164,6 +170,33 @@ void sgemm_8x12::kern(const float* packA, const float* packB, } } +void sgemm_8x12::kern(const float* packA, const float* packB, size_t M, + size_t N, size_t K, float* C, size_t LDC, bool is_first_k, + const float*, float*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); +#if !MGB_ENABLE_CPUINFO + sgemm_8x12_helper(packA, packB, M, N, K, C, LDC, + is_first_k); +#else + auto arch = cpuinfo_get_current_core()->uarch; + if (arch == cpuinfo_uarch_cortex_a53) { + sgemm_8x12_helper(packA, packB, M, N, K, C, + LDC, is_first_k); + } else if (arch == cpuinfo_uarch_cortex_a55) { + sgemm_8x12_helper(packA, packB, M, N, K, C, + LDC, is_first_k); + } else { + sgemm_8x12_helper(packA, packB, M, N, K, C, LDC, + is_first_k); + } +#endif +} + MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0, @@ -180,25 +213,17 @@ void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0, matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); } -void sgemm_mk4_8x12::kern(const float* packA, const float* packB, - size_t M, size_t N, size_t K, float* C, size_t LDC, - bool is_first_k, const float*, float*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == C_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Float32); - MEGDNN_MARK_USED_VAR(A_dtype); - MEGDNN_MARK_USED_VAR(B_dtype); - MEGDNN_MARK_USED_VAR(C_dtype); - megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); - +template +static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, + size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k) { + const int K12 = K * 12; + const int K8 = K * 8; + const int K4 = K * 4; constexpr size_t PACK_C_SIZE = 4; constexpr size_t A_INTERLEAVE = 8; constexpr size_t A_INTERLEAVE4 = 4; constexpr size_t B_INTERLEAVE = 12; - const int K12 = K * 12; - const int K8 = K * 8; - const int K4 = K * 4; - size_t m = 0; for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) { float* output = C + (m / PACK_C_SIZE * LDC); @@ -206,15 +231,14 @@ void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t n = 0; const float* cur_packB = packB; for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { - matmul_mk4_8x12::kern_8x12(packA, cur_packB, K, output, LDC, - is_first_k); + gemm_name::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE * PACK_C_SIZE; cur_packB += K12; } - for (; n < N; n += 4) { - matmul_mk4_8x12::kern_8x4(packA, cur_packB, K, output, LDC, - is_first_k, std::min(N - n, 4)); + for (; n < N; n += 4) { + gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4 * PACK_C_SIZE; cur_packB += K4; } @@ -225,19 +249,45 @@ void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t n = 0; const float* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_mk4_8x12::kern_4x12(packA, cur_packB, K, output, LDC, - is_first_k); + gemm_name::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE * PACK_C_SIZE; cur_packB += K12; } for (; n < N; n += 4) { - matmul_mk4_8x12::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, std::min(N - n, 4)); + gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4 * PACK_C_SIZE; cur_packB += K4; } packA += K4; } } +void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M, + size_t N, size_t K, float* C, size_t LDC, + bool is_first_k, const float*, float*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); +#if !MGB_ENABLE_CPUINFO + sgemm_mk4_8x12_helper(packA, packB, M, N, K, C, LDC, + is_first_k); +#else + auto arch = cpuinfo_get_current_core()->uarch; + if (arch == cpuinfo_uarch_cortex_a53) { + sgemm_mk4_8x12_helper(packA, packB, M, N, K, C, + LDC, is_first_k); + } else if (arch == cpuinfo_uarch_cortex_a55) { + sgemm_mk4_8x12_helper(packA, packB, M, N, K, C, + LDC, is_first_k); + } else { + sgemm_mk4_8x12_helper(packA, packB, M, N, K, C, LDC, + is_first_k); + } +#endif +} // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy.h b/dnn/src/aarch64/matrix_mul/fp32/strategy.h index a2faf6eda..b1cfe2d87 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/strategy.h +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "src/fallback/matrix_mul/gemm_common.h" diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h new file mode 100644 index 000000000..6059c6464 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h @@ -0,0 +1,1265 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_mk4_16x12x4_a53 { + +//! optimize for A53 + +// clang-format off +/** + * Overview of register layout: + * + * A 16x12x4 cell of Lhs is stored in 16bit in q0-q3 + * A 16x12x4 cell of Rhs is stored in 8bit in q4-q7 + * A 16x12 block of accumulators is stored in 16bit in q8-q31 + * + * +------------------------------------------------------------------------+ + * | q4[0]|q4[1]|q4[2]|q4[3]|q4[4]|q4[5]|q4[6]|q4[7]|q5[0]|q5[1]|q5[2]|q5[3]| + * Rhs +------------------------------------------------------------------------+ + * Lhs | | | | | | | | | | | | | + * +--------+ - - - - +------------------------------------------------------------------------+ + * | q0 | | q8 | q9 | q10 | q11 | q12 | q13 | q14 | q15 | q16 | q17 | q18 | q19 | + * | q1 | | q20 | q21 | q22 | q23 | q24 | q25 | q26 | q27 | q28 | q29 | q30 | q31 | + * +--------+ - - - - +------------------------------------------------------------------------+ + * + * Accumulator + */ +// clang-format on +static __attribute__((noinline)) void kern_16x12(const int16_t* packA, + const int8_t* packB, int K, + int16_t* output, int LDC, + bool is_first_k, + int remain_n) { + K /= 4; + const int16_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + + // clang-format off +#define STORE_LINE(reg0, reg1) \ + "cmp w10, #0 \n" \ + "beq 101f\n" \ + "st1 {v" reg0 ".4h}, [x0], #8\n" \ + "st1 {v" reg0 ".d}[1], [x1], #8\n" \ + "st1 {v" reg1 ".4h}, [x2], #8\n" \ + "st1 {v" reg1 ".d}[1], [x3], #8\n" \ + "subs w10, w10, #1\n" + +#define STORE_C \ + "mov w10, %w[remain_n]\n" \ + STORE_LINE("8", "20") \ + STORE_LINE("9", "21") \ + STORE_LINE("10", "22") \ + STORE_LINE("11", "23") \ + STORE_LINE("12", "24") \ + STORE_LINE("13", "25") \ + STORE_LINE("14", "26") \ + STORE_LINE("15", "27") \ + STORE_LINE("16", "28") \ + STORE_LINE("17", "29") \ + STORE_LINE("18", "30") \ + STORE_LINE("19", "31") + + // clang-format on + + register int16_t* outptr asm("x0") = output; + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + + "2: \n" + + "ldr d4, [%[b_ptr]]\n" + "ldr d5, [%[b_ptr], #8]\n" + "ldr q0, [%[a_ptr]]\n" + "subs %w[K], %w[K], #1\n" + "ldr q1, [%[a_ptr], #16]\n" + "sshll v4.8h, v4.8b, #0\n" + "cmp %w[K], #0\n" + "sshll v5.8h, v5.8b, #0\n" + "beq 4f\n" + + "3: \n" + + //! k0 + + "ldr d2, [%[a_ptr], #32]\n" + "nop\n" + + "mla v8.8h, v0.8h, v4.h[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + "mla v9.8h, v0.8h, v4.h[1]\n" + "mla v10.8h, v0.8h, v4.h[2]\n" + "mla v11.8h, v0.8h, v4.h[3]\n" + + "ldr d3, [%[a_ptr], #48]\n" + "ins v2.d[1], x8\n" + + "mla v12.8h, v0.8h, v4.h[4]\n" + "ldr x9, [%[a_ptr], #56]\n" + + "mla v13.8h, v0.8h, v4.h[5]\n" + "mla v14.8h, v0.8h, v4.h[6]\n" + "mla v15.8h, v0.8h, v4.h[7]\n" + + "ldr d6, [%[b_ptr], #12]\n" + "ins v3.d[1], x9\n" + + "mla v16.8h, v0.8h, v5.h[0]\n" + "ldr d7, [%[b_ptr], #20]\n" + "mla v17.8h, v0.8h, v5.h[1]\n" + "mla v18.8h, v0.8h, v5.h[2]\n" + "mla v19.8h, v0.8h, v5.h[3]\n" + "mla v20.8h, v1.8h, v4.h[0]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v21.8h, v1.8h, v4.h[1]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v22.8h, v1.8h, v4.h[2]\n" + "mla v23.8h, v1.8h, v4.h[3]\n" + "mla v24.8h, v1.8h, v4.h[4]\n" + "mla v25.8h, v1.8h, v4.h[5]\n" + "mla v26.8h, v1.8h, v4.h[6]\n" + "mla v27.8h, v1.8h, v4.h[7]\n" + "mla v28.8h, v1.8h, v5.h[0]\n" + "mla v29.8h, v1.8h, v5.h[1]\n" + "mla v30.8h, v1.8h, v5.h[2]\n" + "mla v31.8h, v1.8h, v5.h[3]\n" + + //! k1 + "ldr d0, [%[a_ptr], #64]\n" + "nop\n" + + "mla v8.8h, v2.8h, v6.h[0]\n" + "ldr x8, [%[a_ptr], #72]\n" + + "mla v9.8h, v2.8h, v6.h[1]\n" + "mla v10.8h, v2.8h, v6.h[2]\n" + "mla v11.8h, v2.8h, v6.h[3]\n" + + "ldr d1, [%[a_ptr], #80]\n" + "ins v0.d[1], x8\n" + + "mla v12.8h, v2.8h, v6.h[4]\n" + "ldr x9, [%[a_ptr], #88]\n" + + "mla v13.8h, v2.8h, v6.h[5]\n" + "mla v14.8h, v2.8h, v6.h[6]\n" + "mla v15.8h, v2.8h, v6.h[7]\n" + + "ldr d4, [%[b_ptr], #24]\n" + "ins v1.d[1], x9\n" + + "mla v16.8h, v2.8h, v7.h[0]\n" + "ldr d5, [%[b_ptr], #32]\n" + "mla v17.8h, v2.8h, v7.h[1]\n" + "mla v18.8h, v2.8h, v7.h[2]\n" + "mla v19.8h, v2.8h, v7.h[3]\n" + "sshll v4.8h, v4.8b, #0\n" + "mla v20.8h, v3.8h, v6.h[0]\n" + "sshll v5.8h, v5.8b, #0\n" + "mla v21.8h, v3.8h, v6.h[1]\n" + "mla v22.8h, v3.8h, v6.h[2]\n" + "mla v23.8h, v3.8h, v6.h[3]\n" + "mla v24.8h, v3.8h, v6.h[4]\n" + "mla v25.8h, v3.8h, v6.h[5]\n" + "mla v26.8h, v3.8h, v6.h[6]\n" + "mla v27.8h, v3.8h, v6.h[7]\n" + "mla v28.8h, v3.8h, v7.h[0]\n" + "mla v29.8h, v3.8h, v7.h[1]\n" + "mla v30.8h, v3.8h, v7.h[2]\n" + "mla v31.8h, v3.8h, v7.h[3]\n" + + //! k2 + "ldr d2, [%[a_ptr], #96]\n" + "nop\n" + + "mla v8.8h, v0.8h, v4.h[0]\n" + "ldr x8, [%[a_ptr], #104]\n" + + "mla v9.8h, v0.8h, v4.h[1]\n" + "mla v10.8h, v0.8h, v4.h[2]\n" + "mla v11.8h, v0.8h, v4.h[3]\n" + + "ldr d3, [%[a_ptr], #112]\n" + "ins v2.d[1], x8\n" + + "mla v12.8h, v0.8h, v4.h[4]\n" + "ldr x9, [%[a_ptr], #120]\n" + "mla v13.8h, v0.8h, v4.h[5]\n" + "mla v14.8h, v0.8h, v4.h[6]\n" + "mla v15.8h, v0.8h, v4.h[7]\n" + + "ldr d6, [%[b_ptr], #36]\n" + "ins v3.d[1], x9\n" + + "mla v16.8h, v0.8h, v5.h[0]\n" + "ldr d7, [%[b_ptr], #44]\n" + "mla v17.8h, v0.8h, v5.h[1]\n" + "mla v18.8h, v0.8h, v5.h[2]\n" + + "mla v19.8h, v0.8h, v5.h[3]\n" + "add %[a_ptr], %[a_ptr], #128\n" + + "mla v20.8h, v1.8h, v4.h[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + + "mla v21.8h, v1.8h, v4.h[1]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v22.8h, v1.8h, v4.h[2]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v23.8h, v1.8h, v4.h[3]\n" + "mla v24.8h, v1.8h, v4.h[4]\n" + "mla v25.8h, v1.8h, v4.h[5]\n" + "mla v26.8h, v1.8h, v4.h[6]\n" + "mla v27.8h, v1.8h, v4.h[7]\n" + "mla v28.8h, v1.8h, v5.h[0]\n" + "mla v29.8h, v1.8h, v5.h[1]\n" + "mla v30.8h, v1.8h, v5.h[2]\n" + "mla v31.8h, v1.8h, v5.h[3]\n" + + //! k3 + "ldr d0, [%[a_ptr]]\n" + "nop\n" + + "mla v8.8h, v2.8h, v6.h[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "mla v9.8h, v2.8h, v6.h[1]\n" + "mla v10.8h, v2.8h, v6.h[2]\n" + "mla v11.8h, v2.8h, v6.h[3]\n" + + "ldr d1, [%[a_ptr], #16]\n" + "ins v0.d[1], x8\n" + + "mla v12.8h, v2.8h, v6.h[4]\n" + "ldr x9, [%[a_ptr], #24]\n" + + "mla v13.8h, v2.8h, v6.h[5]\n" + "mla v14.8h, v2.8h, v6.h[6]\n" + "mla v15.8h, v2.8h, v6.h[7]\n" + + "ldr d4, [%[b_ptr]]\n" + "ins v1.d[1], x9\n" + + "mla v16.8h, v2.8h, v7.h[0]\n" + "ldr d5, [%[b_ptr], #8]\n" + "mla v17.8h, v2.8h, v7.h[1]\n" + "mla v18.8h, v2.8h, v7.h[2]\n" + "mla v19.8h, v2.8h, v7.h[3]\n" + "mla v20.8h, v3.8h, v6.h[0]\n" + "mla v21.8h, v3.8h, v6.h[1]\n" + "sshll v4.8h, v4.8b, #0\n" + "mla v22.8h, v3.8h, v6.h[2]\n" + "sshll v5.8h, v5.8b, #0\n" + "mla v23.8h, v3.8h, v6.h[3]\n" + "mla v24.8h, v3.8h, v6.h[4]\n" + "mla v25.8h, v3.8h, v6.h[5]\n" + "mla v26.8h, v3.8h, v6.h[6]\n" + "mla v27.8h, v3.8h, v6.h[7]\n" + "mla v28.8h, v3.8h, v7.h[0]\n" + "mla v29.8h, v3.8h, v7.h[1]\n" + "mla v30.8h, v3.8h, v7.h[2]\n" + "mla v31.8h, v3.8h, v7.h[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" //! tail + //! k0 + + "ldr d2, [%[a_ptr], #32]\n" + "nop\n" + + "mla v8.8h, v0.8h, v4.h[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + "mla v9.8h, v0.8h, v4.h[1]\n" + "mla v10.8h, v0.8h, v4.h[2]\n" + "mla v11.8h, v0.8h, v4.h[3]\n" + + "ldr d3, [%[a_ptr], #48]\n" + "ins v2.d[1], x8\n" + + "mla v12.8h, v0.8h, v4.h[4]\n" + "ldr x9, [%[a_ptr], #56]\n" + + "mla v13.8h, v0.8h, v4.h[5]\n" + "mla v14.8h, v0.8h, v4.h[6]\n" + "mla v15.8h, v0.8h, v4.h[7]\n" + + "ldr d6, [%[b_ptr], #12]\n" + "ins v3.d[1], x9\n" + + "mla v16.8h, v0.8h, v5.h[0]\n" + "ldr d7, [%[b_ptr], #20]\n" + "mla v17.8h, v0.8h, v5.h[1]\n" + "mla v18.8h, v0.8h, v5.h[2]\n" + "mla v19.8h, v0.8h, v5.h[3]\n" + "mla v20.8h, v1.8h, v4.h[0]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v21.8h, v1.8h, v4.h[1]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v22.8h, v1.8h, v4.h[2]\n" + "mla v23.8h, v1.8h, v4.h[3]\n" + "mla v24.8h, v1.8h, v4.h[4]\n" + "mla v25.8h, v1.8h, v4.h[5]\n" + "mla v26.8h, v1.8h, v4.h[6]\n" + "mla v27.8h, v1.8h, v4.h[7]\n" + "mla v28.8h, v1.8h, v5.h[0]\n" + "mla v29.8h, v1.8h, v5.h[1]\n" + "mla v30.8h, v1.8h, v5.h[2]\n" + "mla v31.8h, v1.8h, v5.h[3]\n" + + //! k1 + "ldr d0, [%[a_ptr], #64]\n" + "nop\n" + + "mla v8.8h, v2.8h, v6.h[0]\n" + "ldr x8, [%[a_ptr], #72]\n" + + "mla v9.8h, v2.8h, v6.h[1]\n" + "mla v10.8h, v2.8h, v6.h[2]\n" + "mla v11.8h, v2.8h, v6.h[3]\n" + + "ldr d1, [%[a_ptr], #80]\n" + "ins v0.d[1], x8\n" + + "mla v12.8h, v2.8h, v6.h[4]\n" + "ldr x9, [%[a_ptr], #88]\n" + + "mla v13.8h, v2.8h, v6.h[5]\n" + "mla v14.8h, v2.8h, v6.h[6]\n" + "mla v15.8h, v2.8h, v6.h[7]\n" + + "ldr d4, [%[b_ptr], #24]\n" + "ins v1.d[1], x9\n" + + "mla v16.8h, v2.8h, v7.h[0]\n" + "ldr d5, [%[b_ptr], #32]\n" + "mla v17.8h, v2.8h, v7.h[1]\n" + "mla v18.8h, v2.8h, v7.h[2]\n" + "mla v19.8h, v2.8h, v7.h[3]\n" + "sshll v4.8h, v4.8b, #0\n" + "mla v20.8h, v3.8h, v6.h[0]\n" + "sshll v5.8h, v5.8b, #0\n" + "mla v21.8h, v3.8h, v6.h[1]\n" + "mla v22.8h, v3.8h, v6.h[2]\n" + "mla v23.8h, v3.8h, v6.h[3]\n" + "mla v24.8h, v3.8h, v6.h[4]\n" + "mla v25.8h, v3.8h, v6.h[5]\n" + "mla v26.8h, v3.8h, v6.h[6]\n" + "mla v27.8h, v3.8h, v6.h[7]\n" + "mla v28.8h, v3.8h, v7.h[0]\n" + "mla v29.8h, v3.8h, v7.h[1]\n" + "mla v30.8h, v3.8h, v7.h[2]\n" + "mla v31.8h, v3.8h, v7.h[3]\n" + + //! k2 + "ldr d2, [%[a_ptr], #96]\n" + "nop\n" + + "mla v8.8h, v0.8h, v4.h[0]\n" + "ldr x8, [%[a_ptr], #104]\n" + + "mla v9.8h, v0.8h, v4.h[1]\n" + "mla v10.8h, v0.8h, v4.h[2]\n" + "mla v11.8h, v0.8h, v4.h[3]\n" + + "ldr d3, [%[a_ptr], #112]\n" + "ins v2.d[1], x8\n" + + "mla v12.8h, v0.8h, v4.h[4]\n" + "ldr x9, [%[a_ptr], #120]\n" + + "mla v13.8h, v0.8h, v4.h[5]\n" + "mla v14.8h, v0.8h, v4.h[6]\n" + "mla v15.8h, v0.8h, v4.h[7]\n" + + "ldr d6, [%[b_ptr], #36]\n" + "ins v3.d[1], x9\n" + + "mla v16.8h, v0.8h, v5.h[0]\n" + "ldr d7, [%[b_ptr], #44]\n" + "mla v17.8h, v0.8h, v5.h[1]\n" + "mla v18.8h, v0.8h, v5.h[2]\n" + "mla v19.8h, v0.8h, v5.h[3]\n" + "mla v20.8h, v1.8h, v4.h[0]\n" + "mla v21.8h, v1.8h, v4.h[1]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v22.8h, v1.8h, v4.h[2]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v23.8h, v1.8h, v4.h[3]\n" + "mla v24.8h, v1.8h, v4.h[4]\n" + "mla v25.8h, v1.8h, v4.h[5]\n" + "mla v26.8h, v1.8h, v4.h[6]\n" + "mla v27.8h, v1.8h, v4.h[7]\n" + "mla v28.8h, v1.8h, v5.h[0]\n" + "mla v29.8h, v1.8h, v5.h[1]\n" + "mla v30.8h, v1.8h, v5.h[2]\n" + "mla v31.8h, v1.8h, v5.h[3]\n" + + //! k3 + + "mla v8.8h, v2.8h, v6.h[0]\n" + "mla v9.8h, v2.8h, v6.h[1]\n" + "mla v10.8h, v2.8h, v6.h[2]\n" + "mla v11.8h, v2.8h, v6.h[3]\n" + "mla v12.8h, v2.8h, v6.h[4]\n" + "mla v13.8h, v2.8h, v6.h[5]\n" + "mla v14.8h, v2.8h, v6.h[6]\n" + "mla v15.8h, v2.8h, v6.h[7]\n" + "mla v16.8h, v2.8h, v7.h[0]\n" + "mla v17.8h, v2.8h, v7.h[1]\n" + "mla v18.8h, v2.8h, v7.h[2]\n" + "mla v19.8h, v2.8h, v7.h[3]\n" + "mla v20.8h, v3.8h, v6.h[0]\n" + "mla v21.8h, v3.8h, v6.h[1]\n" + "mla v22.8h, v3.8h, v6.h[2]\n" + "mla v23.8h, v3.8h, v6.h[3]\n" + "mla v24.8h, v3.8h, v6.h[4]\n" + "mla v25.8h, v3.8h, v6.h[5]\n" + "mla v26.8h, v3.8h, v6.h[6]\n" + "mla v27.8h, v3.8h, v6.h[7]\n" + "mla v28.8h, v3.8h, v7.h[0]\n" + "mla v29.8h, v3.8h, v7.h[1]\n" + "mla v30.8h, v3.8h, v7.h[2]\n" + "cmp %w[remain_n], #12\n" + "mla v31.8h, v3.8h, v7.h[3]\n" + + "bne 6f\n" + "5:\n" + + "st1 {v8.4h, v9.4h}, [x0], #16\n" + "st1 {v10.4h, v11.4h}, [x0], #16\n" + "st1 {v12.4h, v13.4h}, [x0], #16\n" + "st1 {v14.4h, v15.4h}, [x0], #16\n" + "st1 {v16.4h, v17.4h}, [x0], #16\n" + "st1 {v18.4h, v19.4h}, [x0], #16\n" + + "st1 {v8.d} [1], [x1], #8\n" + "st1 {v9.d} [1], [x1], #8\n" + "st1 {v10.d}[1], [x1], #8\n" + "st1 {v11.d}[1], [x1], #8\n" + "st1 {v12.d}[1], [x1], #8\n" + "st1 {v13.d}[1], [x1], #8\n" + "st1 {v14.d}[1], [x1], #8\n" + "st1 {v15.d}[1], [x1], #8\n" + "st1 {v16.d}[1], [x1], #8\n" + "st1 {v17.d}[1], [x1], #8\n" + "st1 {v18.d}[1], [x1], #8\n" + "st1 {v19.d}[1], [x1], #8\n" + + "st1 {v20.4h, v21.4h}, [x2], #16\n" + "st1 {v22.4h, v23.4h}, [x2], #16\n" + "st1 {v24.4h, v25.4h}, [x2], #16\n" + "st1 {v26.4h, v27.4h}, [x2], #16\n" + "st1 {v28.4h, v29.4h}, [x2], #16\n" + "st1 {v30.4h, v31.4h}, [x2], #16\n" + + "st1 {v20.d}[1], [x3], #8\n" + "st1 {v21.d}[1], [x3], #8\n" + "st1 {v22.d}[1], [x3], #8\n" + "st1 {v23.d}[1], [x3], #8\n" + "st1 {v24.d}[1], [x3], #8\n" + "st1 {v25.d}[1], [x3], #8\n" + "st1 {v26.d}[1], [x3], #8\n" + "st1 {v27.d}[1], [x3], #8\n" + "st1 {v28.d}[1], [x3], #8\n" + "st1 {v29.d}[1], [x3], #8\n" + "st1 {v30.d}[1], [x3], #8\n" + "st1 {v31.d}[1], [x3], #8\n" + + "b 101f\n" + + "6:\n" STORE_C + + "101:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10", "cc", "memory"); + +#undef STORE_C +#undef STORE_LINE +} + +// clang-format off +/** + * Overview of register layout: + * + * A 8x12x4 cell of Lhs is stored in 16bit in q0-q3 + * A 8x12x4 cell of Rhs is stored in 8bit in q4-q7 + * A 8x12 block of accumulators is stored in 16bit in q8-q31 + * + * +------------------------------------------------------------------------+ + * | q4[0]|q4[1]|q4[2]|q4[3]|q4[4]|q4[5]|q4[6]|q4[7]|q5[0]|q5[1]|q5[2]|q5[3]| + * Rhs +------------------------------------------------------------------------+ + * Lhs | | | | | | | | | | | | | + * +--------+ - - - - +------------------------------------------------------------------------+ + * | q0 | | q8 | q9 | q10 | q11 | q12 | q13 | q14 | q15 | q16 | q17 | q18 | q19 | + * +--------+ - - - - +------------------------------------------------------------------------+ + * + * Accumulator + */ +// clang-format on +static __attribute__((noinline)) void kern_8x12(const int16_t* packA, + const int8_t* packB, int K, + int16_t* output, int LDC, + bool is_first_k, int remain_n) { + K /= 4; + const int16_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + + // clang-format off +#define STORE_LINE(reg0) \ + "cmp w10, #0 \n" \ + "beq 101f\n" \ + "st1 {v" reg0 ".4h}, [x0], #8\n" \ + "st1 {v" reg0 ".d}[1], [x1], #8\n" \ + "subs w10, w10, #1\n" + +#define STORE_C \ + "mov w10, %w[remain_n]\n" \ + STORE_LINE("8" ) \ + STORE_LINE("9" ) \ + STORE_LINE("10") \ + STORE_LINE("11") \ + STORE_LINE("12") \ + STORE_LINE("13") \ + STORE_LINE("14") \ + STORE_LINE("15") \ + STORE_LINE("16") \ + STORE_LINE("17") \ + STORE_LINE("18") \ + STORE_LINE("19") + + // clang-format on + + register int16_t* outptr asm("x0") = output; + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + + "ldr d4, [%[b_ptr]]\n" + "ldr d5, [%[b_ptr], #8]\n" + "ldr q0, [%[a_ptr]]\n" + "subs %w[K], %w[K], #1\n" + "sshll v4.8h, v4.8b, #0\n" + "cmp %w[K], #0\n" + "sshll v5.8h, v5.8b, #0\n" + "beq 4f\n" + + "3: \n" + + //! k0 + + "ldr d2, [%[a_ptr], #16]\n" + "nop\n" + "mla v8.8h, v0.8h, v4.h[0]\n" + "ldr x8, [%[a_ptr], #24]\n" + "mla v9.8h, v0.8h, v4.h[1]\n" + "mla v10.8h, v0.8h, v4.h[2]\n" + "mla v11.8h, v0.8h, v4.h[3]\n" + + "ldr d6, [%[b_ptr], #12]\n" + "ins v2.d[1], x8\n" + + "mla v12.8h, v0.8h, v4.h[4]\n" + "ldr d7, [%[b_ptr], #20]\n" + "mla v13.8h, v0.8h, v4.h[5]\n" + "mla v14.8h, v0.8h, v4.h[6]\n" + "mla v15.8h, v0.8h, v4.h[7]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v16.8h, v0.8h, v5.h[0]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v17.8h, v0.8h, v5.h[1]\n" + "mla v18.8h, v0.8h, v5.h[2]\n" + "mla v19.8h, v0.8h, v5.h[3]\n" + + //! k1 + "ldr d0, [%[a_ptr], #32]\n" + "nop\n" + + "mla v8.8h, v2.8h, v6.h[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + "mla v9.8h, v2.8h, v6.h[1]\n" + "mla v10.8h, v2.8h, v6.h[2]\n" + "mla v11.8h, v2.8h, v6.h[3]\n" + + "ldr d4, [%[b_ptr], #24]\n" + "ins v0.d[1], x8\n" + + "mla v12.8h, v2.8h, v6.h[4]\n" + "ldr d5, [%[b_ptr], #32]\n" + + "mla v13.8h, v2.8h, v6.h[5]\n" + "mla v14.8h, v2.8h, v6.h[6]\n" + "mla v15.8h, v2.8h, v6.h[7]\n" + "sshll v4.8h, v4.8b, #0\n" + "mla v16.8h, v2.8h, v7.h[0]\n" + "sshll v5.8h, v5.8b, #0\n" + "mla v17.8h, v2.8h, v7.h[1]\n" + "mla v18.8h, v2.8h, v7.h[2]\n" + "mla v19.8h, v2.8h, v7.h[3]\n" + + //! k2 + "ldr d2, [%[a_ptr], #48]\n" + "nop\n" + + "mla v8.8h, v0.8h, v4.h[0]\n" + "ldr x8, [%[a_ptr], #56]\n" + + "mla v9.8h, v0.8h, v4.h[1]\n" + "mla v10.8h, v0.8h, v4.h[2]\n" + "mla v11.8h, v0.8h, v4.h[3]\n" + + "ldr d6, [%[b_ptr], #36]\n" + "ins v2.d[1], x8\n" + + "mla v12.8h, v0.8h, v4.h[4]\n" + "ldr d7, [%[b_ptr], #44]\n" + + "mla v13.8h, v0.8h, v4.h[5]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "mla v14.8h, v0.8h, v4.h[6]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "mla v15.8h, v0.8h, v4.h[7]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v16.8h, v0.8h, v5.h[0]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v17.8h, v0.8h, v5.h[1]\n" + "mla v18.8h, v0.8h, v5.h[2]\n" + "mla v19.8h, v0.8h, v5.h[3]\n" + + //! k3 + "ldr d0, [%[a_ptr]]\n" + "nop\n" + + "mla v8.8h, v2.8h, v6.h[0]\n" + "ldr x8, [%[a_ptr], #8]\n" + + "mla v9.8h, v2.8h, v6.h[1]\n" + "mla v10.8h, v2.8h, v6.h[2]\n" + "mla v11.8h, v2.8h, v6.h[3]\n" + + "ldr d4, [%[b_ptr]]\n" + "ins v0.d[1], x8\n" + + "mla v12.8h, v2.8h, v6.h[4]\n" + "ldr d5, [%[b_ptr], #8]\n" + "mla v13.8h, v2.8h, v6.h[5]\n" + "mla v14.8h, v2.8h, v6.h[6]\n" + "mla v15.8h, v2.8h, v6.h[7]\n" + "sshll v4.8h, v4.8b, #0\n" + "mla v16.8h, v2.8h, v7.h[0]\n" + "sshll v5.8h, v5.8b, #0\n" + "mla v17.8h, v2.8h, v7.h[1]\n" + "subs %w[K], %w[K], #1\n" + "mla v18.8h, v2.8h, v7.h[2]\n" + "mla v19.8h, v2.8h, v7.h[3]\n" + + "bne 3b\n" + + "4:\n" // tail + //! k0 + + "ldr d2, [%[a_ptr], #16]\n" + "nop\n" + "mla v8.8h, v0.8h, v4.h[0]\n" + "ldr x8, [%[a_ptr], #24]\n" + "mla v9.8h, v0.8h, v4.h[1]\n" + "mla v10.8h, v0.8h, v4.h[2]\n" + "mla v11.8h, v0.8h, v4.h[3]\n" + + "ldr d6, [%[b_ptr], #12]\n" + "ins v2.d[1], x8\n" + + "mla v12.8h, v0.8h, v4.h[4]\n" + "ldr d7, [%[b_ptr], #20]\n" + "mla v13.8h, v0.8h, v4.h[5]\n" + "mla v14.8h, v0.8h, v4.h[6]\n" + "mla v15.8h, v0.8h, v4.h[7]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v16.8h, v0.8h, v5.h[0]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v17.8h, v0.8h, v5.h[1]\n" + "mla v18.8h, v0.8h, v5.h[2]\n" + "mla v19.8h, v0.8h, v5.h[3]\n" + + //! k1 + "ldr d0, [%[a_ptr], #32]\n" + "nop\n" + + "mla v8.8h, v2.8h, v6.h[0]\n" + "ldr x8, [%[a_ptr], #40]\n" + "mla v9.8h, v2.8h, v6.h[1]\n" + "mla v10.8h, v2.8h, v6.h[2]\n" + "mla v11.8h, v2.8h, v6.h[3]\n" + + "ldr d4, [%[b_ptr], #24]\n" + "ins v0.d[1], x8\n" + + "mla v12.8h, v2.8h, v6.h[4]\n" + "ldr d5, [%[b_ptr], #32]\n" + + "mla v13.8h, v2.8h, v6.h[5]\n" + "mla v14.8h, v2.8h, v6.h[6]\n" + "mla v15.8h, v2.8h, v6.h[7]\n" + "sshll v4.8h, v4.8b, #0\n" + "mla v16.8h, v2.8h, v7.h[0]\n" + "sshll v5.8h, v5.8b, #0\n" + "mla v17.8h, v2.8h, v7.h[1]\n" + "mla v18.8h, v2.8h, v7.h[2]\n" + "mla v19.8h, v2.8h, v7.h[3]\n" + + //! k2 + "ldr d2, [%[a_ptr], #48]\n" + "nop\n" + + "mla v8.8h, v0.8h, v4.h[0]\n" + "ldr x8, [%[a_ptr], #56]\n" + + "mla v9.8h, v0.8h, v4.h[1]\n" + "mla v10.8h, v0.8h, v4.h[2]\n" + "mla v11.8h, v0.8h, v4.h[3]\n" + + "ldr d6, [%[b_ptr], #36]\n" + "ins v2.d[1], x8\n" + + "mla v12.8h, v0.8h, v4.h[4]\n" + "ldr d7, [%[b_ptr], #44]\n" + + "mla v13.8h, v0.8h, v4.h[5]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "mla v14.8h, v0.8h, v4.h[6]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "mla v15.8h, v0.8h, v4.h[7]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v16.8h, v0.8h, v5.h[0]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v17.8h, v0.8h, v5.h[1]\n" + "mla v18.8h, v0.8h, v5.h[2]\n" + "mla v19.8h, v0.8h, v5.h[3]\n" + + //! k3 + "mla v8.8h, v2.8h, v6.h[0]\n" + "mla v9.8h, v2.8h, v6.h[1]\n" + "mla v10.8h, v2.8h, v6.h[2]\n" + "mla v11.8h, v2.8h, v6.h[3]\n" + "mla v12.8h, v2.8h, v6.h[4]\n" + "mla v13.8h, v2.8h, v6.h[5]\n" + "mla v14.8h, v2.8h, v6.h[6]\n" + "mla v15.8h, v2.8h, v6.h[7]\n" + "mla v16.8h, v2.8h, v7.h[0]\n" + "mla v17.8h, v2.8h, v7.h[1]\n" + "cmp %w[remain_n], #12\n" + "mla v18.8h, v2.8h, v7.h[2]\n" + "mla v19.8h, v2.8h, v7.h[3]\n" + + "bne 6f\n" + "5:\n" + + "st1 {v8.4h, v9.4h}, [x0], #16\n" + "st1 {v10.4h, v11.4h}, [x0], #16\n" + "st1 {v12.4h, v13.4h}, [x0], #16\n" + "st1 {v14.4h, v15.4h}, [x0], #16\n" + "st1 {v16.4h, v17.4h}, [x0], #16\n" + "st1 {v18.4h, v19.4h}, [x0], #16\n" + + "st1 {v8.d} [1], [x1], #8\n" + "st1 {v9.d} [1], [x1], #8\n" + "st1 {v10.d}[1], [x1], #8\n" + "st1 {v11.d}[1], [x1], #8\n" + "st1 {v12.d}[1], [x1], #8\n" + "st1 {v13.d}[1], [x1], #8\n" + "st1 {v14.d}[1], [x1], #8\n" + "st1 {v15.d}[1], [x1], #8\n" + "st1 {v16.d}[1], [x1], #8\n" + "st1 {v17.d}[1], [x1], #8\n" + "st1 {v18.d}[1], [x1], #8\n" + "st1 {v19.d}[1], [x1], #8\n" + + "b 101f\n" + + "6:\n" STORE_C + + "101:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", + "memory"); + +#undef STORE_C +#undef STORE_LINE +} + +// clang-format off +/** + * Overview of register layout: + * + * A 4x12x4 cell of Lhs is stored in 16bit in q0-q3 + * A 4x12x4 cell of Rhs is stored in 8bit in q4-q7 + * A 4x12 block of accumulators is stored in 16bit in q8-q31 + * + * +------------------------------------------------------------------------+ + * | q4[0]|q4[1]|q4[2]|q4[3]|q4[4]|q4[5]|q4[6]|q4[7]|q5[0]|q5[1]|q5[2]|q5[3]| + * Rhs +------------------------------------------------------------------------+ + * Lhs | | | | | | | | | | | | | + * +--------+ - - - - +------------------------------------------------------------------------+ + * | d0 | | d8 | d9 | d10 | d11 | d12 | d13 | d14 | d15 | d16 | d17 | d18 | d19 | + * +--------+ - - - - +------------------------------------------------------------------------+ + * + * Accumulator + */ +// clang-format on +static __attribute__((noinline)) void kern_4x12(const int16_t* packA, + const int8_t* packB, int K, + int16_t* output, int LDC, + bool is_first_k, int remain_n) { + K /= 4; + const int16_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + + // clang-format off +#define STORE_LINE(reg0) \ + "cmp w10, #0 \n" \ + "beq 101f\n" \ + "st1 {v" reg0 ".4h}, [x0], #8\n" \ + "subs w10, w10, #1\n" + +#define STORE_C \ + "mov w10, %w[remain_n]\n" \ + STORE_LINE("8" ) \ + STORE_LINE("9" ) \ + STORE_LINE("10") \ + STORE_LINE("11") \ + STORE_LINE("12") \ + STORE_LINE("13") \ + STORE_LINE("14") \ + STORE_LINE("15") \ + STORE_LINE("16") \ + STORE_LINE("17") \ + STORE_LINE("18") \ + STORE_LINE("19") + + // clang-format on + + register int16_t* outptr asm("x0") = output; + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + + "ldr d4, [%[b_ptr]]\n" + "ldr d5, [%[b_ptr], #8]\n" + "ldr d0, [%[a_ptr]]\n" + "subs %w[K], %w[K], #1\n" + "sshll v4.8h, v4.8b, #0\n" + "cmp %w[K], #0\n" + "sshll v5.8h, v5.8b, #0\n" + "beq 4f\n" + + "3: \n" + + //! k0 + + "ldr d2, [%[a_ptr], #8]\n" + "nop\n" + "mla v8.4h, v0.4h, v4.h[0]\n" + "mla v9.4h, v0.4h, v4.h[1]\n" + "mla v10.4h, v0.4h, v4.h[2]\n" + "mla v11.4h, v0.4h, v4.h[3]\n" + + "ldr d6, [%[b_ptr], #12]\n" + + "mla v12.4h, v0.4h, v4.h[4]\n" + "ldr d7, [%[b_ptr], #20]\n" + "mla v13.4h, v0.4h, v4.h[5]\n" + "mla v14.4h, v0.4h, v4.h[6]\n" + "mla v15.4h, v0.4h, v4.h[7]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v16.4h, v0.4h, v5.h[0]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v17.4h, v0.4h, v5.h[1]\n" + "mla v18.4h, v0.4h, v5.h[2]\n" + "mla v19.4h, v0.4h, v5.h[3]\n" + + //! k1 + "ldr d0, [%[a_ptr], #16]\n" + "nop\n" + + "mla v8.4h, v2.4h, v6.h[0]\n" + "mla v9.4h, v2.4h, v6.h[1]\n" + "mla v10.4h, v2.4h, v6.h[2]\n" + "mla v11.4h, v2.4h, v6.h[3]\n" + + "ldr d4, [%[b_ptr], #24]\n" + + "mla v12.4h, v2.4h, v6.h[4]\n" + "ldr d5, [%[b_ptr], #32]\n" + + "mla v13.4h, v2.4h, v6.h[5]\n" + "mla v14.4h, v2.4h, v6.h[6]\n" + "mla v15.4h, v2.4h, v6.h[7]\n" + "sshll v4.8h, v4.8b, #0\n" + "mla v16.4h, v2.4h, v7.h[0]\n" + "sshll v5.8h, v5.8b, #0\n" + "mla v17.4h, v2.4h, v7.h[1]\n" + "mla v18.4h, v2.4h, v7.h[2]\n" + "mla v19.4h, v2.4h, v7.h[3]\n" + + //! k2 + "ldr d2, [%[a_ptr], #24]\n" + "nop\n" + + "mla v8.4h, v0.4h, v4.h[0]\n" + "mla v9.4h, v0.4h, v4.h[1]\n" + "mla v10.4h, v0.4h, v4.h[2]\n" + "mla v11.4h, v0.4h, v4.h[3]\n" + + "ldr d6, [%[b_ptr], #36]\n" + + "mla v12.4h, v0.4h, v4.h[4]\n" + "ldr d7, [%[b_ptr], #44]\n" + + "mla v13.4h, v0.4h, v4.h[5]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "mla v14.4h, v0.4h, v4.h[6]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "mla v15.4h, v0.4h, v4.h[7]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v16.4h, v0.4h, v5.h[0]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v17.4h, v0.4h, v5.h[1]\n" + "mla v18.4h, v0.4h, v5.h[2]\n" + "mla v19.4h, v0.4h, v5.h[3]\n" + + //! k3 + "ldr d0, [%[a_ptr]]\n" + "nop\n" + + "mla v8.4h, v2.4h, v6.h[0]\n" + "mla v9.4h, v2.4h, v6.h[1]\n" + "mla v10.4h, v2.4h, v6.h[2]\n" + "mla v11.4h, v2.4h, v6.h[3]\n" + + "ldr d4, [%[b_ptr]]\n" + + "mla v12.4h, v2.4h, v6.h[4]\n" + "ldr d5, [%[b_ptr], #8]\n" + "mla v13.4h, v2.4h, v6.h[5]\n" + "mla v14.4h, v2.4h, v6.h[6]\n" + "mla v15.4h, v2.4h, v6.h[7]\n" + "sshll v4.8h, v4.8b, #0\n" + "mla v16.4h, v2.4h, v7.h[0]\n" + "sshll v5.8h, v5.8b, #0\n" + "mla v17.4h, v2.4h, v7.h[1]\n" + "subs %w[K], %w[K], #1\n" + "mla v18.4h, v2.4h, v7.h[2]\n" + "mla v19.4h, v2.4h, v7.h[3]\n" + + "bne 3b\n" + + "4:\n" // tail + //! k0 + + "ldr d2, [%[a_ptr], #8]\n" + "nop\n" + "mla v8.4h, v0.4h, v4.h[0]\n" + "mla v9.4h, v0.4h, v4.h[1]\n" + "mla v10.4h, v0.4h, v4.h[2]\n" + "mla v11.4h, v0.4h, v4.h[3]\n" + + "ldr d6, [%[b_ptr], #12]\n" + + "mla v12.4h, v0.4h, v4.h[4]\n" + "ldr d7, [%[b_ptr], #20]\n" + "mla v13.4h, v0.4h, v4.h[5]\n" + "mla v14.4h, v0.4h, v4.h[6]\n" + "mla v15.4h, v0.4h, v4.h[7]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v16.4h, v0.4h, v5.h[0]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v17.4h, v0.4h, v5.h[1]\n" + "mla v18.4h, v0.4h, v5.h[2]\n" + "mla v19.4h, v0.4h, v5.h[3]\n" + + //! k1 + "ldr d0, [%[a_ptr], #16]\n" + "nop\n" + + "mla v8.4h, v2.4h, v6.h[0]\n" + "mla v9.4h, v2.4h, v6.h[1]\n" + "mla v10.4h, v2.4h, v6.h[2]\n" + "mla v11.4h, v2.4h, v6.h[3]\n" + + "ldr d4, [%[b_ptr], #24]\n" + + "mla v12.4h, v2.4h, v6.h[4]\n" + "ldr d5, [%[b_ptr], #32]\n" + + "mla v13.4h, v2.4h, v6.h[5]\n" + "mla v14.4h, v2.4h, v6.h[6]\n" + "mla v15.4h, v2.4h, v6.h[7]\n" + "sshll v4.8h, v4.8b, #0\n" + "mla v16.4h, v2.4h, v7.h[0]\n" + "sshll v5.8h, v5.8b, #0\n" + "mla v17.4h, v2.4h, v7.h[1]\n" + "mla v18.4h, v2.4h, v7.h[2]\n" + "mla v19.4h, v2.4h, v7.h[3]\n" + + //! k2 + "ldr d2, [%[a_ptr], #24]\n" + "nop\n" + + "mla v8.4h, v0.4h, v4.h[0]\n" + "mla v9.4h, v0.4h, v4.h[1]\n" + "mla v10.4h, v0.4h, v4.h[2]\n" + "mla v11.4h, v0.4h, v4.h[3]\n" + + "ldr d6, [%[b_ptr], #36]\n" + + "mla v12.4h, v0.4h, v4.h[4]\n" + "ldr d7, [%[b_ptr], #44]\n" + + "mla v13.4h, v0.4h, v4.h[5]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "mla v14.4h, v0.4h, v4.h[6]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "mla v15.4h, v0.4h, v4.h[7]\n" + "sshll v6.8h, v6.8b, #0\n" + "mla v16.4h, v0.4h, v5.h[0]\n" + "sshll v7.8h, v7.8b, #0\n" + "mla v17.4h, v0.4h, v5.h[1]\n" + "mla v18.4h, v0.4h, v5.h[2]\n" + "mla v19.4h, v0.4h, v5.h[3]\n" + + //! k3 + "mla v8.4h, v2.4h, v6.h[0]\n" + "mla v9.4h, v2.4h, v6.h[1]\n" + "mla v10.4h, v2.4h, v6.h[2]\n" + "mla v11.4h, v2.4h, v6.h[3]\n" + "mla v12.4h, v2.4h, v6.h[4]\n" + "mla v13.4h, v2.4h, v6.h[5]\n" + "mla v14.4h, v2.4h, v6.h[6]\n" + "mla v15.4h, v2.4h, v6.h[7]\n" + "mla v16.4h, v2.4h, v7.h[0]\n" + "cmp %w[remain_n], #12\n" + "mla v17.4h, v2.4h, v7.h[1]\n" + "mla v18.4h, v2.4h, v7.h[2]\n" + "mla v19.4h, v2.4h, v7.h[3]\n" + + "bne 6f\n" + "5:\n" + + "st1 {v8.4h, v9.4h}, [x0], #16\n" + "st1 {v10.4h, v11.4h}, [x0], #16\n" + "st1 {v12.4h, v13.4h}, [x0], #16\n" + "st1 {v14.4h, v15.4h}, [x0], #16\n" + "st1 {v16.4h, v17.4h}, [x0], #16\n" + "st1 {v18.4h, v19.4h}, [x0], #16\n" + "b 101f\n" + + "6:\n" STORE_C + + "101:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", + "memory"); + +#undef STORE_C +#undef STORE_LINE +} + +static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, + const dt_int8* inptr, int ldin, + int m0, int mmax, int k0, int kmax) { + megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + constexpr int pack_m = 16; + constexpr int pack_k = 4; + constexpr int pack_size = 4; + const int m_size = mmax - m0; + const int m_end = m_size / pack_m * pack_m + m0; + int remain_m = mmax - m_end; + + for (int m_idx = m0; m_idx < m_end; m_idx += pack_m) { + const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + const int8_t* inptr2 = inptr0 + 2 * ldin; + const int8_t* inptr3 = inptr0 + 3 * ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + + for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) { + interleave_4x4_16x4_s8_s16(inptr0, inptr1, inptr2, inptr3, outptr); + inptr0 += pack_size * pack_size; + inptr1 += pack_size * pack_size; + inptr2 += pack_size * pack_size; + inptr3 += pack_size * pack_size; + outptr += pack_m * pack_k; + } + } + int m_idx = m_end; + if (remain_m >= 8) { + const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) { + interleave_4x4_8x4_s8_s16(inptr0, inptr1, outptr); + inptr0 += pack_size * pack_size; + inptr1 += pack_size * pack_size; + outptr += 8 * pack_k; + } + remain_m -= 8; + m_idx += 8; + } + if (remain_m == 4) { + const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0; + const int k_size = kmax - k0; + memcpy_s8_s16(inptr0, outptr, k_size * pack_size); + } +} + +static void gemm_s8x8x16_mk4_16x12_pack_B(dt_int8* out, const dt_int8* in, + int ldin, int n0, int nmax, int k0, + int kmax) { + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + + constexpr int pack_n = 12; + constexpr int pack_size = 4; + int8_t tmpbuff[pack_n * pack_size] = {0}; + const int ksize = kmax - k0; + const int nsize = nmax - n0; + const int n_end = nsize / pack_n * pack_n + n0; + const int remain_n = nsize % pack_n; + int output_stride = ksize * pack_n; + int8_t* outptr_base = out; + + for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) { + const int8_t* inptr = in + k_idx / pack_size * ldin + n0 * pack_size; + prefetch_3x(inptr); + + auto outptr = outptr_base; + for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { + transpos_12x4_s8(inptr, outptr); + inptr += pack_n * pack_size; + outptr += output_stride; + } + if (remain_n > 0) { + memcpy(tmpbuff, inptr, sizeof(int8_t) * remain_n * pack_size); + transpos_12x4_s8(tmpbuff, outptr); + outptr += output_stride; + } + outptr_base += pack_n * pack_size; + } +} + +} // namespace matmul_mk4_16x12x4_a53 +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h new file mode 100644 index 000000000..417e51557 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h @@ -0,0 +1,387 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_mk4_4x4x8_a72 { + +//! optimize for A72 + +// clang-format off +/** + * Overview of register layout: + * + * A 4x4x8 cell of Lhs is stored in 8bit in q0-q3, q4-q7 + * A 4x4x8 cell of Rhs is stored in 8bit in q8-q11, q12-q15 + * A 4x4 block of accumulators is stored in 16bit in q16-q31 + * + * +------------------------+ + * | q8 | q9 | q10 | q11 | + * Rhs +------------------------+ + * Lhs | | | | | + * +--------+ - - - - +------------------------+ + * | q0 | | q16 | q20 | q24 | q28 | + * | q1 | | q17 | q21 | q25 | q29 | + * | q2 | | q18 | q22 | q26 | q30 | + * | q3 | | q19 | q23 | q27 | q31 | + * +--------+ - - - - +------------------------+ + * + * Accumulator + */ + +// clang-format on +static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool, int remain_n) { + K = div_ceil(K, 8); + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + const int8_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int8_t); +// clang-format off + #define STORE_LINE(reg0) \ + "cmp w10, #0 \n" \ + "beq 101f\n" \ + "st1 {v" reg0 ".4h}, [x0], #8\n" \ + "subs w10, w10, #1\n" + + #define STORE_C \ + "mov w10, %w[remain_n]\n" \ + STORE_LINE("16") \ + STORE_LINE("20") \ + STORE_LINE("24") \ + STORE_LINE("28") + + // clang-format on + + register int16_t* outptr asm("x0") = output; + asm volatile( + // load accumulator C + + "1:\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + + "2: \n" + + "ld1 {v0.8b, v1.8b}, [%[a_ptr]], #16\n" + "ld1 {v2.8b, v3.8b}, [%[a_ptr]], #16\n" + "ld1 {v8.8b, v9.8b}, [%[b_ptr]], #16\n" + "ld1 {v10.8b, v11.8b}, [%[b_ptr]], #16\n" + + "cmp %w[K], #0\n" + "beq 4f\n" + + "3: \n" + //! k = 0 + "smlal v16.8h, v0.8b, v8.8b\n" + "ld1 {v4.8b}, [%[a_ptr]], #8\n" + "smlal v17.8h, v1.8b, v8.8b\n" + "smlal v18.8h, v2.8b, v8.8b\n" + "ld1 {v5.8b}, [%[a_ptr]], #8\n" + "smlal v19.8h, v3.8b, v8.8b\n" + "smlal v20.8h, v0.8b, v9.8b\n" + "ld1 {v6.8b}, [%[a_ptr]], #8\n" + "smlal v21.8h, v1.8b, v9.8b\n" + "smlal v22.8h, v2.8b, v9.8b\n" + "ld1 {v7.8b}, [%[a_ptr]], #8\n" + "smlal v23.8h, v3.8b, v9.8b\n" + "smlal v24.8h, v0.8b, v10.8b\n" + "ld1 {v12.8b}, [%[b_ptr]], #8\n" + "smlal v25.8h, v1.8b, v10.8b\n" + "smlal v26.8h, v2.8b, v10.8b\n" + "ld1 {v13.8b}, [%[b_ptr]], #8\n" + "smlal v27.8h, v3.8b, v10.8b\n" + "smlal v28.8h, v0.8b, v11.8b\n" + "ld1 {v14.8b}, [%[b_ptr]], #8\n" + "smlal v29.8h, v1.8b, v11.8b\n" + "smlal v30.8h, v2.8b, v11.8b\n" + "ld1 {v15.8b}, [%[b_ptr]], #8\n" + "smlal v31.8h, v3.8b, v11.8b\n" + //! k = 8 + "smlal v16.8h, v4.8b, v12.8b\n" + "ld1 {v0.8b}, [%[a_ptr]], #8\n" + "smlal v17.8h, v5.8b, v12.8b\n" + "smlal v18.8h, v6.8b, v12.8b\n" + "ld1 {v1.8b}, [%[a_ptr]], #8\n" + "smlal v19.8h, v7.8b, v12.8b\n" + "smlal v20.8h, v4.8b, v13.8b\n" + "ld1 {v2.8b}, [%[a_ptr]], #8\n" + "smlal v21.8h, v5.8b, v13.8b\n" + "smlal v22.8h, v6.8b, v13.8b\n" + "ld1 {v3.8b}, [%[a_ptr]], #8\n" + "smlal v23.8h, v7.8b, v13.8b\n" + "smlal v24.8h, v4.8b, v14.8b\n" + "ld1 {v8.8b}, [%[b_ptr]], #8\n" + "smlal v25.8h, v5.8b, v14.8b\n" + "smlal v26.8h, v6.8b, v14.8b\n" + "ld1 {v9.8b}, [%[b_ptr]], #8\n" + "smlal v27.8h, v7.8b, v14.8b\n" + "smlal v28.8h, v4.8b, v15.8b\n" + "ld1 {v10.8b}, [%[b_ptr]], #8\n" + "smlal v29.8h, v5.8b, v15.8b\n" + "smlal v30.8h, v6.8b, v15.8b\n" + "ld1 {v11.8b}, [%[b_ptr]], #8\n" + "smlal v31.8h, v7.8b, v15.8b\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + //! even tail + //! k = 0 + "smlal v16.8h, v0.8b, v8.8b\n" + "ld1 {v4.8b}, [%[a_ptr]], #8\n" + "smlal v17.8h, v1.8b, v8.8b\n" + "smlal v18.8h, v2.8b, v8.8b\n" + "ld1 {v5.8b}, [%[a_ptr]], #8\n" + "smlal v19.8h, v3.8b, v8.8b\n" + "smlal v20.8h, v0.8b, v9.8b\n" + "ld1 {v6.8b}, [%[a_ptr]], #8\n" + "smlal v21.8h, v1.8b, v9.8b\n" + "smlal v22.8h, v2.8b, v9.8b\n" + "ld1 {v7.8b}, [%[a_ptr]], #8\n" + "smlal v23.8h, v3.8b, v9.8b\n" + "smlal v24.8h, v0.8b, v10.8b\n" + "ld1 {v12.8b}, [%[b_ptr]], #8\n" + "smlal v25.8h, v1.8b, v10.8b\n" + "smlal v26.8h, v2.8b, v10.8b\n" + "ld1 {v13.8b}, [%[b_ptr]], #8\n" + "smlal v27.8h, v3.8b, v10.8b\n" + "smlal v28.8h, v0.8b, v11.8b\n" + "ld1 {v14.8b}, [%[b_ptr]], #8\n" + "smlal v29.8h, v1.8b, v11.8b\n" + "smlal v30.8h, v2.8b, v11.8b\n" + "ld1 {v15.8b}, [%[b_ptr]], #8\n" + "smlal v31.8h, v3.8b, v11.8b\n" + //! k = 8 + "smlal v16.8h, v4.8b, v12.8b\n" + "smlal v17.8h, v5.8b, v12.8b\n" + "smlal v18.8h, v6.8b, v12.8b\n" + "smlal v19.8h, v7.8b, v12.8b\n" + "smlal v20.8h, v4.8b, v13.8b\n" + "smlal v21.8h, v5.8b, v13.8b\n" + "smlal v22.8h, v6.8b, v13.8b\n" + "smlal v23.8h, v7.8b, v13.8b\n" + "smlal v24.8h, v4.8b, v14.8b\n" + "smlal v25.8h, v5.8b, v14.8b\n" + "smlal v26.8h, v6.8b, v14.8b\n" + "smlal v27.8h, v7.8b, v14.8b\n" + "smlal v28.8h, v4.8b, v15.8b\n" + "smlal v29.8h, v5.8b, v15.8b\n" + "smlal v30.8h, v6.8b, v15.8b\n" + "smlal v31.8h, v7.8b, v15.8b\n" + "b 6f\n" + + "5:\n" + //! odd tail + "smlal v16.8h, v0.8b, v8.8b\n" + "smlal v17.8h, v1.8b, v8.8b\n" + "smlal v18.8h, v2.8b, v8.8b\n" + "smlal v19.8h, v3.8b, v8.8b\n" + "smlal v20.8h, v0.8b, v9.8b\n" + "smlal v21.8h, v1.8b, v9.8b\n" + "smlal v22.8h, v2.8b, v9.8b\n" + "smlal v23.8h, v3.8b, v9.8b\n" + "smlal v24.8h, v0.8b, v10.8b\n" + "smlal v25.8h, v1.8b, v10.8b\n" + "smlal v26.8h, v2.8b, v10.8b\n" + "smlal v27.8h, v3.8b, v10.8b\n" + "smlal v28.8h, v0.8b, v11.8b\n" + "smlal v29.8h, v1.8b, v11.8b\n" + "smlal v30.8h, v2.8b, v11.8b\n" + "smlal v31.8h, v3.8b, v11.8b\n" + + "6:\n" + //! reduece + "addp v16.8h, v16.8h, v17.8h\n" + "addp v18.8h, v18.8h, v19.8h\n" + "addp v20.8h, v20.8h, v21.8h\n" + "addp v22.8h, v22.8h, v23.8h\n" + "addp v24.8h, v24.8h, v25.8h\n" + "addp v26.8h, v26.8h, v27.8h\n" + + "addp v16.8h, v16.8h, v18.8h\n" + "addp v28.8h, v28.8h, v29.8h\n" + "addp v30.8h, v30.8h, v31.8h\n" + "addp v20.8h, v20.8h, v22.8h\n" + + "addp v16.8h, v16.8h, v16.8h\n" + "addp v20.8h, v20.8h, v20.8h\n" + + "addp v24.8h, v24.8h, v26.8h\n" + "addp v24.8h, v24.8h, v24.8h\n" + + "addp v28.8h, v28.8h, v30.8h\n" + "addp v28.8h, v28.8h, v28.8h\n" + + "cmp %w[remain_n], #4\n" + "bne 7f\n" + + "st1 {v16.4h}, [x0], #8\n" + "st1 {v20.4h}, [x0], #8\n" + "st1 {v24.4h}, [x0], #8\n" + "st1 {v28.4h}, [x0], #8\n" + "b 101f\n" + + "7:\n" STORE_C + + "101:\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr), + [remain_n] "+r"(remain_n) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10", "cc", "memory"); + +#undef STORE_C +#undef STORE_LINE +} +static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) { + int8x8x4_t in0 = vld4_s8(inptr); + vst1_s8(outptr + 0 * 8, in0.val[0]); + vst1_s8(outptr + 1 * 8, in0.val[1]); + vst1_s8(outptr + 2 * 8, in0.val[2]); + vst1_s8(outptr + 3 * 8, in0.val[3]); +} + +static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2, + dt_int8* outptr) { + int8x16_t in0 = vld1q_s8(inptr); + int8x16_t in1 = vld1q_s8(inptr2); + int32x4x2_t in_x2 = { + {vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; + vst2q_s32(reinterpret_cast(outptr), in_x2); +} + +static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { + int8x16_t in0 = vld1q_s8(inptr); + int8x16_t in1 = vdupq_n_s8(0); + int32x4x2_t in_x2 = { + {vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; + vst2q_s32(reinterpret_cast(outptr), in_x2); +} + +static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, + int ldin, int m0, int mmax, int k0, + int kmax) { + megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + constexpr int pack_m = 4; + constexpr int pack_k = 8; + constexpr int pack_size = 4; + const int ksize = kmax - k0; + const int remain_k = ksize % pack_k; + const int kend = kmax - remain_k; + int8_t tmpbuff[pack_m * pack_k]{0}; + + for (int m_idx = m0; m_idx < mmax; m_idx += pack_m) { + const int8_t* inptr0 = in + m_idx / pack_size * ldin + k0; + + for (int k_idx = k0; k_idx < kend; k_idx += pack_k) { + transpose_8x4_b(inptr0, out); + inptr0 += pack_m * pack_k; + out += pack_m * pack_k; + } + if (remain_k > 0) { + int8x16_t tmp = vld1q_s8(inptr0); + vst1q_s8(&tmpbuff[0], tmp); + transpose_8x4_b(&tmpbuff[0], out); + inptr0 += pack_m * pack_size; + out += pack_m * pack_k; + } + } +} + +static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in, + int ldin, int n0, int nmax, int k0, + int kmax) { + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + + constexpr int pack_n = 4; + constexpr int pack_k = 8; + constexpr int pack_size = 4; + const int ksize = kmax - k0; + const int packed_ksize = round_up(ksize, pack_k); + const int remain_k = ksize % pack_k; + const int kend = kmax - remain_k; + const int nsize = nmax - n0; + const int remain_n = nsize % pack_n; + const int nend = nmax - remain_n; + const int stride_input = pack_size * nsize; + int8_t tmpbuff[pack_n * pack_k]{0}; + int8_t tmpbuff2[pack_n * pack_k]{0}; + + for (int k_idx = k0; k_idx < kend; k_idx += pack_k) { + const int8_t* inptr = in + k_idx / pack_size * ldin + n0 * pack_size; + const int8_t* inptr2 = inptr + stride_input; + int8_t* outptr = out + k_idx * pack_n; + for (int n_idx = n0; n_idx < nend; n_idx += pack_n) { + interleve_8x4_b(inptr, inptr2, outptr); + inptr += pack_n * pack_size; + inptr2 += pack_n * pack_size; + outptr += pack_n * packed_ksize; + } + if (remain_n > 0) { + memcpy(&tmpbuff[0], inptr, remain_n * pack_size * sizeof(int8_t)); + memcpy(&tmpbuff2[0], inptr2, remain_n * pack_size * sizeof(int8_t)); + interleve_8x4_b(&tmpbuff[0], &tmpbuff2[0], outptr); + outptr += pack_n * packed_ksize; + } + } + if (remain_k > 0) { + const int8_t* inptr = in + kend / pack_size * ldin + n0 * pack_size; + int8_t* outptr = out + kend * pack_n; + for (int n_idx = n0; n_idx < nend; n_idx += pack_n) { + interleve_8x4_b_pad(inptr, outptr); + inptr += pack_n * pack_size; + outptr += pack_n * packed_ksize; + } + if (remain_n > 0) { + memcpy(&tmpbuff[0], inptr, remain_n * pack_size * sizeof(int8_t)); + interleve_8x4_b_pad(&tmpbuff[0], outptr); + outptr += pack_n * packed_ksize; + } + } +} + +} // namespace matmul_mk4_4x4x8_a72 +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp index ff495ab63..d96ce9641 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp @@ -6,12 +6,15 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" +#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" +#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" @@ -197,4 +200,161 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, packA += K4; } } + +// ===========================gemm_s8x8x16_mk4_16x12================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); + +void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in, + int ldin, int y0, int ymax, int k0, + int kmax, bool) const { + matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0, + ymax, k0, kmax); +} + +void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in, + int ldin, int x0, int xmax, int k0, + int kmax, bool) const { + matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0, + xmax, k0, kmax); +} + +void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, + const dt_int8* packB, size_t M, size_t N, + size_t K, dt_int16* C, size_t LDC, + bool is_first_k, const dt_int16*, + dt_int16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + C_dtype.enumv() == DTypeEnum::Int16 && + A_dtype.enumv() == DTypeEnum::Int8); + megdnn_assert(is_first_k == true, "only impl is_first_k"); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); + + constexpr size_t pack_size = 4; + constexpr size_t pack_m = 16; + constexpr size_t pack_n = 12; + const size_t remain_n = N % pack_n; + size_t remain_m = M % pack_m; + + size_t m_idx = 0; + for (; m_idx + pack_m <= M; m_idx += pack_m) { + int16_t* output = C + (m_idx / pack_size * LDC); + + size_t n_idx = 0; + const int8_t* cur_packB = packB; + for (; n_idx + pack_n <= N; n_idx += pack_n) { + matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, + is_first_k, pack_n); + output += pack_n * pack_size; + cur_packB += pack_n * K; + } + if (remain_n > 0) { + matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, + is_first_k, remain_n); + output += remain_n * pack_size; + cur_packB += pack_n * K; + } + packA += pack_m * K; + } + + if (remain_m >= 8) { + int16_t* output = C + (m_idx / pack_size * LDC); + size_t n_idx = 0; + const int8_t* cur_packB = packB; + for (; n_idx + pack_n <= N; n_idx += pack_n) { + matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, + is_first_k, pack_n); + output += pack_n * pack_size; + cur_packB += pack_n * K; + } + if (remain_n > 0) { + matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, + is_first_k, remain_n); + output += remain_n * pack_size; + cur_packB += pack_n * K; + } + packA += 8 * K; + m_idx += 8; + remain_m -= 8; + } + + if (remain_m == 4) { + int16_t* output = C + (m_idx / pack_size * LDC); + size_t n_idx = 0; + const int8_t* cur_packB = packB; + for (; n_idx + pack_n <= N; n_idx += pack_n) { + matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, + is_first_k, pack_n); + output += pack_n * pack_size; + cur_packB += pack_n * K; + } + if (remain_n > 0) { + matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, + is_first_k, remain_n); + output += remain_n * pack_size; + cur_packB += pack_n * K; + } + } +} + +// ===========================gemm_s8x8x16_mk4_4x4_a72================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); + +void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin, + int y0, int ymax, int k0, int kmax, + bool) const { + matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax, + k0, kmax); +} + +void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax, + bool) const { + matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax, + k0, kmax); +} + +void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int16* C, + size_t LDC, bool is_first_k, + const dt_int16*, dt_int16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + C_dtype.enumv() == DTypeEnum::Int16 && + A_dtype.enumv() == DTypeEnum::Int8); + megdnn_assert(is_first_k == true, "only impl is_first_k"); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); + + constexpr size_t pack_size = 4; + constexpr size_t pack_m = 4; + constexpr size_t pack_n = 4; + constexpr size_t pack_k = 8; + const size_t remain_n = N % pack_n; + const size_t nend = N - remain_n; + const size_t packed_k = round_up(K, pack_k); + + for (size_t m_idx = 0; m_idx < M; m_idx += pack_m) { + int16_t* output = C + (m_idx / pack_size * LDC); + + const int8_t* cur_packB = packB; + for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { + matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, pack_n); + output += pack_n * pack_size; + cur_packB += pack_n * packed_k; + } + if (remain_n > 0) { + matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, remain_n); + output += remain_n * pack_size; + cur_packB += pack_n * packed_k; + } + packA += pack_m * packed_k; + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h index 2b06ceefe..61c2dddfe 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -20,6 +21,11 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, gemm_s8x8x16_8x8); MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true, gemm_s8x8x16_4x4); +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false, + gemm_s8x8x16_mk4_4x4_a72); +MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16, + 16, 12, 4, false, false, + gemm_s8x8x16_mk4_16x12_a53); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp index 709bf64bc..9c19594c0 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.cpp +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -6,10 +6,11 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "src/aarch64/matrix_mul/opr_impl.h" #include "src/aarch64/matrix_mul/algos.h" +#include "src/aarch64/matrix_mul/opr_impl.h" #include "src/common/metahelper.h" #include "src/common/utils.h" @@ -36,6 +37,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #endif AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; + AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4; + AlgoInt8x8x16MK4_4x4x8 int8x8x16_mk4_4x4x8; AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; @@ -70,6 +73,8 @@ public: #endif all_algos.emplace_back(&int8x8x16_k4x4x16); all_algos.emplace_back(&int8x8x16_k8x8x8); + all_algos.emplace_back(&int8x8x16_mk4_4x4x8); + all_algos.emplace_back(&int8x8x16_mk4_16x12x4); all_algos.emplace_back(&int16x16x32_k12x8x1); all_algos.emplace_back(&int16x16x32_mk8_8x8); diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h index d7a625e48..1a5049825 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.h +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "src/arm_common/matrix_mul/opr_impl.h" @@ -21,28 +22,30 @@ public: SmallVector algo_pack() override; private: - class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 - class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1 - class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 - class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 - class AlgoF32Gemv; // Aarch64 F32 Gemv + class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 + class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1 + class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 + class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 + class AlgoF32Gemv; // Aarch64 F32 Gemv #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1 class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 #endif #if __ARM_FEATURE_DOTPROD - class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel - // 8x12x4 DotProduct - class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel - // 8x12x4 DotProduct + class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel + // 8x12x4 DotProduct + class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel + // 8x12x4 DotProduct #else class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 - class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 - class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 + class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 + class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 #endif - class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 - class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 + class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 + class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 + class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16 + class AlgoInt8x8x16MK4_4x4x8; // Aarch64 Int8x8x16 Kernel 4x4x8 class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 @@ -52,7 +55,7 @@ private: // 8x8x4 DotProduct class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct #else - class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 + class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 #endif class AlgoPack; diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index b71588c52..4daeb9c86 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -214,7 +214,6 @@ void* const ConvBiasImpl::sm_arm_common_algo_type = bool ConvBiasImpl::is_matmul_quantized_prefer( const ConvBiasImpl::NCBKernSizeParam& param) const { - // fallback::ConvBiasImpl::NCBKernParam conv_ncb_param; fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param( param, 0, param::MatrixMul::Format::DEFAULT, {}, 0, BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY); diff --git a/dnn/src/common/cpuinfo_arch_vendor.cpp b/dnn/src/common/cpuinfo_arch_vendor.cpp index f33a45995..00916e7d4 100644 --- a/dnn/src/common/cpuinfo_arch_vendor.cpp +++ b/dnn/src/common/cpuinfo_arch_vendor.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ - -#ifdef MGB_ENABLE_CPUINFO_CHECK +#include "src/common/utils.h" +#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO #include "cpuinfo_arch_vendor.h" diff --git a/dnn/src/common/cpuinfo_arch_vendor.h b/dnn/src/common/cpuinfo_arch_vendor.h index 8d11b9e80..efe701827 100644 --- a/dnn/src/common/cpuinfo_arch_vendor.h +++ b/dnn/src/common/cpuinfo_arch_vendor.h @@ -11,8 +11,8 @@ */ #pragma once - -#ifdef MGB_ENABLE_CPUINFO_CHECK +#include "src/common/utils.h" +#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO #include diff --git a/dnn/test/aarch64/matrix_mul.cpp b/dnn/test/aarch64/matrix_mul.cpp index 2cb5ec472..1b09be711 100644 --- a/dnn/test/aarch64/matrix_mul.cpp +++ b/dnn/test/aarch64/matrix_mul.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "test/aarch64/fixture.h" @@ -16,6 +17,7 @@ #include "test/common/matrix_mul.h" #include "test/common/rng.h" +#include "test/arm_common/cpuinfo_help.h" using namespace megdnn; using namespace test; @@ -24,6 +26,20 @@ TEST_F(AARCH64, MATRIX_MUL_FP32K8X12) { dtype::Float32{}, handle(), "AARCH64_F32K8X12X1"); } +#if MGB_ENABLE_CPUINFO +TEST_F(AARCH64, MATRIX_MUL_FP32K8X12_A53) { + CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); + matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, handle(), + "AARCH64_F32K8X12X1"); +} +TEST_F(AARCH64, MATRIX_MUL_FP32K8X12_A55) { + CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); + matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, handle(), + "AARCH64_F32K8X12X1"); +} +#endif TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) { matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, @@ -36,6 +52,20 @@ TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) { dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); } +#if MGB_ENABLE_CPUINFO +TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4_A53) { + CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); + matrix_mul::check_matrix_mul( + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), + "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); +} +TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4_A55) { + CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); + matrix_mul::check_matrix_mul( + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), + "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); +} +#endif TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { matrix_mul::check_matrix_mul( @@ -92,6 +122,18 @@ TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) { std::move(args)); } +TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16_4x4) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle(), "AARCH64_INT8X8X16_MK4_4X4X8", + param::MatrixMul::Format::MK4, 1); +} + +TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle(), "AARCH64_INT8X8X16_MK4_16X12X4", + param::MatrixMul::Format::MK4, 1); +} + TEST_F(AARCH64, MATRIX_MUL_INT8x8x32_K8x8x8) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "AARCH64_INT8X8X32_K8X8X8"); @@ -172,6 +214,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_K4X16) { }; run(256, 256, 128); + run(384, 384, 384); for (size_t k = 4; k <= 256; k *= 8) { for (size_t m = 4; m <= 256; m *= 4) { @@ -235,7 +278,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_8X8X8) { int32_used / int_used); }; - run(256, 256, 128); + run(256, 256, 256); for (size_t k = 4; k <= 256; k *= 8) { for (size_t m = 4; m <= 256; m *= 4) { @@ -297,6 +340,62 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT32_MK_4X4X16) { } } +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_mk4(handle()); + Benchmarker benchmarker_mk4_16x12(handle()); + benchmarker.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + benchmarker.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X16_K4X4X16")); + + param.format = MatrixMul::Param::Format::MK4; + benchmarker_mk4.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X16_MK4_4X4X8")); + benchmarker_mk4.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + + benchmarker_mk4_16x12.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X16_MK4_16X12X4")); + benchmarker_mk4_16x12.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + + auto run = [&](size_t M, size_t N, size_t K) { + auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; + auto mk_used = benchmarker_mk4.exec( + {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / + RUNS; + auto mk4_16x12_used = + benchmarker_mk4_16x12.exec( + {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / + RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms " + "%f Gflops speedup: %f, mk4_16x12 %f Gflops speedup: %f\n", + M, K, N, default_used, computations / default_used, mk_used, + computations / mk_used, default_used / mk_used, + computations / mk4_16x12_used, default_used / mk4_16x12_used); + }; + + run(384, 384, 384); +} + TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { constexpr size_t RUNS = 50; param::MatrixMul param; @@ -350,9 +449,11 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { run(256, 256, 128); - for (size_t k = 4; k <= 16; k *= 2) { - for (size_t m = 4; m <= 64; m *= 2) { - for (size_t n = 4; n <= 64; n *= 2) { + run(256, 256, 256); + + for (size_t k = 4; k <= 256; k *= 4) { + for (size_t m = 4; m <= 256; m *= 4) { + for (size_t n = 4; n <= 256; n *= 4) { run(m, n, k); } } diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index c299ae073..29f9be376 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -736,15 +736,21 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { } #endif -#if MEGDNN_ARMV7 TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) { +#if MEGDNN_ARMV7 const char* default_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"; const char* mk4_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"; printf("compare %s vs %s \n", default_algo, mk4_algo); BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3, dtype::Int8(), dtype::Int16()); -} +#else + const char* default_algo = "IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"; + const char* mk4_algo = "IM2COLMATMUL:AARCH64_INT8X8X16_MK4_4X4X8"; + printf("compare %s vs %s \n", default_algo, mk4_algo); + BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3, + dtype::Int8(), dtype::Int16()); #endif +} TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONV_NCHW44_INT8x8x32_VS_INT8x8x16_STRIDE1) { BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32("S8_CHAN_WISE_STRD1_NCHW44", diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index fae40b977..b918f439c 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -14,6 +14,8 @@ #include "test/common/benchmarker.h" #include "test/common/conv_bias.h" +#include "test/arm_common/cpuinfo_help.h" + using namespace megdnn; using namespace test; using namespace conv_bias; @@ -487,11 +489,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, handle(), "S8_CHAN_WISE_STRD2_NCHW44"); } -TEST_F(ARM_COMMON, - CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) { +TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) { Checker checker(handle()); - checker.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); + checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( + "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); checker.set_dtype(0, dtype::Int8()); checker.set_dtype(1, dtype::Int8()); checker.set_dtype(2, dtype::Int16()); @@ -505,8 +506,8 @@ TEST_F(ARM_COMMON, TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) { Checker checker(handle()); - checker.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); + checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( + "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); checker.set_dtype(0, dtype::Int8()); checker.set_dtype(1, dtype::Int8()); checker.set_dtype(2, dtype::Int16()); @@ -1803,8 +1804,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2_PREPROCESS) { handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ dtype::Float32(), dtype::Float32(), name); #if MEGDNN_AARCH64 - cb("IM2COLMATMUL:AARCH64_F32K8X12X1") - cb("IM2COLMATMUL:AARCH64_F32K4X16X1") + cb("IM2COLMATMUL:AARCH64_F32K8X12X1") cb("IM2COLMATMUL:AARCH64_F32K4X16X1") #elif MEGDNN_ARMV7 cb("IM2COLMATMUL:ARMV7_F32") #endif @@ -1858,6 +1858,94 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) { #undef cb } +//! CPUINFO ralated test +#if MEGDNN_AARCH64 +#if MGB_ENABLE_CPUINFO +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_A55) { +CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); +#define cb(name,stride) \ + check_conv_bias( \ + get_conv_bias_args({2, 3, 4, 5, 6, 7}, stride, false, false, false), \ + handle(), name); + + cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 1) + cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 2) +#undef cb +} +#endif +#endif + +#if MEGDNN_AARCH64 +#if MGB_ENABLE_CPUINFO +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_A53) { +CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); +#define cb(name,stride) \ + check_conv_bias( \ + get_conv_bias_args({2, 3, 4, 5, 6, 7}, stride, false, false, false), \ + handle(), name); + + cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 1) + cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 2) +#undef cb +} +#endif +#endif + +#if MEGDNN_AARCH64 +#if MGB_ENABLE_CPUINFO +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_MK4_PACK_F32_A55) { + CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); + using namespace conv_bias; + std::vector args = get_nchw44_conv_bias_args( + {2, 3, 7}, 1, false, false, false, false, false, true, true); + check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); + args = get_nchw44_conv_bias_args( + {2, 3, 7}, 2, false, false, false, false, false, true, true); + check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); +} +#endif +#endif + +#if MEGDNN_AARCH64 +#if MGB_ENABLE_CPUINFO +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_MK4_PACK_F32_A53) { + CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); + using namespace conv_bias; + std::vector args = get_nchw44_conv_bias_args( + {2, 3, 7}, 1, false, false, false, false, false, true, true); + check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); + args = get_nchw44_conv_bias_args( + {2, 3, 7}, 2, false, false, false, false, false, true, true); + check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); +} +#endif +#endif + + +#if MEGDNN_AARCH64 +#if MGB_ENABLE_CPUINFO +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_MK4_PACK_F32_A55) { + CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, 1, true, false, false); + check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); +} +#endif +#endif + +#if MEGDNN_AARCH64 +#if MGB_ENABLE_CPUINFO +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_MK4_PACK_F32_A53) { + CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, 1, true, false, false); + check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); +} +#endif +#endif + TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { UniformIntRNG rng{-50, 50}; @@ -2216,7 +2304,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { #undef cb } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPROCESS) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPROCESS) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; #define cb(name) \ @@ -2247,7 +2336,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPRO #undef cb } - TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; @@ -2276,19 +2364,21 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { #if MEGDNN_AARCH64 cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8"); cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"); - cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); + cb_nchw44("IM2COLMATMUL:AARCH64_INT8X8X16_MK4_4X4X8"); + cb_nchw44("IM2COLMATMUL:AARCH64_INT8X8X16_MK4_16X12X4"); #elif MEGDNN_ARMV7 - cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"); cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16"); cb_nchw44("IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"); #endif + cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); #undef cb #undef cb_nchw44 } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCESS) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCESS) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; #define cb(name) \ @@ -2311,7 +2401,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCES #undef cb } -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; #define cb(name) \ @@ -2415,8 +2506,9 @@ void checker_conv_bias_mul_int8x8x32(std::vector args, } } -void checker_conv_bias_int8x8x32_preprocess(std::vector args, - Handle* handle, const char* algo_name) { +void checker_conv_bias_int8x8x32_preprocess( + std::vector args, Handle* handle, + const char* algo_name) { using namespace conv_bias; Checker> checker( @@ -2461,7 +2553,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { #undef cb } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); @@ -2490,7 +2583,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { #undef cb } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1_PREPROCESS) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1_PREPROCESS) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true); @@ -2541,7 +2635,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, #undef cb } - TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S1) { UniformIntRNG rng{-50, 50}; @@ -2678,7 +2771,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { #undef cb } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32_FILTER_PREPROCESS) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_IM2COLMATMUL_INT8X8X32_FILTER_PREPROCESS) { using namespace conv_bias; std::vector args = get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); @@ -2722,7 +2816,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32_PREPROCESS) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args( - {2, 4, 7}, 1, false, false, false, false, false, true,true); + {2, 4, 7}, 1, false, false, false, false, false, true, true); #define cb(name) \ check_conv_bias_preprocess(args, handle(), nullptr, 0.001, \ dtype::Float32(), dtype::Float32(), \ @@ -2748,7 +2842,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { #undef cb } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args( {3}, 2, false, false, false, false, false, true, true, false); @@ -2884,12 +2979,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16_PREPROCESS) { NormalRNG rng(1); #if MEGDNN_AARCH64 check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{}, - dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, - "CONV1x1:AARCH64_F16_K8X24X1:48"); + dtype::Float16{}, dtype::Float16{}, + dtype::Float16{}, + "CONV1x1:AARCH64_F16_K8X24X1:48"); #elif MEGDNN_ARMV7 check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{}, - dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, - "CONV1x1:AARCH32_F16_K4X16X1:24"); + dtype::Float16{}, dtype::Float16{}, + dtype::Float16{}, + "CONV1x1:AARCH32_F16_K4X16X1:24"); #endif } @@ -2951,7 +3048,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM_PREPROCESS) { #undef cb } - #if MEGDNN_AARCH64 || MEGDNN_ARMV7 TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { UniformIntRNG rng{-50, 50}; @@ -3074,7 +3170,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) { cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); #endif #undef cb - } TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { @@ -3095,6 +3190,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { #if MEGDNN_AARCH64 cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); + cb_nchw44("CONV1x1:AARCH64_INT8X8X16_MK4_4X4X8:48"); + cb_nchw44("CONV1x1:AARCH64_INT8X8X16_MK4_16X12X4:48"); #elif MEGDNN_ARMV7 cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); @@ -3128,11 +3225,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) { #if MEGDNN_AARCH64 cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); - cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test + cb("CONV1x1:ARM_COMMON_INT8X8X16:24"); //! add nopack test #elif MEGDNN_ARMV7 cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); - cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test + cb("CONV1x1:ARM_COMMON_INT8X8X16:24"); //! add nopack test #endif #undef cb } @@ -3245,11 +3342,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4_PREPROCESS) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; -#define cb(name) \ - check_conv_bias_preprocess(get_nchw44_conv_bias_args({1}, 1, true, false, false), \ - handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ - dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ - dtype::QuantizedS8(60.25f), name); +#define cb(name) \ + check_conv_bias_preprocess( \ + get_nchw44_conv_bias_args({1}, 1, true, false, false), handle(), \ + &rng, epsilon, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), name); #if MEGDNN_AARCH64 cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24"); #elif MEGDNN_ARMV7 diff --git a/dnn/test/arm_common/cpuinfo.cpp b/dnn/test/arm_common/cpuinfo.cpp index d5c27d56c..0f7fdcc04 100644 --- a/dnn/test/arm_common/cpuinfo.cpp +++ b/dnn/test/arm_common/cpuinfo.cpp @@ -9,7 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#ifdef MGB_ENABLE_CPUINFO_CHECK +#include "src/common/utils.h" +#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO #include #include #include "gtest/gtest.h" @@ -18,7 +19,6 @@ namespace megdnn { namespace test { TEST(ARM_RUNTIME, CPUINFO_KIRIN980) { - ASSERT_TRUE(cpuinfo_initialize()); int right_soc = strcmp(cpuinfo_get_package(0)->name, "HiSilicon Kirin 980"); @@ -68,7 +68,6 @@ TEST(ARM_RUNTIME, CPUINFO_KIRIN980) { } TEST(ARM_RUNTIME, CPUINFO_SDM8150) { - ASSERT_TRUE(cpuinfo_initialize()); int right_soc = @@ -119,7 +118,6 @@ TEST(ARM_RUNTIME, CPUINFO_SDM8150) { } TEST(ARM_RUNTIME, CPUINFO_SDM660) { - ASSERT_TRUE(cpuinfo_initialize()); int right_soc = @@ -173,4 +171,3 @@ TEST(ARM_RUNTIME, CPUINFO_SDM660) { } // namespace megdnn #endif // vim: syntax=cpp.doxygen - diff --git a/dnn/test/arm_common/cpuinfo_help.cpp b/dnn/test/arm_common/cpuinfo_help.cpp new file mode 100644 index 000000000..85a66ce9e --- /dev/null +++ b/dnn/test/arm_common/cpuinfo_help.cpp @@ -0,0 +1,17 @@ +/** + * \file dnn/test/arm_common/cpuinfo_help.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/common/utils.h" +#include "test/arm_common/cpuinfo_help.h" +#if MGB_ENABLE_CPUINFO +std::mutex CpuInfoTmpReplace::m_cpuinfo_lock; +#endif +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/test/arm_common/cpuinfo_help.h b/dnn/test/arm_common/cpuinfo_help.h new file mode 100644 index 000000000..cfcdeddd2 --- /dev/null +++ b/dnn/test/arm_common/cpuinfo_help.h @@ -0,0 +1,47 @@ +/** + * \file dnn/test/arm_common/cpuinfo_help.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include +#include +#include "src/common/utils.h" + +#if MGB_ENABLE_CPUINFO +#include "cpuinfo.h" +extern const struct cpuinfo_core** cpuinfo_linux_cpu_to_core_map; +class CpuInfoTmpReplace { +public: + CpuInfoTmpReplace(enum cpuinfo_uarch arch) { + m_cpuinfo_lock.lock(); + for (uint32_t i = 0; i < cpuinfo_get_cores_count(); ++i) { + m_arch_bak_vec.push_back(cpuinfo_linux_cpu_to_core_map[i]->uarch); + ((struct cpuinfo_core**)cpuinfo_linux_cpu_to_core_map)[i]->uarch = + arch; + } + } + ~CpuInfoTmpReplace() { + if (m_arch_bak_vec.size() > 0) { + for (uint32_t i = 0; i < cpuinfo_get_cores_count(); ++i) { + ((struct cpuinfo_core**)cpuinfo_linux_cpu_to_core_map)[i] + ->uarch = m_arch_bak_vec[i]; + } + } + + m_cpuinfo_lock.unlock(); + } + +private: + static std::mutex m_cpuinfo_lock; + std::vector m_arch_bak_vec; +}; +#endif + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/test/x86/cpuinfo.cpp b/dnn/test/x86/cpuinfo.cpp index 2b30e75ff..b67375491 100644 --- a/dnn/test/x86/cpuinfo.cpp +++ b/dnn/test/x86/cpuinfo.cpp @@ -9,7 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#ifdef MGB_ENABLE_CPUINFO_CHECK +#include "src/common/utils.h" +#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO #include #include #include "gtest/gtest.h" @@ -18,14 +19,12 @@ namespace megdnn { namespace test { TEST(X86_RUNTIME, CPUINFO_XEON6130) { - ASSERT_TRUE(cpuinfo_initialize()); int right_cpu = strcmp(cpuinfo_get_package(0)->name, "Intel Xeon Gold 6130"); if (!right_cpu) { - ASSERT_TRUE(cpuinfo_get_processors()); ASSERT_TRUE(cpuinfo_has_x86_avx2()); @@ -44,4 +43,3 @@ TEST(X86_RUNTIME, CPUINFO_XEON6130) { } // namespace megdnn #endif // vim: syntax=cpp.doxygen - -- GitLab