From 714cb232bb394ed02191390b234a2ae837a3011d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 18 Jun 2020 19:40:08 +0800 Subject: [PATCH] feat(dnn): add gemv supports in conv1x1 for NCHW44 and NCHW44_DOT(aarch64 binary size grows 2KB) GitOrigin-RevId: f8b6d7a1b749fa469aa86f9b8bc2e9ef7f25c937 --- .../arm_common/conv_bias/postprocess_helper.h | 48 ++-- dnn/src/arm_common/matrix_mul/algos.cpp | 125 +++++++++ dnn/src/arm_common/matrix_mul/algos.h | 47 +++- .../arm_common/matrix_mul/fp32/exec_sgemv.cpp | 95 ++++++- .../arm_common/matrix_mul/fp32/exec_sgemv.h | 3 + dnn/src/arm_common/matrix_mul/int8/gemv.cpp | 238 +++++++++++++++++- dnn/src/arm_common/matrix_mul/int8/gemv.h | 10 + dnn/src/arm_common/matrix_mul/opr_impl.cpp | 10 + dnn/src/arm_common/matrix_mul/opr_impl.h | 11 +- dnn/src/arm_common/simd_macro/marm_neon.h | 10 + .../conv_bias/conv1x1/algos_conv1x1_gemv.cpp | 236 ++++++++++------- dnn/test/arm_common/conv_bias.cpp | 77 +++++- .../arm_common/conv_bias_multi_thread.cpp | 51 +++- dnn/test/arm_common/matrix_mul.cpp | 169 +++++++++++-- 14 files changed, 977 insertions(+), 153 deletions(-) diff --git a/dnn/src/arm_common/conv_bias/postprocess_helper.h b/dnn/src/arm_common/conv_bias/postprocess_helper.h index b36f747b..d665bfec 100644 --- a/dnn/src/arm_common/conv_bias/postprocess_helper.h +++ b/dnn/src/arm_common/conv_bias/postprocess_helper.h @@ -210,27 +210,33 @@ struct PostProcess { DEFAULT \ } -#define FOR_BIAS(_bias_mode, OH, OW) \ - switch (_bias_mode) { \ - case megdnn::BiasMode::NO_BIAS: \ - FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \ - break; \ - case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ - if (pack_oc_size == 1) { \ - FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ - } else { \ - megdnn_assert(pack_oc_size == 4, \ - "Only support nchw44 in ARM"); \ - FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ - } \ - break; \ - default: \ - if (OH * OW == 1) { \ - FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ - break; \ - } \ - megdnn_throw("quantized unsupported biasmode"); \ - break; \ +#define FOR_BIAS(_bias_mode, OH, OW) \ + switch (_bias_mode) { \ + case megdnn::BiasMode::NO_BIAS: \ + FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \ + break; \ + case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ + if (pack_oc_size == 1) { \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ + } else { \ + megdnn_assert(pack_oc_size == 4, \ + "Only support nchw44 in ARM"); \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ + } \ + break; \ + default: \ + if (OH * OW == 1) { \ + if (pack_oc_size == 1) { \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ + } else { \ + megdnn_assert(pack_oc_size == 4, \ + "Only support nchw44 in ARM"); \ + FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ + } \ + break; \ + } \ + megdnn_throw("quantized unsupported biasmode"); \ + break; \ } template diff --git a/dnn/src/arm_common/matrix_mul/algos.cpp b/dnn/src/arm_common/matrix_mul/algos.cpp index b5f03e47..af344f40 100644 --- a/dnn/src/arm_common/matrix_mul/algos.cpp +++ b/dnn/src/arm_common/matrix_mul/algos.cpp @@ -101,6 +101,91 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( return int8x8x32_gemv_kern; } +/* ===================== Int8x8x32 Gemv MK4 algo ===================== */ +namespace { +void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::usable( + const KernSizeParam& kern_size_param) const { + auto M = kern_size_param.M; + auto N = kern_size_param.N; + auto K = kern_size_param.K; + auto LDB = kern_size_param.LDB; + + bool is_dtype_ok = + kern_size_param.A_type == kern_size_param.B_type && + (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); + + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK4 && + is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && + M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; +} + +bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::preferred( + const KernSizeParam& kern_size_param) const { + MEGDNN_MARK_USED_VAR(kern_size_param); + return true; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( + const KernSizeParam&) const { + return int8x8x32_gemv_mk4_kern; +} + +#if __ARM_FEATURE_DOTPROD +/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ +namespace { +void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable( + const KernSizeParam& kern_size_param) const { + auto M = kern_size_param.M; + auto N = kern_size_param.N; + auto K = kern_size_param.K; + auto LDB = kern_size_param.LDB; + + bool is_dtype_ok = + kern_size_param.A_type == kern_size_param.B_type && + (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); + + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK4_DOT && + is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && + M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; +} + +bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::preferred( + const KernSizeParam& kern_size_param) const { + return true; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::get_kern( + const KernSizeParam&) const { + return int8x8x32_gemv_mk4_dot_kern; +} +#endif + /* ===================== F32 Gemv algo ===================== */ namespace { void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { @@ -137,6 +222,46 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( return f32_gemv_kern; } +/* ================== F32 Gemv MK4 algo ================== */ +namespace { +void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoF32GemvMK4::usable( + const KernSizeParam& kern_size_param) const { + // enumerate the M, N, K, only usable when preferred + auto M = kern_size_param.M; + auto N = kern_size_param.N; + auto K = kern_size_param.K; + auto LDB = kern_size_param.LDB; + + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK4 && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && + !kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && + LDB == 4; +} + +bool MatrixMulImpl::AlgoF32GemvMK4::preferred( + const KernSizeParam& kern_size_param) const { + MEGDNN_MARK_USED_VAR(kern_size_param); + return true; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern( + const KernSizeParam&) const { + return f32_gemv_mk4_kern; +} + /* ===================== F32 Gevm algo ===================== */ namespace { template diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index b34512d4..5a7f665d 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -43,6 +43,36 @@ public: MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) }; +class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } + PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) +}; + +#if __ARM_FEATURE_DOTPROD +class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } + PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) +}; +#endif + class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { protected: ~AlgoF32Gemv() = default; @@ -60,6 +90,20 @@ public: MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) }; +class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } + PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4) +}; + #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { public: @@ -87,10 +131,9 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) + MEGDNN_OVERRIDE_MATMUL_DESC(1, 1, 1, 4) }; - } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp index 01f50a45..daba2abc 100644 --- a/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp +++ b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp @@ -13,11 +13,11 @@ #include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" #include #include "include/megdnn/oprs.h" -#include "midout.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" +#include "midout.h" MIDOUT_DECL(megdnn_fp32_sgemv) using namespace megdnn; @@ -68,18 +68,10 @@ void sgemv_naive_n(const float* __restrict A, const float* __restrict B, #if !defined(__aarch64__) #undef vaddvq_f32 #endif -} // namespace - -namespace megdnn { -namespace arm_common { -void gemv_like(const float* __restrict A, const float* __restrict B, - float* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { - megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); - if (N == 1) { - return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); - } +void sgemv_naive_m(const float* __restrict A, const float* __restrict B, + float* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { size_t m = 0; for (; m + 4 <= M; m += 4) { size_t k = 0; @@ -762,6 +754,85 @@ void gemv_like(const float* __restrict A, const float* __restrict B, } } } + +void sgemv_naive_n_mk4(const float* __restrict A, const float* __restrict B, + float* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + constexpr size_t PACK_SIZE = 4; + megdnn_assert(N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && + K % PACK_SIZE == 0); + auto Aptr = A; + auto Cptr = C; + size_t m = 0; + while (m < M) { + auto Aptr0 = Aptr; + auto Cptr0 = Cptr; + float32x4_t c[4]; +#define INIT(step) c[step] = vdupq_n_f32(0.0f); + UNROLL_CALL_RAW(4, INIT) +#undef INIT + auto Bptr = B; + size_t k = 0; + while (k < K) { + float32x4_t b = vld1q_f32(Bptr); + float32x4x2_t a[2]; +#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8); + UNROLL_CALL_RAW(2, LOAD_A) +#undef LOAD_A + +#define COMPT(step) \ + c[step] = vfmaq_laneq_f32(c[step], a[step / 2].val[step % 2], b, step % 4); + UNROLL_CALL_RAW(4, COMPT) +#undef COMPT + Bptr += Bstride; + Aptr0 += PACK_SIZE * PACK_SIZE; + k += PACK_SIZE; + } + +#define ADD_C(step, stride) c[step] = vaddq_f32(c[step], c[step + stride]); + UNROLL_CALL_RAW(2, ADD_C, 2) + UNROLL_CALL_RAW(1, ADD_C, 1) +#undef ADD_C + vst1q_f32(Cptr0, c[0]); + + Aptr += Astride; + Cptr += Cstride; + m += PACK_SIZE; + } +} + +} // namespace + +namespace megdnn { +namespace arm_common { + +void gemv_like(const float* __restrict A, const float* __restrict B, + float* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); + if (N == 1) { + MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW_N"_hash)) { + return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); + } + MIDOUT_END(); + } else { + MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW_M"_hash)) { + return sgemv_naive_m(A, B, C, M, N, K, Astride, Bstride, Cstride); + } + MIDOUT_END(); + } +} + +void gemv_like_mk4(const float* __restrict A, const float* __restrict B, + float* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(N == 1 && Bstride == 4); + MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW44_N"_hash)) { + return sgemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); + } + MIDOUT_END(); +} + } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h index 982d22a0..b1cb1491 100644 --- a/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h +++ b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h @@ -24,6 +24,9 @@ void gemv_like(const float* __restrict A, const float* __restrict B, float* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); +void gemv_like_mk4(const float* __restrict A, const float* __restrict B, + float* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp index 7bc5509f..4fa50802 100644 --- a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp @@ -10,8 +10,8 @@ */ #include -#include "src/arm_common/matrix_mul/int8/gemv.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/matrix_mul/int8/gemv.h" #include "src/common/utils.h" #include "megdnn/oprs.h" @@ -95,6 +95,80 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, C[m * Cstride] = acc0; } } + +void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + constexpr size_t PACK_SIZE = 4; + megdnn_assert(N == 1 && Bstride == 4); + auto Aptr = A; + size_t m = 0; + for (; m < M; m += PACK_SIZE) { + auto Bptr = B; + auto Aptr0 = Aptr; + int32_t acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; + size_t k = 0; + for (; k + 16 <= K; k += 16) { + int8x16x4_t a = vld4q_s8(Aptr0); + int8x16_t b = vld1q_s8(Bptr); + int16x8_t c[4]; + + c[0] = vmull_s8(vget_low_s8(a.val[0]), vget_low_s8(b)); + c[1] = vmull_s8(vget_low_s8(a.val[1]), vget_low_s8(b)); + c[2] = vmull_s8(vget_low_s8(a.val[2]), vget_low_s8(b)); + c[3] = vmull_s8(vget_low_s8(a.val[3]), vget_low_s8(b)); + + c[0] = vmlal_high_s8(c[0], a.val[0], b); + c[1] = vmlal_high_s8(c[1], a.val[1], b); + c[2] = vmlal_high_s8(c[2], a.val[2], b); + c[3] = vmlal_high_s8(c[3], a.val[3], b); + + acc0 += vaddlvq_s16(c[0]); + acc1 += vaddlvq_s16(c[1]); + acc2 += vaddlvq_s16(c[2]); + acc3 += vaddlvq_s16(c[3]); + + Bptr += 16; + Aptr0 += PACK_SIZE * 16; + } + + for (; k + 8 <= K; k += 8) { + int8x8x4_t a = vld4_s8(Aptr0); + int8x8_t b = vld1_s8(Bptr); + int16x8_t c[4]; + + c[0] = vmull_s8(a.val[0], b); + c[1] = vmull_s8(a.val[1], b); + c[2] = vmull_s8(a.val[2], b); + c[3] = vmull_s8(a.val[3], b); + + acc0 += vaddlvq_s16(c[0]); + acc1 += vaddlvq_s16(c[1]); + acc2 += vaddlvq_s16(c[2]); + acc3 += vaddlvq_s16(c[3]); + + Bptr += 8; + Aptr0 += PACK_SIZE * 8; + } + + for (; k < K; ++k) { + acc0 += static_cast(*(Aptr0 + 0)) * B[k]; + acc1 += static_cast(*(Aptr0 + 1)) * B[k]; + acc2 += static_cast(*(Aptr0 + 2)) * B[k]; + acc3 += static_cast(*(Aptr0 + 3)) * B[k]; + Aptr0 += 4; + } + + C[0] = acc0; + C[1] = acc1; + C[2] = acc2; + C[3] = acc3; + + Aptr += Astride; + C += Cstride; + } +} + } // namespace #endif @@ -169,6 +243,139 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; } } + +void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride) { + constexpr size_t PACK_SIZE = 4; + megdnn_assert(N == 1 && Bstride == 4); + + auto Aptr = A; + size_t m = 0; + for (; m < M; m += PACK_SIZE) { + auto Bptr = B; + auto Aptr0 = Aptr; + int32_t acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; + size_t k = 0; + if (k + 16 <= K) { + int32x4_t acc_neon[4]; + acc_neon[0] = vdupq_n_s32(0); + acc_neon[1] = vdupq_n_s32(0); + acc_neon[2] = vdupq_n_s32(0); + acc_neon[3] = vdupq_n_s32(0); + for (; k + 16 <= K; k += 16) { + int8x16x4_t a = vld4q_s8(Aptr0); + int8x16_t b = vld1q_s8(Bptr); + + acc_neon[0] = vdotq_s32(acc_neon[0], a.val[0], b); + acc_neon[1] = vdotq_s32(acc_neon[1], a.val[1], b); + acc_neon[2] = vdotq_s32(acc_neon[2], a.val[2], b); + acc_neon[3] = vdotq_s32(acc_neon[3], a.val[3], b); + + Bptr += 16; + Aptr0 += PACK_SIZE * 16; + } + acc0 = vaddvq_s32(acc_neon[0]); + acc1 = vaddvq_s32(acc_neon[1]); + acc2 = vaddvq_s32(acc_neon[2]); + acc3 = vaddvq_s32(acc_neon[3]); + } + + if (k + 8 <= K) { + int32x2_t acc_neon[4]; + acc_neon[0] = vdup_n_s32(0); + acc_neon[1] = vdup_n_s32(0); + acc_neon[2] = vdup_n_s32(0); + acc_neon[3] = vdup_n_s32(0); + + int8x8x4_t a = vld4_s8(Aptr0); + int8x8_t b = vld1_s8(Bptr); + acc_neon[0] = vdot_s32(acc_neon[0], a.val[0], b); + acc_neon[1] = vdot_s32(acc_neon[1], a.val[1], b); + acc_neon[2] = vdot_s32(acc_neon[2], a.val[2], b); + acc_neon[3] = vdot_s32(acc_neon[3], a.val[3], b); + + Bptr += 8; + Aptr0 += PACK_SIZE * 8; + k += 8; + + acc0 += vaddv_s32(acc_neon[0]); + acc1 += vaddv_s32(acc_neon[1]); + acc2 += vaddv_s32(acc_neon[2]); + acc3 += vaddv_s32(acc_neon[3]); + } + + for (; k < K; ++k) { + acc0 += static_cast(*(Aptr0 + 0)) * B[k]; + acc1 += static_cast(*(Aptr0 + 1)) * B[k]; + acc2 += static_cast(*(Aptr0 + 2)) * B[k]; + acc3 += static_cast(*(Aptr0 + 3)) * B[k]; + Aptr0 += 4; + } + + C[0] = acc0; + C[1] = acc1; + C[2] = acc2; + C[3] = acc3; + + Aptr += Astride; + C += Cstride; + } +} + +void gemv_naive_n_mk4_dot(const int8_t* __restrict A, + const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, + size_t Bstride, size_t Cstride) { + constexpr size_t PACK_SIZE = 4; + megdnn_assert(N == 1 && Bstride == 4); + + auto Aptr = A; + size_t m = 0; + for (; m < M; m += PACK_SIZE) { + auto Bptr = B; + auto Aptr0 = Aptr; + size_t k = 0; + int32x4_t acc_neon; + acc_neon = vdupq_n_s32(0); + for (; k + 16 <= K; k += 16) { + int8x16_t a0 = vld1q_s8(Aptr0); + int8x16_t a1 = vld1q_s8(Aptr0 + 16); + int8x16_t a2 = vld1q_s8(Aptr0 + 32); + int8x16_t a3 = vld1q_s8(Aptr0 + 48); + int8x16_t b = vld1q_s8(Bptr); + acc_neon = vdotq_laneq_s32(acc_neon, a0, b, 0); + acc_neon = vdotq_laneq_s32(acc_neon, a1, b, 1); + acc_neon = vdotq_laneq_s32(acc_neon, a2, b, 2); + acc_neon = vdotq_laneq_s32(acc_neon, a3, b, 3); + Bptr += 16; + Aptr0 += PACK_SIZE * 16; + } + + if (k + 8 <= K) { + int8x16_t a0 = vld1q_s8(Aptr0); + int8x16_t a1 = vld1q_s8(Aptr0 + 16); + int8x8_t b = vld1_s8(Bptr); + acc_neon = vdotq_lane_s32(acc_neon, a0, b, 0); + acc_neon = vdotq_lane_s32(acc_neon, a1, b, 1); + Bptr += 8; + Aptr0 += PACK_SIZE * 8; + k += 8; + } + + if (k + 4 <= K) { + int8x16_t a = vld1q_s8(Aptr0); + int32_t tmp = *(reinterpret_cast(Bptr)); + int8x8_t b = vdup_n_s32(tmp); + acc_neon = vdotq_lane_s32(acc_neon, a, b, 0); + } + + vst1q_s32(C, acc_neon); + Aptr += Astride; + C += Cstride; + } +} + } // namespace #endif @@ -201,4 +408,33 @@ void arm_common::gemv_like(const int8_t* __restrict A, MIDOUT_END(); } +void arm_common::gemv_like_mk4(const int8_t* __restrict A, + const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, + size_t K, size_t Astride, size_t Bstride, + size_t Cstride) { + megdnn_assert(N == 1); + MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, + midout_iv("INT8_gemv_like_mk4"_hash)) { + return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); + } + MIDOUT_END(); +} + +#if __ARM_FEATURE_DOTPROD +void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A, + const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, + size_t K, size_t Astride, size_t Bstride, + size_t Cstride) { + megdnn_assert(N == 1); + MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, + midout_iv("INT8_gemv_like_mk4_dot"_hash)) { + return gemv_naive_n_mk4_dot(A, B, C, M, N, K, Astride, Bstride, + Cstride); + } + MIDOUT_END(); +} +#endif + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.h b/dnn/src/arm_common/matrix_mul/int8/gemv.h index 3080167f..9c8d6c69 100644 --- a/dnn/src/arm_common/matrix_mul/int8/gemv.h +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.h @@ -24,6 +24,16 @@ void gemv_like(const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); +void gemv_like_mk4(const int8_t* __restrict A, const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride); + +#if __ARM_FEATURE_DOTPROD +void gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B, + int32_t* __restrict C, size_t M, size_t N, size_t K, + size_t Astride, size_t Bstride, size_t Cstride); +#endif + } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp index 5ef42558..128da56f 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.cpp +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -28,14 +28,24 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF16Gemv f16gemv; #endif AlgoInt8x8x32Gemv int8x8x32_gemv; + AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; +#if __ARM_FEATURE_DOTPROD + AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; +#endif AlgoGevm gevm; + AlgoF32GemvMK4 f32_gemv_mk4; public: AlgoPack() { all_algos.emplace_back(&int8x8x16); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC all_algos.emplace_back(&f16gemv); +#endif +#if __ARM_FEATURE_DOTPROD + all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); #endif all_algos.emplace_back(&int8x8x32_gemv); + all_algos.emplace_back(&int8x8x32_gemv_mk4); + all_algos.emplace_back(&f32_gemv_mk4); all_algos.emplace_back(&gevm); } SmallVector all_algos; diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.h b/dnn/src/arm_common/matrix_mul/opr_impl.h index bd53987d..9ed9f3c0 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.h +++ b/dnn/src/arm_common/matrix_mul/opr_impl.h @@ -25,11 +25,16 @@ public: protected: static void* const sm_arm_common_algo_type; - class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv - class AlgoF32Gemv; // Arm_common F32 Gemv - class AlgoGevm; // Arm_common Gemv(support int8 and fp32) + class AlgoF32Gemv; // Arm_common F32 Gemv + class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 + class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv + class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 + class AlgoGevm; // Arm_common Gevm(support int8 and fp32) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class AlgoF16Gemv; +#endif +#if __ARM_FEATURE_DOTPROD + class AlgoInt8x8x32GemvMK4Dot;// Arm_common Int8x8x32 Gemv NCHW44_DOT #endif class AlgoInt8x8x16; // Arm_common Int 8x8x16 class AlgoPack; diff --git a/dnn/src/arm_common/simd_macro/marm_neon.h b/dnn/src/arm_common/simd_macro/marm_neon.h index 8a6cde61..9deffc4d 100644 --- a/dnn/src/arm_common/simd_macro/marm_neon.h +++ b/dnn/src/arm_common/simd_macro/marm_neon.h @@ -407,6 +407,16 @@ __ai int32_t vaddv_s32(int32x2_t a) { return vget_lane_s32(a, 0) + vget_lane_s32(a, 1); } +__ai int32_t vaddvq_s32(int32x4_t a) { + return vgetq_lane_s32(a, 0) + vgetq_lane_s32(a, 1) + + vgetq_lane_s32(a, 2) + vgetq_lane_s32(a, 3); +} + +__ai float32_t vaddvq_f32(float32x4_t a) { + return vgetq_lane_f32(a, 0) + vgetq_lane_f32(a, 1) + + vgetq_lane_f32(a, 2) + vgetq_lane_f32(a, 3); +} + #endif // MEGDNN_ARMV7 //! pack vmovl_low_xx() on armv7 and armv8 diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp index c4c1e147..ccfa2bad 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp @@ -42,14 +42,27 @@ using namespace conv1x1; namespace { -#if MEGDNN_X86 template struct GemvLike { inline static void do_gemv(const stype* A, const stype* B, btype* C, size_t M, size_t N, size_t K, size_t LDA, size_t LDB, size_t LDC, DType src, DType filter) { - megdnn_throw("x86 conv1x1 gemv only supports format : NCHW"); + MEGDNN_MARK_USED_VAR(A); + MEGDNN_MARK_USED_VAR(B); + MEGDNN_MARK_USED_VAR(C); + MEGDNN_MARK_USED_VAR(M); + MEGDNN_MARK_USED_VAR(N); + MEGDNN_MARK_USED_VAR(K); + MEGDNN_MARK_USED_VAR(LDA); + MEGDNN_MARK_USED_VAR(LDB); + MEGDNN_MARK_USED_VAR(LDC); + MEGDNN_MARK_USED_VAR(src); + MEGDNN_MARK_USED_VAR(filter); + megdnn_assert(false, + "unspported conv1x1 gemv : \nsrc_type : " + "%s\nfilter_type : %s\n", + src.name(), filter.name()); } }; @@ -66,39 +79,29 @@ struct GemvLike { } }; -#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 -template -struct GemvLike { - inline static void do_gemv(const stype* A, const stype* B, btype* C, - size_t M, size_t N, size_t K, size_t LDA, - size_t LDB, size_t LDC, DType src, - DType filter) { - megdnn_throw("arm conv1x1 gemv only supports format : NCHW"); - } -}; - -template -struct GemvLike { - inline static void do_gemv(const stype* A, const stype* B, btype* C, - size_t M, size_t N, size_t K, size_t LDA, - size_t LDB, size_t LDC, DType src, +template <> +struct GemvLike { + inline static void do_gemv(const dt_uint8* A, const dt_uint8* B, + dt_int32* C, size_t M, size_t N, size_t K, + size_t LDA, size_t LDB, size_t LDC, DType src, DType filter) { - MEGDNN_MARK_USED_VAR(src); - MEGDNN_MARK_USED_VAR(filter); - megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC); + uint8_t zp0 = src.param().zero_point; + uint8_t zp1 = filter.param().zero_point; + megdnn::fallback::gemv_like(A, B, C, M, N, K, LDA, + LDB, LDC, zp0, zp1); } }; +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 template <> -struct GemvLike { - inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int16* C, - size_t M, size_t N, size_t K, size_t LDA, - size_t LDB, size_t LDC, DType src, +struct GemvLike { + inline static void do_gemv(const dt_float32* A, const dt_float32* B, + dt_float32* C, size_t M, size_t N, size_t K, + size_t LDA, size_t LDB, size_t LDC, DType src, DType filter) { MEGDNN_MARK_USED_VAR(src); MEGDNN_MARK_USED_VAR(filter); - megdnn::fallback::gemv_like(A, B, C, M, N, K, LDA, - LDB, LDC); + megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC); } }; @@ -118,21 +121,47 @@ struct GemvLike { } }; #endif -#endif template <> -struct GemvLike { - inline static void do_gemv(const dt_uint8* A, const dt_uint8* B, - dt_int32* C, size_t M, size_t N, size_t K, - size_t LDA, size_t LDB, size_t LDC, DType src, +struct GemvLike { + inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int32* C, + size_t M, size_t N, size_t K, size_t LDA, + size_t LDB, size_t LDC, DType src, DType filter) { - uint8_t zp0 = src.param().zero_point; - uint8_t zp1 = filter.param().zero_point; - megdnn::fallback::gemv_like(A, B, C, M, N, K, LDA, - LDB, LDC, zp0, zp1); + MEGDNN_MARK_USED_VAR(src); + MEGDNN_MARK_USED_VAR(filter); + megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC); } }; +template +struct GemvLike { + inline static void do_gemv(const stype* A, const stype* B, btype* C, + size_t M, size_t N, size_t K, size_t LDA, + size_t LDB, size_t LDC, DType src, + DType filter) { + MEGDNN_MARK_USED_VAR(src); + MEGDNN_MARK_USED_VAR(filter); + megdnn::arm_common::gemv_like_mk4(A, B, C, M, N, K, LDA, LDB, LDC); + } +}; + +#if __ARM_FEATURE_DOTPROD +template +struct GemvLike { + inline static void do_gemv(const stype* A, const stype* B, btype* C, + size_t M, size_t N, size_t K, size_t LDA, + size_t LDB, size_t LDC, DType src, + DType filter) { + MEGDNN_MARK_USED_VAR(src); + MEGDNN_MARK_USED_VAR(filter); + megdnn::arm_common::gemv_like_mk4_dot(A, B, C, M, N, K, LDA, LDB, LDC); + } +}; +#endif + +#endif + template (conv_bias_dst); + size_t pack_size = megdnn::fallback::pack_size(format); GemvLike::do_gemv( - Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC, 1, 1, - ncb_param.filter_type, ncb_param.src_type); + Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC * pack_size, + pack_size, pack_size, ncb_param.filter_type, + ncb_param.src_type); //! do postprocess void* bias_ptr = nullptr; - if (param.bias_mode == megdnn::BiasMode::BIAS) { + if (param.bias_mode != megdnn::BiasMode::NO_BIAS) { bias_ptr = static_cast(const_cast( ncb_param.bias(batch_id, group_id) + numbers_of_ncb_dst_offset)); - } else { - bias_ptr = static_cast(const_cast( - ncb_param.bias(batch_id, group_id) + oc_start)); } PostProcess::run( @@ -211,9 +239,13 @@ struct Conv1x1GemvWorker { size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic( const NCBKernSizeParam& param) const { - size_t OC = param.filter_meta.ocpg; - size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); - return round_up(oc_block_size_one_thread, 16); + MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, + midout_iv("AlgoConv1x1Gemv::get_oc_tile"_hash)) { + size_t OC = param.filter_meta.ocpg; + size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); + return round_up(oc_block_size_one_thread, 16); + } + MIDOUT_END(); } size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace( @@ -286,6 +318,11 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC cb1(param::ConvBias::Format::NCHW, dt_float16, __fp16, PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT16_FP16"_hash); +#else +#if !MEGDNN_DISABLE_FLOAT16 + cb1(param::ConvBias::Format::NCHW, dt_float16, dt_float16, + PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); +#endif #endif cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, @@ -311,6 +348,37 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( "NCHW::GEMV::QUINT8x8x32_QUINT8"_hash); break; + case param::ConvBias::Format::NCHW44: + cb1(param::ConvBias::Format::NCHW44, dt_float32, dt_float32, + PostprocessMode::FLOAT, "NCHW44::GEMV::FLOAT"_hash); + cb2(param::ConvBias::Format::NCHW44, dt_int8, dt_int32, dt_int32, + dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, + "NCHW44::GEMV::INT8x8x32_INT32"_hash); + cb2(param::ConvBias::Format::NCHW44, dtype::QuantizedS8, + dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, + dt_int32, PostprocessMode::NO_PROCESS, + "NCHW44::GEMV::QINT8x8x32_QINT32"_hash); + cb2(param::ConvBias::Format::NCHW44, dtype::QuantizedS8, + dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, + dt_int8, PostprocessMode::QUANTIZED, + "NCHW44::GEMV::QINT8x8x32_QINT8"_hash); + break; + + case param::ConvBias::Format::NCHW44_DOT: + cb2(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, + dt_int32, dt_int8, dt_int32, dt_int32, + PostprocessMode::NO_PROCESS, + "NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash); + cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, + dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, + dt_int32, PostprocessMode::NO_PROCESS, + "NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash); + cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, + dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, + dt_int8, PostprocessMode::QUANTIZED, + "NCHW44_DOT::GEMV::QINT8x8x32_QINT8"_hash); + break; + default: megdnn_throw("Invalid Format"); break; @@ -338,6 +406,16 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(ConvBiasImpl* opr, AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv("AlgoConv1x1Gemv::usable"_hash)) { +#if MEGDNN_X86 + if (opr->param().format != param::ConvBias::Format::NCHW) + return false; +#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 + if (opr->param().format != param::ConvBias::Format::NCHW && + opr->param().format != param::ConvBias::Format::NCHW44 && + opr->param().format != param::ConvBias::Format::NCHW44_DOT) + return false; +#endif + //! whether 1x1 size_t FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1]; @@ -390,59 +468,43 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(ConvBiasImpl* opr, param.src_type.enumv() != DTypeEnum::Float32) { return false; } - - bool is_param_ok = - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && - param.filter_meta.dilation[0] == 1) && - param.compute_mode == param::ConvBias::ComputeMode::DEFAULT; - - bool is_format_and_dtype_ok = false; -#if MEGDNN_X86 - if (opr->param().format == param::ConvBias::Format::NCHW) { - //! x86 supports all dtypes in NCHW - is_format_and_dtype_ok = true; - } -#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 - //! add NCHW44 and NCHW44_DOT support in the future - if (opr->param().format == param::ConvBias::Format::NCHW) { - //! NCHW format supports all dtype - is_format_and_dtype_ok = true; +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 + if (opr->param().format == param::ConvBias::Format::NCHW44) { + if (param.src_type.enumv() != DTypeEnum::Float32 && + param.src_type.enumv() != DTypeEnum::Int8 && + param.src_type.enumv() != DTypeEnum::QuantizedS8) { + return false; + } + } else if (opr->param().format == param::ConvBias::Format::NCHW44_DOT) { + if (param.src_type.enumv() != DTypeEnum::Int8 && + param.src_type.enumv() != DTypeEnum::QuantizedS8) { + return false; + } } #endif - return is_param_ok && is_format_and_dtype_ok; + return (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT; } MIDOUT_END(); return false; } bool ConvBiasImpl::AlgoConv1x1Gemv::is_preferred( - ConvBiasImpl*, const NCBKernSizeParam& param) const { - size_t OC = param.filter_meta.ocpg; - if (OC <= 2 && param.src_type.enumv() != DTypeEnum::Float32) - return true; + ConvBiasImpl* opr, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, + midout_iv("AlgoConv1x1Gemv::is_preferred"_hash)) { #if (MEGDNN_ARMV7 || MEGDNN_AARCH64) - //! maybe add support for QuantizedAsym in the future - return (param.src_type.enumv() == DTypeEnum::Int8 && - param.filter_type.enumv() == DTypeEnum::Int8 && - param.dst_type.enumv() == DTypeEnum::Int32) || - (param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.filter_type.enumv() == DTypeEnum::QuantizedS8 && - param.dst_type.enumv() == DTypeEnum::QuantizedS8) || - (param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.filter_type.enumv() == DTypeEnum::QuantizedS8 && - param.dst_type.enumv() == DTypeEnum::QuantizedS32) || -#if !MEGDNN_DISABLE_FLOAT16 - (param.src_type.enumv() == DTypeEnum::Float16 && - param.filter_type.enumv() == DTypeEnum::Float16 && - param.dst_type.enumv() == DTypeEnum::Float16) || + if (opr->param().format == param::ConvBias::Format::NCHW && + param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { + return false; + } #endif - (param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32); -#else + return true; + } + MIDOUT_END(); return false; -#endif } // vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 52ef5f11..79f8a55d 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -2036,7 +2036,6 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, RUNS; auto matmul_used = benchmark_matmul.exec({A, B, {}}) / RUNS; - printf("\n%s: ", matmul_algo_name); printf("%s %s:\n matmul: %f ms %f Gflops\nconv1x1: %f ms %f GFlops " "speedup: " "%f\n", @@ -2120,6 +2119,82 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_INT8x8x16) { #endif } +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) { + using namespace conv_bias; + std::vector args; + param::ConvBias conv_param; + conv_param.stride_h = 1; + conv_param.stride_w = 1; + conv_param.pad_h = 0; + conv_param.pad_w = 0; + conv_param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; + auto run = [&](size_t M, size_t K){ + args.emplace_back(conv_param, TensorShape{1, K, 1, 1}, + TensorShape{M, K, 1, 1}, TensorShape{}); + }; + for (size_t M : {4, 64, 1024, 4096}) + for (size_t K : {128, 256, 1024, 4096}) + run(M, K); + + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmark_matmul(handle()); + benchmark_matmul.set_before_exec_callback( + AlgoChecker("ARM_COMMON_F32_GEMV")); + benchmark_matmul.set_times(RUNS) + .set_dtype(0, dtype::Float32{}) + .set_dtype(1, dtype::Float32{}) + .set_dtype(2, dtype::Float32{}) + .set_param(param) + .set_display(false); + + Benchmarker benchmark_conv1x1(handle()); + benchmark_conv1x1.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("CONV1x1_GEMV")); + benchmark_conv1x1.set_times(RUNS) + .set_dtype(0, dtype::Float32{}) + .set_dtype(1, dtype::Float32{}) + .set_dtype(2, dtype::Float32{}) + .set_dtype(4, dtype::Float32{}) + .set_display(false); + + std::cout << "warm up:\n"; + for (int i = 0; i < 50; i++) { + benchmark_matmul.exec({{1, 1024}, {1024, 512}, {}}); + benchmark_matmul.set_display(true); + } + + for (auto&& arg : args) { + size_t IC = arg.src[1]; + size_t OH = arg.src[2]; + size_t OW = arg.src[3]; + size_t OC = arg.filter[0]; + size_t M = OC; + size_t K = IC; + size_t N = OH * OW; + + float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3; + + TensorShape A, B; + A = TensorShape{M, K}; + B = TensorShape{K, N}; + + auto conv1x1_used = benchmark_conv1x1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUNS; + auto matmul_used = benchmark_matmul.exec({A, B, {}}) / RUNS; + + printf("%s %s:\n gemv: %f ms %f Gflops\nconv1x1: %f ms %f GFlops " + "speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + matmul_used, computations / matmul_used, conv1x1_used, + computations / conv1x1_used, matmul_used / conv1x1_used); + } +} + #ifndef __ARM_FEATURE_DOTPROD TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { std::vector conv_bias_1x1_args_nchw44 = diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 3c82e69f..f64f01c5 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -180,12 +180,15 @@ std::vector get_nchw44_conv_bias_args( for (size_t kernel : kernel_vec) for (size_t oc : {4, 12}) for (size_t ic : {1, 3, 4, 12}) - for (size_t h : {3, 5, 12}) - for (size_t w : {7, 16, 23}) { + for (size_t h : {1, 3, 12}) + for (size_t w : {1, 16, 23}) { for (size_t group = 1; group <= std::min(std::min(oc, ic), 4_z); ++group) { + if (kernel != 1 && (h == 1 || w == 1)) { + continue; + } pack(n, oc, ic, h, w, kernel, stride, group, nlmode, bias); } @@ -1897,6 +1900,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { #elif MEGDNN_ARMV7 check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24"); #endif + std::vector gemv_args; + for (auto&& arg : args) + if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { + gemv_args.emplace_back(arg); + } + check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { @@ -1932,7 +1941,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { #endif std::vector gemv_args; for (auto&& arg : args) - if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { + if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { gemv_args.emplace_back(arg); } check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); @@ -2138,4 +2147,40 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { } #endif +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, 1, true, false, false); + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; + std::vector gemv_args; + for (auto&& arg : args) + if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { + gemv_args.emplace_back(arg); + } + checker_conv_bias(gemv_args, handle(), &rng, epsilon, + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), + "CONV1x1_GEMV"); +} + +#ifdef __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44_DOT) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, 1, true, false, false, false, true); + UniformIntRNG rng{-50, 50}; + float epsilon = 0.001; + std::vector gemv_args; + for (auto&& arg : args) + if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { + gemv_args.emplace_back(arg); + } + checker_conv_bias(gemv_args, handle(), &rng, epsilon, + dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), + dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), + "CONV1x1_GEMV"); +} +#endif + // vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/matrix_mul.cpp b/dnn/test/arm_common/matrix_mul.cpp index a2415898..faf4d816 100644 --- a/dnn/test/arm_common/matrix_mul.cpp +++ b/dnn/test/arm_common/matrix_mul.cpp @@ -156,7 +156,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { .set_dtype(2, dtype::QuantizedS32(6.25f)) .execs({A, B, {}}); }; - + // N = 1 for (size_t M : {1, 10, 16, 33, 64}) for (size_t K : {7, 512, 1024}) @@ -164,6 +164,70 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { run(M, K, N); } +TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4) { + Checker checker(handle()); + using Param = MatrixMul::Param; + + checker.set_before_exec_callback( + AlgoChecker("ARM_COMMON_INT8X8X32_GEMV_MK4")); + + std::unique_ptr rng = std::make_unique(-127, 127); + checker.set_rng(0, rng.get()).set_rng(1, rng.get()); + + auto run = [&](size_t M, size_t K, size_t N) { + Param param; + param.format = param::MatrixMul::Format::MK4; + param.transposeA = false; + param.transposeB = false; + TensorShape A, B; + A = TensorShape{M / 4, K / 4, 4, 4}; + B = TensorShape{K / 4, 1, 4}; + checker.set_param(param) + .set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .execs({A, B, {}}); + }; + + // N = 1 + for (size_t M : {4, 16, 128, 1024}) + for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) + run(M, K, 1); +} + +#if __ARM_FEATURE_DOTPROD +TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4_DOT) { + Checker checker(handle()); + using Param = MatrixMul::Param; + + checker.set_before_exec_callback( + AlgoChecker("ARM_COMMON_INT8X8X32_GEMV_MK4_DOT")); + + std::unique_ptr rng = std::make_unique(-127, 127); + checker.set_rng(0, rng.get()).set_rng(1, rng.get()); + + auto run = [&](size_t M, size_t K, size_t N) { + Param param; + param.format = param::MatrixMul::Format::MK4_DOT; + param.transposeA = false; + param.transposeB = false; + TensorShape A, B; + A = TensorShape{M / 4, K / 4, 4, 4}; + B = TensorShape{K / 4, 1, 4}; + checker.set_param(param) + .set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .execs({A, B, {}}); + }; + + // N = 1 + for (size_t M : {4, 16, 128, 1024}) + for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) + run(M, K, 1); +} +#endif + TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { Checker checker(handle()); using Param = MatrixMul::Param; @@ -220,6 +284,31 @@ TEST_F(ARM_COMMON, FP32_GEVM) { run(M, K, N); } +TEST_F(ARM_COMMON, FP32_GEMV_MK4) { + Checker checker(handle()); + using Param = MatrixMul::Param; + + checker.set_before_exec_callback( + AlgoChecker("ARM_COMMON_F32_GEMV_MK4")); + + checker.set_epsilon(1e-2); + auto run = [&](size_t M, size_t K) { + Param param; + param.format = param::MatrixMul::Format::MK4; + param.transposeA = false; + param.transposeB = false; + TensorShape A, B; + A = TensorShape{M/4, K/4, 4, 4}; + B = TensorShape{K/4, 1, 4}; + checker.set_param(param).execs({A, B, {}}); + }; + + // N = 1 + for (size_t M : {4, 16, 128, 1024}) + for (size_t K : {4, 8, 12, 128, 256, 4096}) + run(M, K); +} + #if MEGDNN_WITH_BENCHMARK TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { @@ -228,18 +317,16 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { benchmarker.set_times(exec_times); auto run = [&](size_t M, size_t K, size_t N) { - std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" - << std::endl; + printf("SGEMV: (%zu, %zu, %zu)\n", M, K, N); benchmarker.set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()); auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; auto computations = 2.f * M * K * N * 1e-6; auto perf = computations / time; - std::cout << "gemv fp32, Performance is " << perf << " Gflops" - << std::endl; + printf("gemv fp32, Performance is %f Gflops\n", perf); }; - std::cout << "warm up:\n"; + printf("warm up:\n"); for (int i = 0; i < 50; i++) { benchmarker.set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()) @@ -253,6 +340,10 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { for (size_t K : {1024, 1536, 2048}) for (size_t N : {512, 1024}) run(M, K, N); + + for (size_t M : {4, 64, 1024, 4096}) + for (size_t K : {128, 256, 1024, 4096}) + run(M, K, 1); } TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { @@ -263,28 +354,25 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { AlgoChecker("ARM_COMMON_F32_GEMV")); auto run = [&](size_t M, size_t K, size_t N) { - std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" - << std::endl; + printf("SGEMV: (%zu, %zu, %zu)\n", M, K, N); benchmarker.set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()) .set_dtype(2, dtype::Float32()); auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; auto computations = 2 * M * K * N * 1e-6; auto perf = computations / time; - std::cout << "gemv fp32, Performance is " << perf << " Gflops" - << std::endl; + printf("gemv fp32, Performance is %f Gflops\n", perf); }; - std::cout << "warm up:\n"; + printf("warm up:\n"); for (int i = 0; i < 50; i++) { benchmarker.set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()) - .set_dtype(2, dtype::Float32()) .set_display(false) .exec({{2, 1024}, {1024, 512}, {}}); benchmarker.set_display(true); } - + // run gemv run(12, 48, 1); run(48, 12, 1); @@ -298,6 +386,45 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { run(1024, 256, 1); } +TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) { + int exec_times = 10; + using Param = MatrixMul::Param; + Param param; + param.format = param::MatrixMul::Format::MK4; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmarker(handle()); + benchmarker.set_times(exec_times); + benchmarker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_param(param); + + auto run = [&](size_t M, size_t K) { + printf("SGEMV_MK4: (%zu, %zu, %zu)\n", M, K, N); + TensorShape A, B; + A = TensorShape{M/4, K/4, 4, 4}; + B = TensorShape{K/4, 1, 4}; + auto time = benchmarker.exec({A, B, {}}) / exec_times; + auto computations = 2.f * M * K * 1e-6; + auto perf = computations / time; + printf("gemv mk4 fp32, Performance is %f Gflops\n", perf); + }; + + printf("warm up:\n"); + for (int i = 0; i < 50; i++) { + benchmarker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_display(false) + .exec({{4, 256, 4, 4}, {256, 1, 4}, {}}); + } + + // run gemv mk4 + for (size_t M : {4, 64, 1024, 4096}) + for (size_t K : {128, 1024, 4096}) + run(M, K); +} + TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { int exec_times = 50; Benchmarker benchmarker(handle()); @@ -306,19 +433,17 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { AlgoChecker("ARM_COMMON_F16_GEMV")); auto run = [&](size_t M, size_t K, size_t N) { - std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" - << std::endl; + printf("SGEMV_FP16: (%zu, %zu, %zu)\n", M, K, N); benchmarker.set_dtype(0, dtype::Float16()) .set_dtype(1, dtype::Float16()) .set_dtype(2, dtype::Float16()); auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; auto computations = 2 * M * K * N * 1e-6; auto perf = computations / time; - std::cout << "gemv fp16, Performance is " << perf << " Gflops" - << std::endl; + printf("gemv fp16, Performance is %f Gflops\n", perf); }; - std::cout << "warm up:\n"; + printf("warm up:\n"); for (int i = 0; i < 50; i++) { benchmarker.set_dtype(0, dtype::Float16()) .set_dtype(1, dtype::Float16()) @@ -343,17 +468,15 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) { float mod = 1000 * exec_times / 1e9; auto run = [&](size_t M, size_t K, size_t N) { float time = 1.f, perf = 1.f; - std::cout << "SGEMM: (" << M << ", " << K << ", " << N << ")" - << std::endl; + printf("SGEMM: (%zu, %zu, %zu)\n", M, K, N); benchmarker.set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()); time = benchmarker.exec({{M, K}, {K, N}, {}}); perf = 2.f * M * K * N / time * mod; - std::cout << "gemm fp32, Performance is " << perf << " Gflops" - << std::endl; + printf("gemm, Performance is %f Gflops\n", perf); }; - std::cout << "warm up:\n"; + printf("warm up:\n"); for (int i = 0; i < 50; i++) { benchmarker.set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()) -- GitLab