From 27ef788f84e5585b82b512116bd0767a7f1e83fa Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 17 May 2020 16:35:19 +0800 Subject: [PATCH] feat(dnn/armv7): add armv7 mk4 matmul GitOrigin-RevId: 8ef24bf53b19f7863b8a51da5a81135c941d1a72 --- .../aarch64/matrix_mul/fp32/kernel_mk4_8x12.h | 2 +- dnn/src/armv7/matrix_mul/algos.cpp | 68 +++ dnn/src/armv7/matrix_mul/algos.h | 11 + dnn/src/armv7/matrix_mul/asm/common.h | 56 +++ dnn/src/armv7/matrix_mul/fp32/strategy.h | 3 + .../matrix_mul/fp32/strategy_mk_4x12.cpp | 451 ++++++++++++++++++ dnn/src/armv7/matrix_mul/opr_impl.cpp | 2 + dnn/src/armv7/matrix_mul/opr_impl.h | 1 + .../arm_common/conv_bias_multi_thread.cpp | 18 +- dnn/test/armv7/matrix_mul.cpp | 15 + 10 files changed, 620 insertions(+), 7 deletions(-) create mode 100644 dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h index 6dccd0978..0e2bab774 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h @@ -707,7 +707,7 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, "cmp %w[n_remain], #3\n" \ "blt 22f\n" \ "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ - "b 23f\n" \ + "b 24f\n" \ "22:\n" \ "cmp %w[n_remain], #2\n" \ "blt 23f\n" \ diff --git a/dnn/src/armv7/matrix_mul/algos.cpp b/dnn/src/armv7/matrix_mul/algos.cpp index 59b327ec3..c9a669460 100644 --- a/dnn/src/armv7/matrix_mul/algos.cpp +++ b/dnn/src/armv7/matrix_mul/algos.cpp @@ -85,6 +85,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern, "AlgoF32Impl"_hash, armv7::matmul::sgemm_4x12, float, float); +/* ===================== F32 algo mk4 K4x12 ===================== */ + +namespace { +void f32_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("f32_mk4_pack_4x12_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + armv7::matmul::sgemm_mk4_pack_4x12 strategy(M, N, K, A_type, B_type, + C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} + +} // anonymous namespace + +bool MatrixMulImpl::AlgoF32MK4Pack4x12::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK4 && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && + !kern_size_param.trB && kern_size_param.M % 4 == 0 && + kern_size_param.K % 4 == 0 && !kern_size_param.trA && + !kern_size_param.trB; +} + +size_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoF32MK4Pack4x12::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + armv7::matmul::sgemm_mk4_pack_4x12 strategy(M, N, K, A_type, B_type, + C_type); + return megdnn::matmul::GemmInterleaved< + armv7::matmul::sgemm_mk4_pack_4x12>(M, N, K, trA, trB, + strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_kern( + const KernSizeParam&) const { + return f32_mk4_pack_4x12_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4Pack4x12, + megdnn_armv7_matmul_kern, + "AlgoF32MK4Pack4x12"_hash, + armv7::matmul::sgemm_mk4_pack_4x12, float, + float); + #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /* ===================== F16 K4x16x1 algo ===================== */ namespace { diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index 141d5650b..388074c26 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -29,6 +29,17 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; +class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { public: bool is_reproducible() const override { return true; } diff --git a/dnn/src/armv7/matrix_mul/asm/common.h b/dnn/src/armv7/matrix_mul/asm/common.h index 20bbd3f35..a7aad47e9 100644 --- a/dnn/src/armv7/matrix_mul/asm/common.h +++ b/dnn/src/armv7/matrix_mul/asm/common.h @@ -1120,6 +1120,62 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1, : "q0", "q1", "q2", "memory"); } +template +static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { + static_assert(sizeof(T) == 4, + "transpose_1x12_4_s only support sizeof(T) == 4"); + + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]!\n" + "vld4.32 {d4-d7}, [%[inptr0]]!\n" + "vld4.32 {d8-d11}, [%[inptr0]]!\n" + "vld4.32 {d12-d15}, [%[inptr0]]!\n" + "vld4.32 {d16-d19}, [%[inptr0]]!\n" + "vld4.32 {d20-d23}, [%[inptr0]]!\n" + "vswp d1, d4\n" + "vswp d3, d6\n" + "vswp d9, d12\n" + "vswp d11, d14\n" + "vswp d17, d20\n" + "vswp d19, d22\n" + + "vst1.32 {d0-d1}, [%[outptr]]! \n" + "vst1.32 {d8-d9}, [%[outptr]]! \n" + "vst1.32 {d16-d17}, [%[outptr]]! \n" + "vst1.32 {d4-d5}, [%[outptr]]! \n" + "vst1.32 {d12-d13}, [%[outptr]]! \n" + "vst1.32 {d20-d21}, [%[outptr]]! \n" + "vst1.32 {d2-d3}, [%[outptr]]! \n" + "vst1.32 {d10-d11}, [%[outptr]]! \n" + "vst1.32 {d18-d19}, [%[outptr]]! \n" + "vst1.32 {d6-d7}, [%[outptr]]! \n" + "vst1.32 {d14-d15}, [%[outptr]]! \n" + "vst1.32 {d22-d23}, [%[outptr]]! \n" + : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "memory"); +} + +template +static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { + static_assert(sizeof(T) == 4, + "transpose_1x4_4_s only support sizeof(T) == 4"); + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]!\n" + "vld4.32 {d4-d7}, [%[inptr0]]!\n" + "vswp d1, d4\n" + "vswp d3, d6\n" + "vst1.32 {d0-d1}, [%[outptr]]! \n" + "vst1.32 {d4-d5}, [%[outptr]]! \n" + "vst1.32 {d2-d3}, [%[outptr]]! \n" + "vst1.32 {d6-d7}, [%[outptr]]! \n" + : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "memory"); +} + + template static inline void transpose_4(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, T* outptr, diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy.h b/dnn/src/armv7/matrix_mul/fp32/strategy.h index 9b7b0bae3..ef50e8a8a 100644 --- a/dnn/src/armv7/matrix_mul/fp32/strategy.h +++ b/dnn/src/armv7/matrix_mul/fp32/strategy.h @@ -18,6 +18,9 @@ namespace matmul { MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, sgemm_4x12); +MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, false, + sgemm_mk4_pack_4x12); + MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 8, 1, false, true, sgemm_nopack_4x8); diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp b/dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp new file mode 100644 index 000000000..db89491bf --- /dev/null +++ b/dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp @@ -0,0 +1,451 @@ +/** + * \file dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/armv7/matrix_mul/fp32/strategy.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" + +using namespace megdnn; +using namespace armv7; +using namespace armv7::matmul; + +namespace { + +// Overview of register layout: +// +// A 1x12 cell of Rhs is stored in 32bit in q1-q3 +// A 4x1 cell of Lhs is stored in 132bit in q0 +// A 4x12 block of accumulators is stored in 32bit in q4-q15. +// +// +--------+--------+--------+ +// | q1[0-3]| q2[0-3]| q3[0-3]| +// Rhs +--------+--------+--------+ +// +// | | | | +// +// Lhs | | | | +// +// +--+ - - - - +--------+--------+--------+ +// |q0| | q4[0-3]| q5[0-3]| q6[0-3]| +// |q0| | q7[0-3]| q8[0-3]| q9[0-3]| +// |q0| |q10[0-3]|q11[0-3]|q12[0-3]| +// |q0| |q13[0-3]|q14[0-3]|q15[0-3]| +// +--+ - - - - +--------+--------+--------+ +// +// Accumulator +void kern_4x12(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + asm volatile( + "cmp %[is_first_k], #1\n" + "beq 1f\n" + "mov r1, %[output0]\n" + "vld1.32 {d8-d11}, [r1]!\n" + "vld1.32 {d12-d15}, [r1]!\n" + "vld1.32 {d16-d19}, [r1]!\n" + "vld1.32 {d20-d23}, [r1]!\n" + "vld1.32 {d24-d27}, [r1]!\n" + "vld1.32 {d28-d31}, [r1]!\n" + + "vld1.32 {d0-d1}, [%[a_ptr]]!\n" + "vld1.32 {d4-d5}, [%[b_ptr]]!\n" + "b 2f\n" + + "1:\n" + "veor.32 q4, q4, q4\n" + "pld [%[output0]]\n" + "veor.32 q5, q4, q4\n" + "veor.32 q6, q4, q4\n" + "veor.32 q7, q4, q4\n" + "vld1.32 {d0-d1}, [%[a_ptr]]!\n" + "veor.32 q8, q4, q4\n" + "veor.32 q9, q4, q4\n" + "veor.32 q10, q4, q4\n" + "veor.32 q11, q4, q4\n" + "vld1.32 {d4-d7}, [%[b_ptr]]!\n" + "veor.32 q12, q4, q4\n" + "veor.32 q13, q4, q4\n" + "veor.32 q14, q4, q4\n" + "veor.32 q15, q4, q4\n" + + "2: \n" + "cmp %[K], #0\n" + "beq 4f\n" + + "3:\n" + "vmla.f32 q4, q0, d4[0]\n" + "vmla.f32 q5, q0, d4[1]\n" + "vmla.f32 q6, q0, d5[0]\n" + "vmla.f32 q7, q0, d5[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr]]!\n" + "vmla.f32 q8, q0, d6[0]\n" + "vmla.f32 q9, q0, d6[1]\n" + "vmla.f32 q10, q0, d7[0]\n" + "vld1.32 {d2-d3}, [%[a_ptr]]!\n" + "vmla.f32 q11, q0, d7[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr]]!\n" + "vmla.f32 q12, q0, d4[0]\n" + "vmla.f32 q13, q0, d4[1]\n" + "vmla.f32 q14, q0, d5[0]\n" + "vmla.f32 q15, q0, d5[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr]]!\n" + + "vmla.f32 q4, q1, d6[0]\n" + "subs %[K], %[K], #1\n" + "vmla.f32 q5, q1, d6[1]\n" + "vmla.f32 q6, q1, d7[0]\n" + "vmla.f32 q7, q1, d7[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr]]!\n" + "vmla.f32 q8, q1, d4[0]\n" + "vmla.f32 q9, q1, d4[1]\n" + "vld1.32 {d0-d1}, [%[a_ptr]]!\n" + "vmla.f32 q10, q1, d5[0]\n" + "vmla.f32 q11, q1, d5[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr]]!\n" + "vmla.f32 q12, q1, d6[0]\n" + "vmla.f32 q13, q1, d6[1]\n" + "vmla.f32 q14, q1, d7[0]\n" + "vmla.f32 q15, q1, d7[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr]]!\n" + "bne 3b\n" + + "4:\n" + "cmp %[oddk], #1\n" + "beq 5f\n" + + // Even tail + "vmla.f32 q4, q0, d4[0]\n" + "vmla.f32 q5, q0, d4[1]\n" + "vmla.f32 q6, q0, d5[0]\n" + "vmla.f32 q7, q0, d5[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr]]!\n" + "vmla.f32 q8, q0, d6[0]\n" + "vmla.f32 q9, q0, d6[1]\n" + "vmla.f32 q10, q0, d7[0]\n" + "vld1.32 {d2-d3}, [%[a_ptr]]!\n" + "vmla.f32 q11, q0, d7[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr]]!\n" + "vmla.f32 q12, q0, d4[0]\n" + "vmla.f32 q13, q0, d4[1]\n" + "vmla.f32 q14, q0, d5[0]\n" + "vmla.f32 q15, q0, d5[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr]]!\n" + + "vmla.f32 q4, q1, d6[0]\n" + "subs %[K], %[K], #1\n" + "vmla.f32 q5, q1, d6[1]\n" + "vmla.f32 q6, q1, d7[0]\n" + "vmla.f32 q7, q1, d7[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr]]!\n" + "vmla.f32 q8, q1, d4[0]\n" + "vmla.f32 q9, q1, d4[1]\n" + "vst1.32 {d8-d11}, [%[output0]]!\n" + "vmla.f32 q10, q1, d5[0]\n" + "vmla.f32 q11, q1, d5[1]\n" + "vst1.32 {d12-d15}, [%[output0]]!\n" + "vmla.f32 q12, q1, d6[0]\n" + "vmla.f32 q13, q1, d6[1]\n" + "vst1.32 {d16-d19}, [%[output0]]!\n" + "vmla.f32 q14, q1, d7[0]\n" + "vmla.f32 q15, q1, d7[1]\n" + "vst1.32 {d20-d23}, [%[output0]]!\n" + "vst1.32 {d24-d27}, [%[output0]]!\n" + "vst1.32 {d28-d31}, [%[output0]]!\n" + + "b 6f\n" + + // odd tail + "5:\n" + "vmla.f32 q4, q0, d4[0]\n" + "vmla.f32 q5, q0, d4[1]\n" + "vmla.f32 q6, q0, d5[0]\n" + "vmla.f32 q7, q0, d5[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr]]!\n" + "vmla.f32 q8, q0, d6[0]\n" + "vst1.32 {d8-d11}, [%[output0]]!\n" + "vmla.f32 q9, q0, d6[1]\n" + "vmla.f32 q10, q0, d7[0]\n" + "vst1.32 {d12-d15}, [%[output0]]!\n" + "vmla.f32 q11, q0, d7[1]\n" + "vmla.f32 q12, q0, d4[0]\n" + "vst1.32 {d16-d19}, [%[output0]]!\n" + "vmla.f32 q13, q0, d4[1]\n" + "vst1.32 {d20-d23}, [%[output0]]!\n" + "vmla.f32 q14, q0, d5[0]\n" + "vst1.32 {d24-d27}, [%[output0]]!\n" + "vmla.f32 q15, q0, d5[1]\n" + "vst1.32 {d28-d31}, [%[output0]]!\n" + + "6:\n" + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), + [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), + [ output0 ] "+r"(output0) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "r1", "cc", "memory"); +} + + + + +// Overview of register layout: +// +// A 2x4 cell of Rhs is stored in 32bit in v2 - v3 +// A 4x2 cell of Lhs is stored in 32bit in v0 - v1 +// A 4x4 block of accumulators is stored in 32bit in v4-v6 +// +// +--------+ +// | q2[0-3]| +// | q5[0-3]| +// Rhs +--------+ +// +// | | +// +// Lhs | | +// +// +--+ --- - +--------+ +// |q0| | q8[0-3]| +// |q0| |q11[0-3]| +// |q0| |q14[0-3]| +// |q0| |q17[0-3]| +// +--+ --- - +--------+ +// +// Accumulator +void kern_4x4(const float* packA, const float* packB, int K, float* output, + int LDC, bool is_first_k, int n_remain) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + +//clang-format off +#define LOAD_C \ + "cmp %[n_remain], #4\n" \ + "blt 11f\n" \ + "vld1.32 {d8-d11}, [r1]!\n" \ + "vld1.32 {d12-d15}, [r1]!\n" \ + "b 14f\n" \ + "11:\n" \ + "cmp %[n_remain], #3\n" \ + "blt 12f\n" \ + "vld1.32 {d8-d11}, [r1]!\n" \ + "vld1.32 {d12-d13}, [r1]!\n" \ + "b 14f\n" \ + "12:\n" \ + "cmp %[n_remain], #2\n" \ + "blt 13f\n" \ + "vld1.32 {d8-d11}, [r1]\n" \ + "b 14f\n" \ + "13:\n" \ + "vld1.32 {d8-d9}, [r1]\n" \ + "14:\n" + +#define STORE_C \ + "cmp %[n_remain], #4\n" \ + "blt 21f\n" \ + "vst1.32 {d8-d11}, [%[output]]!\n" \ + "vst1.32 {d12-d15}, [%[output]]!\n" \ + "b 24f\n" \ + "21:\n" \ + "cmp %[n_remain], #3\n" \ + "blt 22f\n" \ + "vst1.32 {d8-d11}, [%[output]]!\n" \ + "vst1.32 {d12-d13}, [%[output]]!\n" \ + "b 24f\n" \ + "22:\n" \ + "cmp %[n_remain], #2\n" \ + "blt 23f\n" \ + "vst1.32 {d8-d11}, [%[output]]!\n" \ + "b 24f\n" \ + "23:\n" \ + "vst1.32 {d8-d9}, [%[output]]!\n" \ + "24:\n" +//clang-format on + + asm volatile( + "cmp %[is_first_k], #1\n" + "beq 1f\n" + "mov r1, %[output]\n" LOAD_C + "vld1.32 {d0-d1}, [%[a_ptr]]!\n" + "vld1.32 {d4-d5}, [%[b_ptr]]!\n" + "b 2f\n" + + "1:\n" + "veor.32 q4, q4, q4\n" + "pld [%[output]]\n" + "veor.32 q5, q4, q4\n" + "vld1.32 {d0-d1}, [%[a_ptr]]!\n" + "veor.32 q6, q4, q4\n" + "vld1.32 {d4-d5}, [%[b_ptr]]!\n" + "veor.32 q7, q4, q4\n" + + "2: \n" + "cmp %[K], #0\n" + "beq 4f\n" + + "3:\n" + "vmla.f32 q4, q0, d4[0]\n" + "vld1.32 {d2-d3}, [%[a_ptr]]!\n" + "vmla.f32 q5, q0, d4[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr]]!\n" + "vmla.f32 q6, q0, d5[0]\n" + "vmla.f32 q7, q0, d5[1]\n" + + "vld1.32 {d4-d5}, [%[b_ptr]]!\n" + "vmla.f32 q4, q1, d6[0]\n" + "subs %[K], %[K], #1\n" + "vmla.f32 q5, q1, d6[1]\n" + "vld1.32 {d0-d1}, [%[a_ptr]]!\n" + "vmla.f32 q6, q1, d7[0]\n" + "vmla.f32 q7, q1, d7[1]\n" + "bne 3b\n" + + "4:\n" + "cmp %[oddk], #1\n" + "beq 5f\n" + + // Even tail + "vmla.f32 q4, q0, d4[0]\n" + "vld1.32 {d2-d3}, [%[a_ptr]]!\n" + "vmla.f32 q5, q0, d4[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr]]!\n" + "vmla.f32 q6, q0, d5[0]\n" + "vmla.f32 q7, q0, d5[1]\n" + + "vmla.f32 q4, q1, d6[0]\n" + "vmla.f32 q5, q1, d6[1]\n" + "vmla.f32 q6, q1, d7[0]\n" + "vmla.f32 q7, q1, d7[1]\n" + "b 6f\n" + + // odd tail + "5:\n" + "vmla.f32 q4, q0, d4[0]\n" + "vmla.f32 q5, q0, d4[1]\n" + "vmla.f32 q6, q0, d5[0]\n" + "vmla.f32 q7, q0, d5[1]\n" + + "6:\n" STORE_C + : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), + [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), + [ output ] "+r"(output), [ n_remain ] "+r"(n_remain) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "r1", "cc", + "memory"); +#undef LOAD_C +#undef STORE_C +} + +} // namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_pack_4x12); +//! Now no matmul mode of only packB support in conv1x1 and im2col, so just copy +//! the weight +void sgemm_mk4_pack_4x12::pack_A(float* out, const float* in, int ldin, int y0, + int ymax, int k0, int kmax, bool) const { + megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + constexpr int PACK_C_SIZE = 4; + size_t cp_length = (kmax - k0) * PACK_C_SIZE; + for (int m = y0; m < ymax; m += 4) { + const float* src = in + (m / PACK_C_SIZE) * ldin + k0 * PACK_C_SIZE; + memcpy(out, src, cp_length * sizeof(float)); + out += cp_length; + } +} + +void sgemm_mk4_pack_4x12::pack_B(float* out, const float* in, int ldin, int x0, + int xmax, int k0, int kmax, + bool transpose_B) const { + megdnn_assert(!transpose_B); + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + float tmpbuff[16] = {0.0f}; + + constexpr int PACK_C_SIZE = 4; + int ksize = kmax - k0; + int ksize12 = ksize * 12; + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE; + prefetch_3x(inptr); + + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + transpose_1x12_4_s(inptr, outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + transpose_1x4_4_s(inptr, outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE); + auto outptr_interleave = outptr; + const float* tmp_ptr = &tmpbuff[0]; + transpose_1x4_4_s(tmp_ptr, outptr_interleave); + outptr += ksize4; + } + outptr_base += 12 * PACK_C_SIZE; + outptr_base4 += 4 * PACK_C_SIZE; + } +} + +void sgemm_mk4_pack_4x12::kern(const float* packA, const float* packB, size_t M, + size_t N, size_t K, float* C, size_t LDC, + bool is_first_k, const float*, float*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); + constexpr int PACK_C_SIZE = 4; + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 12; + const int K12 = K * 12; + const int K4 = K * 4; + size_t m = 0; + for (; m < M; m += A_INTERLEAVE) { + float* output = C + (m / 4 * LDC); + + size_t n = 0; + const float* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); + output += PACK_C_SIZE * B_INTERLEAVE; + cur_packB += K12; + } + for (; n < N; n += 4) { + kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); + output += PACK_C_SIZE * 4; + cur_packB += K4; + } + packA += K4; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp index f6a9d91c3..736e2f7c7 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.cpp +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -20,6 +20,7 @@ using namespace armv7; class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF32 f32; + AlgoF32MK4Pack4x12 f32_mk4_pack_4x12; AlgoF32MK4_4x8 f32_mk4_4x8; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC AlgoF16K4x16x1 f16_k4x16x1; @@ -48,6 +49,7 @@ public: AlgoPack() { all_algos.emplace_back(&f32_gemv); all_algos.emplace_back(&f32); + all_algos.emplace_back(&f32_mk4_pack_4x12); all_algos.emplace_back(&f32_mk4_4x8); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC all_algos.emplace_back(&f16_k4x16x1); diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index 51e74db02..fe3d22a3b 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -21,6 +21,7 @@ public: SmallVector algo_pack() override; private: class AlgoF32; // Armv7 F32 + class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack class AlgoF32Gemv; // Armv7 F32 Gemv class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 77ab60e74..85b55dec8 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -1287,23 +1287,27 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { #undef cb } -#if MEGDNN_AARCH64 TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args({2, 4, 7}, 1); +#if MEGDNN_AARCH64 check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); -} +#elif MEGDNN_ARMV7 + check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12"); #endif +} -#if MEGDNN_AARCH64 TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args({3, 5, 6}, 2); +#if MEGDNN_AARCH64 check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); -} +#elif MEGDNN_ARMV7 + check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12"); #endif +} /***************************** Conv1x1 Algo Test ***********************/ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { @@ -1316,14 +1320,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { #endif } -#if MEGDNN_AARCH64 TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args({1}, 1, true, false, false); +#if MEGDNN_AARCH64 check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); -} +#elif MEGDNN_ARMV7 + check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24"); #endif +} TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { using namespace conv_bias; diff --git a/dnn/test/armv7/matrix_mul.cpp b/dnn/test/armv7/matrix_mul.cpp index 56994f15c..78d82bc0b 100644 --- a/dnn/test/armv7/matrix_mul.cpp +++ b/dnn/test/armv7/matrix_mul.cpp @@ -28,6 +28,12 @@ TEST_F(ARMV7, MATRIX_MUL_MK4) { "ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4); } +TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) { + matrix_mul::check_matrix_mul( + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), + "ARMV7_F32_MK4_PACK_4X12", param::MatrixMul::Format::MK4, 1); +} + TEST_F(ARMV7, MATRIX_MUL_MK4_INT8) { std::vector args; for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11}) @@ -349,6 +355,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_MK4) { dtype::Float32{}); } +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_PACK_MK4) { + auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); + matrix_mul::benchmark_with_contrast( + handle(), args, dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, "ARMV7_F32_MK4_PACK_4X12", + param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}); +} + TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) { auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4); matrix_mul::benchmark_with_contrast( -- GitLab