From 07d1d0abd2643aaf1c45afd7857b2c308ab6f55a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 12 May 2020 00:10:36 +0800 Subject: [PATCH] feat(dnn/arm64): add fp32 mk4 matmul GitOrigin-RevId: f6df006547e08ba5b76be984a2fe87cf053c31de --- dnn/src/aarch64/matrix_mul/algos.cpp | 61 ++ dnn/src/aarch64/matrix_mul/algos.h | 11 + dnn/src/aarch64/matrix_mul/asm/common.h | 65 ++ .../matrix_mul/fp32/kernel_general_8x12.h | 4 + .../aarch64/matrix_mul/fp32/kernel_mk4_8x12.h | 874 ++++++++++++++++++ dnn/src/aarch64/matrix_mul/fp32/strategy.cpp | 77 ++ dnn/src/aarch64/matrix_mul/fp32/strategy.h | 3 + dnn/src/aarch64/matrix_mul/opr_impl.cpp | 2 + dnn/src/aarch64/matrix_mul/opr_impl.h | 1 + dnn/src/arm_common/conv_bias/f16/algos.cpp | 2 + dnn/src/arm_common/conv_bias/fp32/algos.cpp | 4 + dnn/src/arm_common/conv_bias/int8/algos.cpp | 2 + dnn/test/aarch64/matrix_mul.cpp | 15 + dnn/test/common/matrix_mul.cpp | 4 +- 14 files changed, 1123 insertions(+), 2 deletions(-) create mode 100644 dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index b98045c9..e8ad513d 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -86,6 +86,67 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, "AlgoF32K8x12x1Impl"_hash, aarch64::matmul::sgemm_8x12, float, float); +/* ===================== F32_MK4_8X12X1 algo ===================== */ +bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32() && + kern_size_param.format == param::MatrixMul::Format::MK4 && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; +} + +size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32MK4_8x12x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type, + C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern( + const KernSizeParam&) const { + auto f32_kern_mk4_8x12 = [](const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32MK4_8x12x1::get_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, + LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); + }; + return f32_kern_mk4_8x12; +} +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1, + megdnn_aarch64_matmul_kern, + "AlgoF32MK4_8x12x1Impl"_hash, + aarch64::matmul::sgemm_mk4_8x12, float, + float); + /* ===================== F32K4X16X1 algo ===================== */ bool MatrixMulImpl::AlgoF32K4x16x1::usable( diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index 266f1e47..dc46b5e8 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -29,6 +29,17 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; +class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { public: bool is_reproducible() const override { return true; } diff --git a/dnn/src/aarch64/matrix_mul/asm/common.h b/dnn/src/aarch64/matrix_mul/asm/common.h index fc9fa537..334d3f9b 100644 --- a/dnn/src/aarch64/matrix_mul/asm/common.h +++ b/dnn/src/aarch64/matrix_mul/asm/common.h @@ -1103,6 +1103,36 @@ static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1, : "v0", "v1", "v2", "v3", "cc", "memory"); } +template +static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1, + T* outptr) { + static_assert(sizeof(T) == 4, "interleave_2x4_4_s only support size == 4"); + asm volatile( + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr1]], #64\n" + "stp q0, q4, [%[outptr]]\n" + "stp q1, q5, [%[outptr], #32]\n" + "stp q2, q6, [%[outptr], #64]\n" + "stp q3, q7, [%[outptr], #96]\n" + + : [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1), + [ outptr ] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); +} + +template +static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) { + static_assert(sizeof(T) == 4, "interleave_1x4_4_s only support size == 4"); + asm volatile( + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n" + + : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "memory"); +} + template static inline void interleave_4x8_2_b(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, @@ -1479,6 +1509,41 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, "v11", "memory"); } +template +static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { + static_assert(sizeof(T) == 4, + "transpose_1x12_4_s only support sizeof(T) == 4"); + + asm volatile( + "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" + "ld4 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr0]], #64\n" + "ld4 {v8.4s, v9.4s, v10.4s, v11.4s},[%[inptr0]], #64\n" + + "stp q0, q4, [%[outptr]] \n" + "stp q8, q1, [%[outptr], #32] \n" + "stp q5, q9, [%[outptr], #64] \n" + "stp q2, q6, [%[outptr], #96] \n" + "stp q10, q3, [%[outptr], #128] \n" + "stp q7, q11, [%[outptr], #160] \n" + : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "memory"); +} + +template +static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { + static_assert(sizeof(T) == 4, + "transpose_1x4_4_s only support sizeof(T) == 4"); + + asm volatile( + "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n" + : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "memory"); +} + template static inline void transpose_8x4_1_s(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h index 076c63b9..2ed7787b 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h @@ -899,6 +899,10 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, : : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", "x10", "cc", "memory"); +#undef LOAD_LINE +#undef LOAD_C +#undef STORE_LINE +#undef STORE_C } void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, int y0, diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h new file mode 100644 index 00000000..6dccd097 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h @@ -0,0 +1,874 @@ +/** + * \file dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + + +namespace megdnn { +namespace aarch64 { +namespace matmul_mk4_8x12 { + +// Overview of register layout: +// +// A 1x12 cell of Rhs is stored in 32bit in v2-v7 +// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) +// A 8x12 block of accumulators is stored in 32bit in v8-v31. +// +// +--------+--------+--------+ +// | v2[0-3]| v3[0-3]| v4[0-3]| +// | v5[0-3]| v6[0-3]| v7[0-3]| +// Rhs +--------+--------+--------+ +// +// | | | | +// +// Lhs | | | | +// +// +--+ --- - +--------+--------+--------+ +// |v0| | v8[0-3]| v9[0-3]|v10[0-3]| +// |v0| |v11[0-3]|v12[0-3]|v13[0-3]| +// |v0| |v14[0-3]|v15[0-3]|v16[0-3]| +// |v0| |v17[0-3]|v18[0-3]|v19[0-3]| +// |v1| |v20[0-3]|v21[0-3]|v22[0-3]| +// |v1| |v23[0-3]|v24[0-3]|v25[0-3]| +// |v1| |v26[0-3]|v27[0-3]|v28[0-3]| +// |v1| |v29[0-3]|v30[0-3]|v31[0-3]| +// +--+ --- - +--------+--------+--------+ +// +// Accumulator +void kern_8x12(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k) { + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + float* output1 = output0 + LDC; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + asm volatile( + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "mov x1, %[output0]\n" + "mov x2, %[output1]\n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64\n" + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "prfm pstl1keep, [%[output1]]\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "eor v20.16b, v20.16b, v20.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v21.16b, v21.16b, v21.16b\n" + "eor v22.16b, v22.16b, v22.16b\n" + "eor v23.16b, v23.16b, v23.16b\n" + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "fmla v21.4s, v1.4s, v2.s[1]\n" + "fmla v22.4s, v1.4s, v2.s[2]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "fmla v24.4s, v1.4s, v3.s[0]\n" + "fmla v25.4s, v1.4s, v3.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v1.4s, v3.s[2]\n" + "fmla v27.4s, v1.4s, v3.s[3]\n" + "fmla v28.4s, v1.4s, v4.s[0]\n" + "fmla v29.4s, v1.4s, v4.s[1]\n" + "fmla v30.4s, v1.4s, v4.s[2]\n" + "fmla v31.4s, v1.4s, v4.s[3]\n" + + "fmla v8.4s, v0.4s, v5.s[0]\n" + "fmla v9.4s, v0.4s, v5.s[1]\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "fmla v10.4s, v0.4s, v5.s[2]\n" + "fmla v11.4s, v0.4s, v5.s[3]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v12.4s, v0.4s, v6.s[0]\n" + "fmla v13.4s, v0.4s, v6.s[1]\n" + "fmla v14.4s, v0.4s, v6.s[2]\n" + "fmla v15.4s, v0.4s, v6.s[3]\n" + "fmla v16.4s, v0.4s, v7.s[0]\n" + "fmla v17.4s, v0.4s, v7.s[1]\n" + "fmla v18.4s, v0.4s, v7.s[2]\n" + "fmla v19.4s, v0.4s, v7.s[3]\n" + + "fmla v20.4s, v1.4s, v5.s[0]\n" + "fmla v21.4s, v1.4s, v5.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v22.4s, v1.4s, v5.s[2]\n" + "fmla v23.4s, v1.4s, v5.s[3]\n" + "fmla v24.4s, v1.4s, v6.s[0]\n" + "subs %w[K], %w[K], #1\n" + "fmla v25.4s, v1.4s, v6.s[1]\n" + "fmla v26.4s, v1.4s, v6.s[2]\n" + "fmla v27.4s, v1.4s, v6.s[3]\n" + "fmla v28.4s, v1.4s, v7.s[0]\n" + "fmla v29.4s, v1.4s, v7.s[1]\n" + "fmla v30.4s, v1.4s, v7.s[2]\n" + "fmla v31.4s, v1.4s, v7.s[3]\n" + + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "fmla v21.4s, v1.4s, v2.s[1]\n" + "fmla v22.4s, v1.4s, v2.s[2]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "fmla v24.4s, v1.4s, v3.s[0]\n" + "fmla v25.4s, v1.4s, v3.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v26.4s, v1.4s, v3.s[2]\n" + "fmla v27.4s, v1.4s, v3.s[3]\n" + "fmla v28.4s, v1.4s, v4.s[0]\n" + "fmla v29.4s, v1.4s, v4.s[1]\n" + "fmla v30.4s, v1.4s, v4.s[2]\n" + "fmla v31.4s, v1.4s, v4.s[3]\n" + + "fmla v8.4s, v0.4s, v5.s[0]\n" + "fmla v9.4s, v0.4s, v5.s[1]\n" + "fmla v10.4s, v0.4s, v5.s[2]\n" + "fmla v11.4s, v0.4s, v5.s[3]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v12.4s, v0.4s, v6.s[0]\n" + "fmla v13.4s, v0.4s, v6.s[1]\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n" + "fmla v14.4s, v0.4s, v6.s[2]\n" + "fmla v15.4s, v0.4s, v6.s[3]\n" + "fmla v16.4s, v0.4s, v7.s[0]\n" + "fmla v17.4s, v0.4s, v7.s[1]\n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n" + "fmla v18.4s, v0.4s, v7.s[2]\n" + "fmla v19.4s, v0.4s, v7.s[3]\n" + + "fmla v20.4s, v1.4s, v5.s[0]\n" + "fmla v21.4s, v1.4s, v5.s[1]\n" + "fmla v22.4s, v1.4s, v5.s[2]\n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n" + "fmla v23.4s, v1.4s, v5.s[3]\n" + "fmla v24.4s, v1.4s, v6.s[0]\n" + "fmla v25.4s, v1.4s, v6.s[1]\n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output1]], #64\n" + "fmla v26.4s, v1.4s, v6.s[2]\n" + "fmla v27.4s, v1.4s, v6.s[3]\n" + "fmla v28.4s, v1.4s, v7.s[0]\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64\n" + "fmla v29.4s, v1.4s, v7.s[1]\n" + "fmla v30.4s, v1.4s, v7.s[2]\n" + "fmla v31.4s, v1.4s, v7.s[3]\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "fmla v20.4s, v1.4s, v2.s[0]\n" + "fmla v21.4s, v1.4s, v2.s[1]\n" + "fmla v22.4s, v1.4s, v2.s[2]\n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "fmla v24.4s, v1.4s, v3.s[0]\n" + "fmla v25.4s, v1.4s, v3.s[1]\n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output1]], #64\n" + "fmla v26.4s, v1.4s, v3.s[2]\n" + "fmla v27.4s, v1.4s, v3.s[3]\n" + "fmla v28.4s, v1.4s, v4.s[0]\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64\n" + "fmla v29.4s, v1.4s, v4.s[1]\n" + "fmla v30.4s, v1.4s, v4.s[2]\n" + "fmla v31.4s, v1.4s, v4.s[3]\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64\n" + + "6:\n" + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), + [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), + [ output0 ] "+r"(output0), [ output1 ] "+r"(output1) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "x1", "x2", "cc", "memory"); +} + +// Overview of register layout: +// +// A 1x12 cell of Rhs is stored in 32bit in v2-v7 +// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) +// A 8x12 block of accumulators is stored in 32bit in v8-v31. +// +// +--------+ +// | v2[0-3]| +// | v3[0-3]| +// Rhs +--------+ +// +// | | +// +// Lhs | | +// +// +--+ --- - +--------+ +// |v0| | v8[0-3]| +// |v0| |v11[0-3]| +// |v0| |v14[0-3]| +// |v0| |v17[0-3]| +// |v1| |v20[0-3]| +// |v1| |v23[0-3]| +// |v1| |v26[0-3]| +// |v1| |v29[0-3]| +// +--+ --- - +--------+ +// +// Accumulator +void kern_8x4(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k, int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + float* output1 = output0 + LDC; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + //clang-format off +#define LOAD_C \ + "cmp %w[n_remain], #4\n" \ + "blt 11f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \ + "b 14f\n" \ + "11:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 12f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "ld1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \ + "b 14f\n" \ + "12:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 13f\n" \ + "ld1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "ld1 {v12.4s, v13.4s},[%[output1]]\n" \ + "b 14f\n" \ + "13:\n" \ + "ld1 {v8.4s}, [%[output0]]\n" \ + "ld1 {v12.4s},[%[output1]]\n" \ + "14:\n" + +#define STORE_C \ + "cmp %w[n_remain], #4\n" \ + "blt 21f\n" \ + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "st1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \ + "b 24f\n" \ + "21:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 22f\n" \ + "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "st1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \ + "b 23f\n" \ + "22:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 23f\n" \ + "st1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "st1 {v12.4s, v13.4s},[%[output1]]\n" \ + "b 24f\n" \ + "23:\n" \ + "st1 {v8.4s}, [%[output0]]\n" \ + "st1 {v12.4s},[%[output1]]\n" \ + "24:\n" + //clang-format on + + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "prfm pstl1keep, [%[output1]]\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], #16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v1.4s, v2.s[0]\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "fmla v13.4s, v1.4s, v2.s[1]\n" + "fmla v14.4s, v1.4s, v2.s[2]\n" + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "fmla v8.4s, v0.4s, v3.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v3.s[1]\n" + "fmla v10.4s, v0.4s, v3.s[2]\n" + "fmla v11.4s, v0.4s, v3.s[3]\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "fmla v12.4s, v1.4s, v3.s[0]\n" + "subs %w[K], %w[K], #1\n" + "fmla v13.4s, v1.4s, v3.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "fmla v14.4s, v1.4s, v3.s[2]\n" + "fmla v15.4s, v1.4s, v3.s[3]\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], #16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v1.4s, v2.s[0]\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "fmla v13.4s, v1.4s, v2.s[1]\n" + "fmla v14.4s, v1.4s, v2.s[2]\n" + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "fmla v8.4s, v0.4s, v3.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v3.s[1]\n" + "fmla v10.4s, v0.4s, v3.s[2]\n" + "fmla v11.4s, v0.4s, v3.s[3]\n" + "fmla v12.4s, v1.4s, v3.s[0]\n" + "fmla v13.4s, v1.4s, v3.s[1]\n" + "fmla v14.4s, v1.4s, v3.s[2]\n" + "fmla v15.4s, v1.4s, v3.s[3]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], #16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v1.4s, v2.s[0]\n" + "fmla v13.4s, v1.4s, v2.s[1]\n" + "fmla v14.4s, v1.4s, v2.s[2]\n" + "fmla v15.4s, v1.4s, v2.s[3]\n" + + "6:\n" STORE_C + + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), + [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), + [ output0 ] "+r"(output0), [ output1 ] "+r"(output1), + [ n_remain ] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "cc", "memory"); + +#undef LOAD_C +#undef STORE_C +} + + +// Overview of register layout: +// +// A 1x12 cell of Rhs is stored in 32bit in v2-v7 +// A 8x1 cell of Lhs is stored in 32bit in (v0-v1) +// A 8x12 block of accumulators is stored in 32bit in v8-v31. +// +// +--------+--------+--------+ +// | v2[0-3]| v3[0-3]| v4[0-3]| +// | v5[0-3]| v6[0-3]| v7[0-3]| +// Rhs +--------+--------+--------+ +// +// | | | | +// +// Lhs | | | | +// +// +--+ --- - +--------+--------+--------+ +// |v0| | v8[0-3]| v9[0-3]|v10[0-3]| +// |v0| |v11[0-3]|v12[0-3]|v13[0-3]| +// |v0| |v14[0-3]|v15[0-3]|v16[0-3]| +// |v0| |v17[0-3]|v18[0-3]|v19[0-3]| +// +--+ --- - +--------+--------+--------+ +// +// Accumulator + +void kern_4x12(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + asm volatile( + "cmp %w[is_first_k], #1\n" + "beq 1f\n" + "mov x1, %[output0]\n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64\n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64\n" + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + "eor v12.16b, v12.16b, v12.16b\n" + "eor v13.16b, v13.16b, v13.16b\n" + "eor v14.16b, v14.16b, v14.16b\n" + "eor v15.16b, v15.16b, v15.16b\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "fmla v8.4s, v1.4s, v5.s[0]\n" + "fmla v9.4s, v1.4s, v5.s[1]\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48\n" + "fmla v10.4s, v1.4s, v5.s[2]\n" + "fmla v11.4s, v1.4s, v5.s[3]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v12.4s, v1.4s, v6.s[0]\n" + "fmla v13.4s, v1.4s, v6.s[1]\n" + "subs %w[K], %w[K], #1\n" + "fmla v14.4s, v1.4s, v6.s[2]\n" + "fmla v15.4s, v1.4s, v6.s[3]\n" + "fmla v16.4s, v1.4s, v7.s[0]\n" + "fmla v17.4s, v1.4s, v7.s[1]\n" + "fmla v18.4s, v1.4s, v7.s[2]\n" + "fmla v19.4s, v1.4s, v7.s[3]\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + + "fmla v8.4s, v1.4s, v5.s[0]\n" + "fmla v9.4s, v1.4s, v5.s[1]\n" + "fmla v10.4s, v1.4s, v5.s[2]\n" + "fmla v11.4s, v1.4s, v5.s[3]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v12.4s, v1.4s, v6.s[0]\n" + "fmla v13.4s, v1.4s, v6.s[1]\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n" + "fmla v14.4s, v1.4s, v6.s[2]\n" + "fmla v15.4s, v1.4s, v6.s[3]\n" + "fmla v16.4s, v1.4s, v7.s[0]\n" + "fmla v17.4s, v1.4s, v7.s[1]\n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n" + "fmla v18.4s, v1.4s, v7.s[2]\n" + "fmla v19.4s, v1.4s, v7.s[3]\n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + "fmla v12.4s, v0.4s, v3.s[0]\n" + "fmla v13.4s, v0.4s, v3.s[1]\n" + "fmla v14.4s, v0.4s, v3.s[2]\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64\n" + "fmla v15.4s, v0.4s, v3.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v17.4s, v0.4s, v4.s[1]\n" + "fmla v18.4s, v0.4s, v4.s[2]\n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64\n" + "fmla v19.4s, v0.4s, v4.s[3]\n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64\n" + + "6:\n" + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), + [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), + [ output0 ] "+r"(output0) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "x1", "cc", "memory"); +} + + + + +// Overview of register layout: +// +// A 2x4 cell of Rhs is stored in 32bit in v2 - v3 +// A 4x2 cell of Lhs is stored in 32bit in v0 - v1 +// A 4x4 block of accumulators is stored in 32bit in v4-v6 +// +// +--------+ +// | v2[0-3]| +// | v5[0-3]| +// Rhs +--------+ +// +// | | +// +// Lhs | | +// +// +--+ --- - +--------+ +// |v0| | v8[0-3]| +// |v0| |v11[0-3]| +// |v0| |v14[0-3]| +// |v0| |v17[0-3]| +// +--+ --- - +--------+ +// +// Accumulator +void kern_4x4(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k, int n_remain) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + //clang-format off +#define LOAD_C \ + "cmp %w[n_remain], #4\n" \ + "blt 11f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "11:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 12f\n" \ + "ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "12:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 13f\n" \ + "ld1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "b 14f\n" \ + "13:\n" \ + "ld1 {v8.4s}, [%[output0]]\n" \ + "14:\n" + +#define STORE_C \ + "cmp %w[n_remain], #4\n" \ + "blt 21f\n" \ + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "21:\n" \ + "cmp %w[n_remain], #3\n" \ + "blt 22f\n" \ + "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ + "b 23f\n" \ + "22:\n" \ + "cmp %w[n_remain], #2\n" \ + "blt 23f\n" \ + "st1 {v8.4s, v9.4s}, [%[output0]]\n" \ + "b 24f\n" \ + "23:\n" \ + "st1 {v8.4s}, [%[output0]]\n" \ + "24:\n" + //clang-format on + + asm volatile( + // load accumulator C + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "b 2f\n" + + "1:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "ld1 {v2.4s}, [%[b_ptr]], #16\n" + "eor v9.16b, v9.16b, v9.16b\n" + "ld1 {v0.4s}, [%[a_ptr]], #16\n" + "eor v10.16b, v10.16b, v10.16b\n" + "prfm pstl1keep, [%[output0]]\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "2: \n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "fmla v8.4s, v1.4s, v3.s[0]\n" + "fmla v9.4s, v1.4s, v3.s[1]\n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "fmla v10.4s, v1.4s, v3.s[2]\n" + "fmla v11.4s, v1.4s, v3.s[3]\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "fmla v8.4s, v0.4s, v2.s[0]\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "fmla v8.4s, v1.4s, v3.s[0]\n" + "fmla v9.4s, v1.4s, v3.s[1]\n" + "fmla v10.4s, v1.4s, v3.s[2]\n" + "fmla v11.4s, v1.4s, v3.s[3]\n" + "b 6f\n" + + // odd tail + "5:\n" + "fmla v8.4s, v0.4s, v2.s[0]\n" + "fmla v9.4s, v0.4s, v2.s[1]\n" + "fmla v10.4s, v0.4s, v2.s[2]\n" + "fmla v11.4s, v0.4s, v2.s[3]\n" + + "6:\n" STORE_C + + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), + [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), + [ output0 ] "+r"(output0), [ n_remain ] "+r"(n_remain) + : + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); + +#undef LOAD_C +#undef STORE_C +} + +void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin, int y0, + int ymax, int k0, int kmax) { + megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + constexpr int PACK_SIZE_32 = 4 * 8; + constexpr int PACK_SIZE_16 = 4 * 4; + constexpr int PACK_C_SIZE = 4; + int y = y0; + for (; y + 7 < ymax; y += 8) { + const float* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0; + const float* inptr1 = inptr0 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + int k = (kmax - k0); + for (; k > 3; k -= 4) { + interleave_2x4_4_s(inptr0, inptr1, outptr); + outptr += PACK_SIZE_32; + } + } + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y / PACK_C_SIZE * ldin + k0; + prefetch_2x(inptr0); + int K = (kmax - k0); + for (; K > 3; K -= 4) { + interleave_1x4_4_s(inptr0, outptr); + outptr += PACK_SIZE_16; + } + } +} + +void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0, + int xmax, int k0, int kmax) { + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + float tmpbuff[16] = {0.0f}; + + constexpr int PACK_C_SIZE = 4; + int ksize = kmax - k0; + int ksize12 = ksize * 12; + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE; + prefetch_3x(inptr); + + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + transpose_1x12_4_s(inptr, outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + transpose_1x4_4_s(inptr, outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + std::memcpy(tmpbuff, inptr, + sizeof(float) * (xmax - x) * PACK_C_SIZE); + auto outptr_interleave = outptr; + const float* tmp_ptr = &tmpbuff[0]; + transpose_1x4_4_s(tmp_ptr, outptr_interleave); + outptr += ksize4; + } + outptr_base += 12 * 4; + outptr_base4 += 4 * 4; + } +} + +} // namespace matmul_mk4_8x12 +} // aarch64 +} // megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp b/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp index b180d8c7..4fa5f6b8 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp @@ -12,6 +12,7 @@ #include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" +#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" #include "src/common/utils.h" using namespace megdnn; @@ -163,4 +164,80 @@ void sgemm_8x12::kern(const float* packA, const float* packB, } } +MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); + +void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0, + int ymax, int k0, int kmax, + bool transpose_A) const { + megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A"); + matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax); +} + +void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0, + int xmax, int k0, int kmax, + bool transpose_B) const { + megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B"); + matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); +} + +void sgemm_mk4_8x12::kern(const float* packA, const float* packB, + size_t M, size_t N, size_t K, float* C, size_t LDC, + bool is_first_k, const float*, float*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); + + constexpr size_t PACK_C_SIZE = 4; + constexpr size_t A_INTERLEAVE = 8; + constexpr size_t A_INTERLEAVE4 = 4; + constexpr size_t B_INTERLEAVE = 12; + const int K12 = K * 12; + const int K8 = K * 8; + const int K4 = K * 4; + + size_t m = 0; + for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) { + float* output = C + (m / PACK_C_SIZE * LDC); + + size_t n = 0; + const float* cur_packB = packB; + for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { + matmul_mk4_8x12::kern_8x12(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE * PACK_C_SIZE; + cur_packB += K12; + } + + for (; n < N; n += 4) { + matmul_mk4_8x12::kern_8x4(packA, cur_packB, K, output, LDC, + is_first_k, std::min(N - n, 4)); + output += 4 * PACK_C_SIZE; + cur_packB += K4; + } + packA += K8; + } + for (; m < M; m += A_INTERLEAVE4) { + float* output = C + (m / PACK_C_SIZE * LDC); + size_t n = 0; + const float* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_mk4_8x12::kern_4x12(packA, cur_packB, K, output, LDC, + is_first_k); + output += B_INTERLEAVE * PACK_C_SIZE; + cur_packB += K12; + } + for (; n < N; n += 4) { + matmul_mk4_8x12::kern_4x4(packA, cur_packB, K, output, LDC, + is_first_k, std::min(N - n, 4)); + output += 4 * PACK_C_SIZE; + cur_packB += K4; + } + packA += K4; + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy.h b/dnn/src/aarch64/matrix_mul/fp32/strategy.h index 8cd877e2..a2faf6ed 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/strategy.h +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy.h @@ -20,6 +20,9 @@ MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, sgemm_4x16); +MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, + sgemm_mk4_8x12); + MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true, sgemm_nopack_4x16); diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp index c9691647..58ec3ef6 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.cpp +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -18,6 +18,7 @@ using namespace aarch64; class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF32K8x12x1 f32K8x12x1; + AlgoF32MK4_8x12x1 f32_mk4_8x12x1; AlgoF32K4x16x1 f32k4x16x1; AlgoF32MK4_4x16 f32mk4_4x16; AlgoF32Gemv f32_gemv; @@ -53,6 +54,7 @@ public: AlgoPack() { all_algos.emplace_back(&f32_gemv); all_algos.emplace_back(&f32K8x12x1); + all_algos.emplace_back(&f32_mk4_8x12x1); all_algos.emplace_back(&f32k4x16x1); all_algos.emplace_back(&f32mk4_4x16); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h index 01ec762f..1d4e0ab3 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.h +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -22,6 +22,7 @@ public: private: class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 + class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1 class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 class AlgoF32Gemv; // Aarch64 F32 Gemv diff --git a/dnn/src/arm_common/conv_bias/f16/algos.cpp b/dnn/src/arm_common/conv_bias/f16/algos.cpp index c24af258..7183c524 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.cpp +++ b/dnn/src/arm_common/conv_bias/f16/algos.cpp @@ -244,6 +244,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable( if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) return false; using Strategy = winograd::winograd_2x3_8x8_f16; + using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = megdnn::winograd::ConvBiasusable(matmul_param) && + m_matmul_algo->packmode() == PackMode::NO_PACK && (opr->param().format == param::ConvBias::Format::NCHW || (opr->param().format == param::ConvBias::Format::NCHW_WINOGRAD && diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.cpp b/dnn/src/arm_common/conv_bias/fp32/algos.cpp index 28e08032..af3de843 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/algos.cpp @@ -38,6 +38,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; using Strategy = winograd::winograd_2x3_4x4_f; + using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = megdnn::winograd::ConvBiasusable(matmul_param) && + m_matmul_algo->packmode() == PackMode::NO_PACK && (opr->param().format == param::ConvBias::Format::NCHW || (opr->param().format == param::ConvBias::Format::NCHW_WINOGRAD && @@ -319,6 +321,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; using Strategy = winograd::winograd_6x3_4x4_f; + using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = megdnn::winograd::ConvBiasusable(matmul_param) && + m_matmul_algo->packmode() == PackMode::NO_PACK && (opr->param().format == param::ConvBias::Format::NCHW || (opr->param().format == param::ConvBias::Format::NCHW_WINOGRAD && diff --git a/dnn/src/arm_common/conv_bias/int8/algos.cpp b/dnn/src/arm_common/conv_bias/int8/algos.cpp index 625ef9c6..f6d7022e 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.cpp +++ b/dnn/src/arm_common/conv_bias/int8/algos.cpp @@ -217,6 +217,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable( if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) return false; using Strategy = winograd::winograd_2x3_8x8_s8; + using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = megdnn::winograd::ConvBias( @@ -224,6 +225,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable( param.osz[1], param.filter_meta.ocpg) .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && + m_matmul_algo->packmode() == PackMode::NO_PACK && ((opr->param().format == param::ConvBias::Format::NCHW && param.filter_type.enumv() == DTypeEnum::QuantizedS8) || (opr->param().format == param::ConvBias::Format::NCHW_WINOGRAD && diff --git a/dnn/test/aarch64/matrix_mul.cpp b/dnn/test/aarch64/matrix_mul.cpp index e4baef13..71d338f1 100644 --- a/dnn/test/aarch64/matrix_mul.cpp +++ b/dnn/test/aarch64/matrix_mul.cpp @@ -31,6 +31,12 @@ TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) { "AARCH64_F32K4X16X1"); } +TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) { + matrix_mul::check_matrix_mul( + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), + "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); +} + TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { //! nbase should be 4 in order to test the last rest 4 in N dim matrix_mul::check_matrix_mul( @@ -527,6 +533,15 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_MK4) { dtype::Float32{}); } +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_PACK_MK4) { + auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(16); + matrix_mul::benchmark_with_contrast( + handle(), args, dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, "AARCH64_F32_MK4_K8X12X1", + param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, "AARCH64_F32K8X12X1"); +} + TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) { auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); matrix_mul::benchmark_with_contrast( diff --git a/dnn/test/common/matrix_mul.cpp b/dnn/test/common/matrix_mul.cpp index 48682d17..74cc7c83 100644 --- a/dnn/test/common/matrix_mul.cpp +++ b/dnn/test/common/matrix_mul.cpp @@ -40,8 +40,8 @@ std::vector matrix_mul::get_matmul_mk_packed_args( size_t nbase) { std::vector args; for (size_t m : {1, 2, 3, 4, 5}) - for (size_t n : {1, 2, 3, 4, 5, 8, 16, 24}) - for (size_t k : {1, 2, 3, 4, 5}) + for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24}) + for (size_t k : {1, 2, 3, 4, 5, 9, 10}) args.emplace_back(m, n * nbase, k, 0); return args; } -- GitLab