From 5d950063cfa0b4bacf37147b413f8c50271c83d9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 10 Jun 2020 14:35:28 +0800 Subject: [PATCH] feat(dnn): refactor dot gemv for both aarch64 and aarch32 GitOrigin-RevId: 2b98867e4563bffe69f0340676e1a6c9ca8d0a2d --- dnn/src/aarch64/matrix_mul/algos.cpp | 34 ------ dnn/src/aarch64/matrix_mul/algos.h | 19 --- dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp | 116 ------------------- dnn/src/aarch64/matrix_mul/int8_dot/gemv.h | 34 ------ dnn/src/aarch64/matrix_mul/opr_impl.cpp | 4 - dnn/src/aarch64/matrix_mul/opr_impl.h | 2 - dnn/src/arm_common/matrix_mul/algos.cpp | 3 - dnn/src/arm_common/matrix_mul/algos.h | 5 - dnn/src/arm_common/matrix_mul/int8/gemv.cpp | 79 ++++++++++++- dnn/src/arm_common/matrix_mul/int8/gemv.h | 3 +- dnn/src/arm_common/matrix_mul/opr_impl.cpp | 3 +- dnn/src/arm_common/matrix_mul/opr_impl.h | 2 - dnn/src/arm_common/simd_macro/marm_neon.h | 13 +++ dnn/src/armv7/matrix_mul/algos.h | 5 - dnn/src/armv7/matrix_mul/opr_impl.cpp | 6 - dnn/src/armv7/matrix_mul/opr_impl.h | 3 - dnn/test/arm_common/matrix_mul.cpp | 30 +++++ 17 files changed, 122 insertions(+), 239 deletions(-) delete mode 100644 dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp delete mode 100644 dnn/src/aarch64/matrix_mul/int8_dot/gemv.h diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index 6ae30e821..7ef15b450 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -14,7 +14,6 @@ #include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/aarch64/matrix_mul/int16/strategy.h" #include "src/aarch64/matrix_mul/int8/strategy.h" -#include "src/aarch64/matrix_mul/int8_dot/gemv.h" #include "src/aarch64/matrix_mul/int8_dot/strategy.h" #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" #include "src/aarch64/matrix_mul/quint8/strategy.h" @@ -441,39 +440,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd, "AlgoInt8x8x32K8x12x4DotProdImpl"_hash, aarch64::matmul::gemm_s8_8x12, int8_t, int32_t); -/* ===================== Int8x8x32 Gemv DotProd algo ===================== */ -namespace { -void int8x8x32_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { - auto M = kern_param.M, N = kern_param.N, K = kern_param.K; - auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; - const auto Aptr = kern_param.A(), Bptr = kern_param.B(); - auto Cptr = kern_param.C(); - aarch64::matmul::gemv_like_int8(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); -} -} // anonymous namespace - -bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::usable( - const KernSizeParam& kern_size_param) const { - return can_be_treated_as_int8x8x32(kern_size_param) && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.N == 1 && kern_size_param.LDB == 1; -} - -bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::preferred( - const KernSizeParam& kern_size_param) const { - auto N = kern_size_param.N, LDB = kern_size_param.LDB; - return (N == 1 && LDB == 1); -} - -MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvDotProd::get_kern( - const KernSizeParam&) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt8x8x32GemvDotProd::get_kern"_hash)) { - return int8x8x32_gemv_dotprod_kern; - } - MIDOUT_END(); - return nullptr; -} /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ namespace { diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index 672187560..7e5a3613e 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -104,21 +104,6 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; -class MatrixMulImpl::AlgoInt8x8x32GemvDotProd final : public AlgoBase { -public: - bool is_reproducible() const override { return true; } - const char* name() const override { - return "AARCH64_INT8X8X32_GEMV_DOTPROD"; - } - bool usable(const KernSizeParam&) const override; - bool preferred(const KernSizeParam&) const override; - size_t get_workspace(const KernSizeParam&) const override { return 0; } - kern_t get_kern(const KernSizeParam&) const override; - void* type() const override { return sm_arm_common_algo_type; } - AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } - PackMode packmode() const override { return PackMode::NO_PACK; } -}; - class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { public: bool is_reproducible() const override { return true; } @@ -174,10 +159,6 @@ public: void* type() const override { return sm_arm_common_algo_type; } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; - -class MatrixMulImpl::AlgoInt8x8x32Gemv final - : public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; - #endif class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp b/dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp deleted file mode 100644 index cb1ac71b4..000000000 --- a/dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp +++ /dev/null @@ -1,116 +0,0 @@ -/** - * \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#include "src/aarch64/matrix_mul/int8_dot/gemv.h" -#include -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/common/utils.h" -#include "src/common/unroll_macro.h" - -#if __ARM_FEATURE_DOTPROD - -namespace { - -void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { - megdnn_assert(N == 1 && Bstride == 1); - size_t m = 0; - for (; m + 2 <= M; m += 2) { - int32_t acc[4]; - int32x4_t acc_neon = vdupq_n_s32(0); - size_t k = 0; - for (; k + 16 <= K; k += 16) { - int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); - int64x2_t a1 = - vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); - //! the first 8 elements is m, the last 8 elements is m + 1 - int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); - int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); - - int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); - int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); - int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); - - acc_neon = vdotq_s32(acc_neon, a2, b2); - acc_neon = vdotq_s32(acc_neon, a3, b3); - } - vst1q_s32(acc, acc_neon); - - for (; k + 8 <= K; k += 8) { - int8x8_t a0 = vld1_s8(A + m * Astride + k); - int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); - int8x8_t b0 = vld1_s8(B + k); - uint32x2_t zero = vdup_n_s32(0); - acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); - zero = vdup_n_s32(0); - acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); - } - - for (; k < K; ++k) { - acc[0] += static_cast(A[m * Astride + k]) * B[k]; - acc[3] += static_cast(A[(m + 1) * Astride + k]) * B[k]; - } - C[m * Cstride] = acc[0] + acc[1]; - C[(m + 1) * Cstride] = acc[2] + acc[3]; - } - - for (; m < M; ++m) { - int32_t acc[4]; - int32x4_t acc_neon = vdupq_n_s32(0); - size_t k = 0; - for (; k + 16 <= K; k += 16) { - int8x16_t a0 = vld1q_s8(A + m * Astride + k); - int8x16_t b0 = vld1q_s8(B + k); - acc_neon = vdotq_s32(acc_neon, a0, b0); - } - vst1q_s32(acc, acc_neon); - - for (; k + 8 <= K; k += 8) { - int8x8_t a0 = vld1_s8(A + m * Astride + k); - int8x8_t b0 = vld1_s8(B + k); - uint32x2_t zero = vdup_n_s32(0); - acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); - } - - for (; k < K; ++k) { - acc[0] += static_cast(A[m * Astride + k]) * B[k]; - } - C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; - } -} - -} // namespace - -bool megdnn::aarch64::matmul::is_gemv_like_preferred_int8( - bool transposeA, bool transposeB, size_t M, size_t N, size_t K, - size_t /* LDA */, size_t LDB, size_t /* LDC */) { - if (transposeA) - return false; - if (transposeB) - return false; - MEGDNN_MARK_USED_VAR(K); - MEGDNN_MARK_USED_VAR(M); - return (N == 1 && LDB == 1); -} - -void megdnn::aarch64::matmul::gemv_like_int8(const int8_t* __restrict A, - const int8_t* __restrict B, - int32_t* __restrict C, size_t M, - size_t N, size_t K, size_t Astride, - size_t Bstride, size_t Cstride) { - megdnn_assert(N == 1); - return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); -} - -#endif - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/gemv.h b/dnn/src/aarch64/matrix_mul/int8_dot/gemv.h deleted file mode 100644 index 61041ab12..000000000 --- a/dnn/src/aarch64/matrix_mul/int8_dot/gemv.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ -#pragma once - -#include -#include - -#if __ARM_FEATURE_DOTPROD -namespace megdnn { -namespace aarch64 { -namespace matmul { - -bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, - size_t N, size_t K, size_t LDA, size_t LDB, - size_t LDC); - -void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride); - -} // namespace matmul -} // namespace aarch64 -} // namespace megdnn -#endif - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp index d616e665a..709bf64bc 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.cpp +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -28,13 +28,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #endif #if __ARM_FEATURE_DOTPROD AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; - AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod; AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; #else AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; - AlgoInt8x8x32Gemv int8x8x32_gemv; #endif AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; @@ -63,11 +61,9 @@ public: all_algos.emplace_back(&f16_mk8_8x8); #endif #if __ARM_FEATURE_DOTPROD - all_algos.emplace_back(&int8x8x32_gemv_dotprod); all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); #else - all_algos.emplace_back(&int8x8x32_gemv); all_algos.emplace_back(&int8x8x32_k4x4x16); all_algos.emplace_back(&int8x8x32_k8x8x8); all_algos.emplace_back(&int8x8x32_mk4_4x4x16); diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h index dd8e01f9a..d7a625e48 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.h +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -34,14 +34,12 @@ private: #if __ARM_FEATURE_DOTPROD class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel // 8x12x4 DotProduct - class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel // 8x12x4 DotProduct #else class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 - class AlgoInt8x8x32Gemv; // Aarch64 Int8x8x32 Gemv #endif class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 diff --git a/dnn/src/arm_common/matrix_mul/algos.cpp b/dnn/src/arm_common/matrix_mul/algos.cpp index be5b663c8..d06c07e79 100644 --- a/dnn/src/arm_common/matrix_mul/algos.cpp +++ b/dnn/src/arm_common/matrix_mul/algos.cpp @@ -72,7 +72,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( return exec_int_8x8x16; } -#if !__ARM_FEATURE_DOTPROD /* ===================== Int8x8x32 Gemv algo ===================== */ namespace { void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { @@ -102,7 +101,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( const KernSizeParam&) const { return int8x8x32_gemv_kern; } -#endif /* ===================== F32 Gemv algo ===================== */ namespace { @@ -112,7 +110,6 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - arm_common::sgemm_sgemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); } } // anonymous namespace diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index 8db0b316b..68b13a85b 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -27,11 +27,7 @@ public: PackMode packmode() const override { return PackMode::NO_PACK; } }; -#if !__ARM_FEATURE_DOTPROD class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { -protected: - ~AlgoInt8x8x32Gemv() = default; - public: bool is_reproducible() const override { return true; } const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } @@ -43,7 +39,6 @@ public: AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } }; -#endif class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { protected: diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp index 676d0ba8a..2f86fb15b 100644 --- a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp @@ -9,8 +9,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if !__ARM_FEATURE_DOTPROD - #include #include "src/arm_common/matrix_mul/int8/gemv.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -23,6 +21,8 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv) using namespace megdnn; using namespace arm_common; +#if !__ARM_FEATURE_DOTPROD + namespace { void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, @@ -95,8 +95,82 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, C[m * Cstride] = acc0; } } +} // namespace +#endif + +#if __ARM_FEATURE_DOTPROD +namespace { + +void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(N == 1 && Bstride == 1); + size_t m = 0; + for (; m + 2 <= M; m += 2) { + int32_t acc[4]; + int32x4_t acc_neon = vdupq_n_s32(0); + size_t k = 0; + for (; k + 16 <= K; k += 16) { + int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); + int64x2_t a1 = + vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); + //! the first 8 elements is m, the last 8 elements is m + 1 + int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); + int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); + + int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); + int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); + int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); + + acc_neon = vdotq_s32(acc_neon, a2, b2); + acc_neon = vdotq_s32(acc_neon, a3, b3); + } + vst1q_s32(acc, acc_neon); + + for (; k + 8 <= K; k += 8) { + int8x8_t a0 = vld1_s8(A + m * Astride + k); + int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); + int8x8_t b0 = vld1_s8(B + k); + uint32x2_t zero = vdup_n_s32(0); + acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); + zero = vdup_n_s32(0); + acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); + } + + for (; k < K; ++k) { + acc[0] += static_cast(A[m * Astride + k]) * B[k]; + acc[3] += static_cast(A[(m + 1) * Astride + k]) * B[k]; + } + C[m * Cstride] = acc[0] + acc[1]; + C[(m + 1) * Cstride] = acc[2] + acc[3]; + } + + for (; m < M; ++m) { + int32_t acc[4]; + int32x4_t acc_neon = vdupq_n_s32(0); + size_t k = 0; + for (; k + 16 <= K; k += 16) { + int8x16_t a0 = vld1q_s8(A + m * Astride + k); + int8x16_t b0 = vld1q_s8(B + k); + acc_neon = vdotq_s32(acc_neon, a0, b0); + } + vst1q_s32(acc, acc_neon); + for (; k + 8 <= K; k += 8) { + int8x8_t a0 = vld1_s8(A + m * Astride + k); + int8x8_t b0 = vld1_s8(B + k); + uint32x2_t zero = vdup_n_s32(0); + acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); + } + + for (; k < K; ++k) { + acc[0] += static_cast(A[m * Astride + k]) * B[k]; + } + C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; + } +} } // namespace +#endif bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, size_t N, size_t K, @@ -124,6 +198,5 @@ void matmul::gemv_like_int8(const int8_t* __restrict A, } MIDOUT_END(); } -#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.h b/dnn/src/arm_common/matrix_mul/int8/gemv.h index dc66d84f9..5b1a558cc 100644 --- a/dnn/src/arm_common/matrix_mul/int8/gemv.h +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.h @@ -13,7 +13,6 @@ #include #include -#if !__ARM_FEATURE_DOTPROD namespace megdnn { namespace arm_common { namespace matmul { @@ -28,6 +27,6 @@ void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, } // namespace matmul } // namespace arm_common } // namespace megdnn -#endif + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp index d06311d1a..0767751e8 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.cpp +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -27,13 +27,14 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC AlgoF16Gemv f16gemv; #endif - + AlgoInt8x8x32Gemv int8x8x32_gemv; public: AlgoPack() { all_algos.emplace_back(&int8x8x16); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC all_algos.emplace_back(&f16gemv); #endif + all_algos.emplace_back(&int8x8x32_gemv); } SmallVector all_algos; }; diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.h b/dnn/src/arm_common/matrix_mul/opr_impl.h index 78a1bc2a8..bfde8615d 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.h +++ b/dnn/src/arm_common/matrix_mul/opr_impl.h @@ -25,9 +25,7 @@ public: protected: static void* const sm_arm_common_algo_type; -#if !__ARM_FEATURE_DOTPROD class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv -#endif class AlgoF32Gemv; // Arm_common F32 Gemv #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class AlgoF16Gemv; diff --git a/dnn/src/arm_common/simd_macro/marm_neon.h b/dnn/src/arm_common/simd_macro/marm_neon.h index 9443e67c3..0019b178f 100644 --- a/dnn/src/arm_common/simd_macro/marm_neon.h +++ b/dnn/src/arm_common/simd_macro/marm_neon.h @@ -388,6 +388,19 @@ __ai int64x2_t vmovl_high_s32(int32x4_t __p0) { __ai uint64x2_t vmovl_high_u32(uint32x4_t __p0) { return vmovl_u32(vget_high_u32(__p0)); } + +__ai int64x2_t vzip1q_s64(int64x2_t& a, int64x2_t& b) { + return vcombine_s64(vget_low_s64(a), vget_low_s64(b)); +} + +__ai int64x2_t vzip2q_s64(int64x2_t& a, int64x2_t& b) { + return vcombine_s64(vget_high_s64(a), vget_high_s64(b)); +} + +__ai int32_t vaddv_s32(int32x2_t a) { + return vget_lane_s32(a, 0) + vget_lane_s32(a, 1); +} + #endif // MEGDNN_ARMV7 //! pack vmovl_low_xx() on armv7 and armv8 diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index 03d4c36ea..ae5d0fb71 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -134,11 +134,6 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; -#if !__ARM_FEATURE_DOTPROD -class MatrixMulImpl::AlgoInt8x8x32Gemv final - : public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; -#endif - class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { public: bool is_reproducible() const override { return true; } diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp index 4bb770820..3f3a037bd 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.cpp +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -35,9 +35,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; AlgoInt8x8x32K4x2x16 int8x8x32_k4x2x16; AlgoInt8x8x32K4x8x8 int8x8x32_k4x8x8; -#if !__ARM_FEATURE_DOTPROD - AlgoInt8x8x32Gemv int8x8x32_gemv; -#endif AlgoQuint8K4x8x8 quint8_k4x8x8; AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; @@ -60,9 +57,6 @@ public: all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); all_algos.emplace_back(&int8_k6x8x4); all_algos.emplace_back(&quint8_k4x8x4); -#endif -#if !__ARM_FEATURE_DOTPROD - all_algos.emplace_back(&int8x8x32_gemv); #endif all_algos.emplace_back(&int8x8x32_mk4_4x2x16); all_algos.emplace_back(&int8x8x32_k4x2x16); diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index a9973dafe..d502b63ce 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -27,9 +27,6 @@ private: class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 -#if !__ARM_FEATURE_DOTPROD - class AlgoInt8x8x32Gemv; // Armv7 Int8x8x32 Gemv -#endif class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 diff --git a/dnn/test/arm_common/matrix_mul.cpp b/dnn/test/arm_common/matrix_mul.cpp index 26874e1bc..e4d1ad49e 100644 --- a/dnn/test/arm_common/matrix_mul.cpp +++ b/dnn/test/arm_common/matrix_mul.cpp @@ -133,6 +133,36 @@ TEST_F(ARM_COMMON, MATRIX_MUL_FP16_TEST) { } #endif +TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { + Checker checker(handle()); + using Param = MatrixMul::Param; + + checker.set_before_exec_callback( + AlgoChecker("ARM_COMMON_INT8X8X32_GEMV")); + + std::unique_ptr rng = std::make_unique(-127, 127); + checker.set_rng(0, rng.get()).set_rng(1, rng.get()); + + auto run = [&](size_t M, size_t K, size_t N) { + Param param; + param.transposeA = false; + param.transposeB = false; + TensorShape A, B; + A = TensorShape{M, K}; + B = TensorShape{K, N}; + checker.set_param(param) + .set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .execs({A, B, {}}); + }; + + // N = 1 + for (size_t M : {1, 10, 16, 33, 64}) + for (size_t K : {7, 512, 1024}) + for (size_t N : {1}) + run(M, K, N); +} #if MEGDNN_WITH_BENCHMARK -- GitLab