From 5e306b756b9f8212613ff677e472f2551f3f3bf2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 21 Jun 2022 19:26:11 +0800 Subject: [PATCH] feat(x86): make conv1x1 and im2col available on with x86-NCHW44 add AlgoF32GiMK4Pack4x12 matrix_mul algo GitOrigin-RevId: 47cfe1d733d80c4c8ca8f4a8fa2d8469e530da75 --- dnn/src/fallback/conv_bias/conv1x1/algos.cpp | 11 +- dnn/src/fallback/conv_bias/im2col/algos.cpp | 14 +- dnn/src/fallback/conv_bias/im2col/factory.h | 6 +- .../fallback/conv_bias/im2col/strategy_base.h | 2 - .../im2col/strategy_fuse_nchw44_fp32_s2.cpp | 83 ++-- dnn/src/fallback/matrix_mul/algos.cpp | 72 ++- dnn/src/fallback/matrix_mul/algos.h | 13 + .../fallback/matrix_mul/generic_strategy.h | 2 + dnn/src/fallback/matrix_mul/gi/fp32/common.h | 107 ++++ .../matrix_mul/gi/fp32/strategy_mk_4x12.cpp | 458 ++++++++++++++++++ dnn/src/fallback/matrix_mul/opr_impl.cpp | 2 + dnn/src/fallback/matrix_mul/opr_impl.h | 10 +- dnn/src/x86/conv_bias/postprocess_helper.h | 4 +- dnn/src/x86/elemwise_op.h | 8 +- dnn/test/fallback/conv_bias.cpp | 68 +++ dnn/test/fallback/matrix_mul.cpp | 15 +- 16 files changed, 814 insertions(+), 61 deletions(-) create mode 100644 dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp index cd84efa6d..54bbbd8c5 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp @@ -197,10 +197,17 @@ bool ConvBiasImpl::AlgoConv1x1::usable( return false; } } -#else //! x86 only support nchw mode - if (format != param::ConvBias::Format::NCHW) { +#else //! x86 and RISC-V do not support NCHW44_DOT + if (format != param::ConvBias::Format::NCHW && + format != param::ConvBias::Format::NCHW44) { return false; } + //! hybird mode is not support + if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { + if (param.filter_meta.icpg < 4_z || param.filter_meta.ocpg == 1) { + return false; + } + } #endif //! param if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) { diff --git a/dnn/src/fallback/conv_bias/im2col/algos.cpp b/dnn/src/fallback/conv_bias/im2col/algos.cpp index 37efbf967..61d568bc3 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.cpp +++ b/dnn/src/fallback/conv_bias/im2col/algos.cpp @@ -345,9 +345,21 @@ bool ConvBiasImpl::AlgoIm2col::usable( } } #else - if (format != param::ConvBias::Format::NCHW) { + if (format != param::ConvBias::Format::NCHW && + format != param::ConvBias::Format::NCHW44) { return false; } + if (format == param::ConvBias::Format::NCHW44) { + //! current NCHW44 im2col only support DEFAULT mode matmul + if (matmul_desc.packmode != Pack_Mode::DEFAULT) { + return false; + //! nchw44 hybird mode and channel wise is not support + } else if ( + param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || + param.filter_meta.ocpg == 1) { + return false; + } + } #endif if (param.src_type.enumv() != param.filter_type.enumv() || (param.src_type.enumv() != DTypeEnum::Int8 && diff --git a/dnn/src/fallback/conv_bias/im2col/factory.h b/dnn/src/fallback/conv_bias/im2col/factory.h index e4544c206..f23581f87 100644 --- a/dnn/src/fallback/conv_bias/im2col/factory.h +++ b/dnn/src/fallback/conv_bias/im2col/factory.h @@ -216,10 +216,9 @@ public: cb1(NCHW, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash); } else if (format == param::ConvBias::Format::NCHW44) { -#if MEGDNN_AARCH64 || MEGDNN_ARMV7 auto matmul_block = matmul_algo->get_inner_block_size(); - //! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 - //! im2col+pack fuse + //! Optimize NCHW44 3x3s2 on aarch64 8X12X4 and fallback/armv7 + //! 4x12x4 im2col+pack fuse if ((matmul_block.m == 8 || matmul_block.m == 4) && matmul_block.n == 12 && matmul_block.k == 1 && param.filter_meta.spatial[0] == 3 && @@ -236,7 +235,6 @@ public: MIDOUT_END(); return {}; } -#endif cb1(NCHW44, DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, "DefaultStrategyTypeNCHW44::FLOAT"_hash); diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_base.h b/dnn/src/fallback/conv_bias/im2col/strategy_base.h index 47261288f..7d3475b84 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_base.h +++ b/dnn/src/fallback/conv_bias/im2col/strategy_base.h @@ -530,7 +530,6 @@ public: }; #endif -#if MEGDNN_AARCH64 || MEGDNN_ARMV7 template < typename op_ctype, typename op_dtype, megdnn::PostprocessMode postprocess_mode> class StrategyFuseXx12x1Nchw44K3x3S2 @@ -553,7 +552,6 @@ public: fallback::MatrixMulImpl::KernParam matmul_param, const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; }; -#endif } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp index 8d22e842d..1330ad8f6 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp @@ -1,7 +1,6 @@ #include "src/fallback/conv_bias/im2col/strategy_base.h" -#if MEGDNN_AARCH64 || MEGDNN_ARMV7 -#include +#include "src/fallback/general_intrinsic/gi_float.h" using namespace megdnn; @@ -11,32 +10,32 @@ namespace { int out_index = 0; \ outptr = output_base; \ for (; out_index + 11 < block_size; out_index += 12) { \ - float32x4x4_t v0 = vld4q_f32(tmp_output); \ - float32x4x4_t v1 = vld4q_f32(tmp_output + 16); \ - float32x4x4_t v2 = vld4q_f32(tmp_output + 32); \ - vst1q_f32(outptr, v0.val[0]); \ - vst1q_f32(outptr + 4, v1.val[0]); \ - vst1q_f32(outptr + 8, v2.val[0]); \ - vst1q_f32(outptr + 12, v0.val[1]); \ - vst1q_f32(outptr + 16, v1.val[1]); \ - vst1q_f32(outptr + 20, v2.val[1]); \ - vst1q_f32(outptr + 24, v0.val[2]); \ - vst1q_f32(outptr + 28, v1.val[2]); \ - vst1q_f32(outptr + 32, v2.val[2]); \ - vst1q_f32(outptr + 36, v0.val[3]); \ - vst1q_f32(outptr + 40, v1.val[3]); \ - vst1q_f32(outptr + 44, v2.val[3]); \ + GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(tmp_output); \ + GI_FLOAT32_V4_t v1 = GiLoadUzipFloat32V4(tmp_output + 16); \ + GI_FLOAT32_V4_t v2 = GiLoadUzipFloat32V4(tmp_output + 32); \ + GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \ + GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v1, 0)); \ + GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v2, 0)); \ + GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 1)); \ + GiStoreFloat32(outptr + 16, GiGetSubVectorFloat32V4(v1, 1)); \ + GiStoreFloat32(outptr + 20, GiGetSubVectorFloat32V4(v2, 1)); \ + GiStoreFloat32(outptr + 24, GiGetSubVectorFloat32V4(v0, 2)); \ + GiStoreFloat32(outptr + 28, GiGetSubVectorFloat32V4(v1, 2)); \ + GiStoreFloat32(outptr + 32, GiGetSubVectorFloat32V4(v2, 2)); \ + GiStoreFloat32(outptr + 36, GiGetSubVectorFloat32V4(v0, 3)); \ + GiStoreFloat32(outptr + 40, GiGetSubVectorFloat32V4(v1, 3)); \ + GiStoreFloat32(outptr + 44, GiGetSubVectorFloat32V4(v2, 3)); \ outptr += ksize12; \ tmp_output += 48; \ } \ \ outptr = output_base4; \ for (; out_index + 3 < block_size; out_index += 4) { \ - float32x4x4_t v0 = vld4q_f32(tmp_output); \ - vst1q_f32(outptr, v0.val[0]); \ - vst1q_f32(outptr + 4, v0.val[1]); \ - vst1q_f32(outptr + 8, v0.val[2]); \ - vst1q_f32(outptr + 12, v0.val[3]); \ + GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(tmp_output); \ + GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \ + GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v0, 1)); \ + GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v0, 2)); \ + GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 3)); \ outptr += ksize4; \ tmp_output += 16; \ } \ @@ -45,23 +44,23 @@ namespace { float zerobuffer[16] = {0}; \ size_t out_remain = std::min(block_size - out_index, 4); \ std::memcpy(zerobuffer, tmp_output, out_remain * sizeof(float) * 4); \ - float32x4x4_t v0 = vld4q_f32(zerobuffer); \ - vst1q_f32(outptr, v0.val[0]); \ - vst1q_f32(outptr + 4, v0.val[1]); \ - vst1q_f32(outptr + 8, v0.val[2]); \ - vst1q_f32(outptr + 12, v0.val[3]); \ + GI_FLOAT32_V4_t v0 = GiLoadUzipFloat32V4(zerobuffer); \ + GiStoreFloat32(outptr, GiGetSubVectorFloat32V4(v0, 0)); \ + GiStoreFloat32(outptr + 4, GiGetSubVectorFloat32V4(v0, 1)); \ + GiStoreFloat32(outptr + 8, GiGetSubVectorFloat32V4(v0, 2)); \ + GiStoreFloat32(outptr + 12, GiGetSubVectorFloat32V4(v0, 3)); \ } \ output_base += 48; \ output_base4 += 16; -#define LOAD_AND_STOR_IM2COL_DST() \ - float32x4_t v1 = vld1q_f32(&src[index + 4]); \ - float32x4_t v2 = vld1q_f32(&src[index + 8]); \ - vst1q_f32(&output0[i], v0); \ - vst1q_f32(&output1[i], v1); \ - vst1q_f32(&output2[i], v2); \ - i += 4; \ - index += 8; \ +#define LOAD_AND_STOR_IM2COL_DST() \ + GI_FLOAT32_t v1 = GiLoadFloat32(&src[index + 4]); \ + GI_FLOAT32_t v2 = GiLoadFloat32(&src[index + 8]); \ + GiStoreFloat32(&output0[i], v0); \ + GiStoreFloat32(&output1[i], v1); \ + GiStoreFloat32(&output2[i], v2); \ + i += 4; \ + index += 8; \ v0 = v2; void fuse_packb( @@ -94,12 +93,12 @@ void fuse_packb( size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW + cur_remain_w * SW); for (int w = cur_remain_w; w < end_remain_w; w++) { - vst1q_f32(&output02[i], vld1q_f32(&src[index])); - vst1q_f32(&output1[i], vld1q_f32(&src[index + 4])); + GiStoreFloat32(&output02[i], GiLoadFloat32(&src[index])); + GiStoreFloat32(&output1[i], GiLoadFloat32(&src[index + 4])); i += 4; index += 8; } - vst1q_f32(&output02[i], vld1q_f32(&src[index])); + GiStoreFloat32(&output02[i], GiLoadFloat32(&src[index])); float* output[3]; output[0] = output02; output[1] = output1; @@ -120,19 +119,19 @@ void fuse_packb( size_t index = 4 * (ic * IH * IW + (start_h * SH + fh) * IW + (cur_remain_w * SW)); - float32x4_t v0 = vld1q_f32(&src[index]); + GI_FLOAT32_t v0 = GiLoadFloat32(&src[index]); for (int w = cur_remain_w; w < OW; w++) { LOAD_AND_STOR_IM2COL_DST(); } for (int h = start_h + 1; h < end_h; h++) { size_t index = 4 * (ic * IH * IW + (h * SH + fh) * IW); - v0 = vld1q_f32(&src[index]); + v0 = GiLoadFloat32(&src[index]); rep(ow, OW) { LOAD_AND_STOR_IM2COL_DST(); } } index = 4 * (ic * IH * IW + (end_h * SH + fh) * IW); - v0 = vld1q_f32(&src[index]); + v0 = GiLoadFloat32(&src[index]); for (int w = 0; w < end_remain_w; w++) { LOAD_AND_STOR_IM2COL_DST(); } @@ -190,6 +189,4 @@ template class StrategyFuseXx12x1Nchw44K3x3S2< float, float, megdnn::PostprocessMode::FLOAT>; } // namespace megdnn -#endif - // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/algos.cpp b/dnn/src/fallback/matrix_mul/algos.cpp index 3188b23d8..52f104a2b 100644 --- a/dnn/src/fallback/matrix_mul/algos.cpp +++ b/dnn/src/fallback/matrix_mul/algos.cpp @@ -57,7 +57,7 @@ void kern_naive(const MatrixMulImpl::KernParam& kern_param) { size_t pack_size = get_pack_size(); megdnn_assert( (M % pack_size == 0 && K % pack_size == 0), - "M and N must time of pack_size M: %zu N: %zu pack_size: %zu", M, N, + "M and K must time of pack_size M: %zu K: %zu pack_size: %zu", M, N, pack_size); #define DISPATCH(TA, TB) \ @@ -263,12 +263,15 @@ void gi_f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) { } // anonymous namespace bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable( const KernSizeParam& kern_size_param) const { + constexpr size_t MB = 4; + constexpr size_t KB = 4; return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.format == param::MatrixMul::Format::MK4 && kern_size_param.B_type == kern_size_param.A_type && kern_size_param.C_type == kern_size_param.A_type && kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && - !kern_size_param.trB; + !kern_size_param.trB && kern_size_param.M % MB == 0 && + kern_size_param.K % KB == 0; } size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace( @@ -295,6 +298,71 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern( return gi_f32_mk4_4x8_kern; } +/* ===================== F32 algo gi mk4 pack K4x12 ===================== */ +namespace { +void f32_gi_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN( + megdnn_fb_gi_matmul_kern, midout_iv("f32_gi_mk4_pack_4x12_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + matmul::fallback::gi_sgemm_mk4_pack_4x12 strategy( + M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); + } + MIDOUT_END(); +} + +} // anonymous namespace + +bool MatrixMulImpl::AlgoF32GiMK4Pack4x12::usable( + const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::MK4 && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && + !kern_size_param.trB && kern_size_param.M % 4 == 0 && + kern_size_param.K % 4 == 0 && !kern_size_param.trA && !kern_size_param.trB; +} + +size_t MatrixMulImpl::AlgoF32GiMK4Pack4x12::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN( + megdnn_fb_gi_matmul_kern, + midout_iv("AlgoF32GiMK4Pack4x12::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + matmul::fallback::gi_sgemm_mk4_pack_4x12 strategy( + M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved< + matmul::fallback::gi_sgemm_mk4_pack_4x12>( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); + return 0; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4Pack4x12::get_kern( + const KernSizeParam&) const { + return f32_gi_mk4_pack_4x12_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoF32GiMK4Pack4x12, megdnn_fb_gi_matmul_kern, "AlgoF32GiMK4Pack4x12"_hash, + matmul::fallback::gi_sgemm_mk4_pack_4x12, float, float, AlgoDataType::FLOAT32, + MK4); + /* ===================== F32 algo ===================== */ namespace { void f32_kern(const MatrixMulImpl::KernParam& kern_param) { diff --git a/dnn/src/fallback/matrix_mul/algos.h b/dnn/src/fallback/matrix_mul/algos.h index bef120257..fcf0ff37b 100644 --- a/dnn/src/fallback/matrix_mul/algos.h +++ b/dnn/src/fallback/matrix_mul/algos.h @@ -97,6 +97,19 @@ public: MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) }; +class MatrixMulImpl::AlgoF32GiMK4Pack4x12 final : public AlgoBase { +public: + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + } + const char* name() const override { return "FB_GI_F32_MK4_PACK_4x12"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_PACK_4x12) +}; + class MatrixMulImpl::AlgoF32Gi4x12 final : public AlgoBase { public: AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } diff --git a/dnn/src/fallback/matrix_mul/generic_strategy.h b/dnn/src/fallback/matrix_mul/generic_strategy.h index 4dc1aecd4..3b28c95e5 100644 --- a/dnn/src/fallback/matrix_mul/generic_strategy.h +++ b/dnn/src/fallback/matrix_mul/generic_strategy.h @@ -9,6 +9,8 @@ MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12) MEGDNN_REG_GEMM_STRATEGY_NOPACK( float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12); +MEGDNN_REG_GEMM_STRATEGY( + float, float, float, 4, 12, 1, false, false, gi_sgemm_mk4_pack_4x12); } // namespace fallback } // namespace matmul diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/common.h b/dnn/src/fallback/matrix_mul/gi/fp32/common.h index f282bc1ae..cc7813b64 100644 --- a/dnn/src/fallback/matrix_mul/gi/fp32/common.h +++ b/dnn/src/fallback/matrix_mul/gi/fp32/common.h @@ -214,6 +214,113 @@ static GI_FORCEINLINE void transpose_4x4_1_s( outptr += stride; } +template +static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { + static_assert(sizeof(T) == 4, "transpose_1x12_4_s only support sizeof(T) == 4"); + GI_FLOAT32_t tmp_a, tmp_b; +#define LOAD() \ + tmp_a = GiLoadFloat32(inptr0); \ + inptr0 += 4; \ + tmp_b = GiLoadFloat32(inptr0); \ + inptr0 += 4; + + LOAD(); + GI_FLOAT32_V2_t d0d1d2d3 = GiZipqFloat32(tmp_a, tmp_b); + LOAD(); + GI_FLOAT32_V2_t d4d5d6d7 = GiZipqFloat32(tmp_a, tmp_b); + LOAD(); + GI_FLOAT32_V2_t d8d9d10d11 = GiZipqFloat32(tmp_a, tmp_b); + LOAD(); + GI_FLOAT32_V2_t d12d13d14d15 = GiZipqFloat32(tmp_a, tmp_b); + LOAD(); + GI_FLOAT32_V2_t d16d17d18d19 = GiZipqFloat32(tmp_a, tmp_b); + LOAD(); + GI_FLOAT32_V2_t d20d21d22d23 = GiZipqFloat32(tmp_a, tmp_b); +#undef LOAD + GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); + GiSt1Float32(outptr + 1 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); + GiSt1Float32( + outptr + 2 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 0))); + GiSt1Float32( + outptr + 3 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 0))); + GiSt1Float32( + outptr + 4 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 0))); + GiSt1Float32( + outptr + 5 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 0))); + GiSt1Float32( + outptr + 6 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); + GiSt1Float32( + outptr + 7 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); + GiSt1Float32( + outptr + 8 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 0))); + GiSt1Float32( + outptr + 9 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 0))); + GiSt1Float32( + outptr + 10 * 2, + GiGetHighFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 0))); + GiSt1Float32( + outptr + 11 * 2, + GiGetHighFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 0))); + GiSt1Float32( + outptr + 12 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); + GiSt1Float32( + outptr + 13 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); + GiSt1Float32( + outptr + 14 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 1))); + GiSt1Float32( + outptr + 15 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 1))); + GiSt1Float32( + outptr + 16 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 1))); + GiSt1Float32( + outptr + 17 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 1))); + GiSt1Float32( + outptr + 18 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); + GiSt1Float32( + outptr + 19 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); + GiSt1Float32( + outptr + 20 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d8d9d10d11, 1))); + GiSt1Float32( + outptr + 21 * 2, + GiGetHighFloat32(GiGetSubVectorFloat32V2(d12d13d14d15, 1))); + GiSt1Float32( + outptr + 22 * 2, + GiGetHighFloat32(GiGetSubVectorFloat32V2(d16d17d18d19, 1))); + GiSt1Float32( + outptr + 23 * 2, + GiGetHighFloat32(GiGetSubVectorFloat32V2(d20d21d22d23, 1))); + outptr += 23 * 2; +} + +template +static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { + static_assert(sizeof(T) == 4, "transpose_1x4_4_s only support sizeof(T) == 4"); + GI_FLOAT32_t tmp_a, tmp_b; +#define LOAD() \ + tmp_a = GiLoadFloat32(inptr0); \ + inptr0 += 4; \ + tmp_b = GiLoadFloat32(inptr0); \ + inptr0 += 4; + + LOAD(); + GI_FLOAT32_V2_t d0d1d2d3 = GiZipqFloat32(tmp_a, tmp_b); + LOAD(); + GI_FLOAT32_V2_t d4d5d6d7 = GiZipqFloat32(tmp_a, tmp_b); +#undef LOAD + GiSt1Float32(outptr, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); + GiSt1Float32(outptr + 1 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); + GiSt1Float32( + outptr + 2 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 0))); + GiSt1Float32( + outptr + 3 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 0))); + GiSt1Float32(outptr + 4 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); + GiSt1Float32(outptr + 5 * 2, GiGetLowFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); + GiSt1Float32( + outptr + 6 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d0d1d2d3, 1))); + GiSt1Float32( + outptr + 7 * 2, GiGetHighFloat32(GiGetSubVectorFloat32V2(d4d5d6d7, 1))); + outptr += 7 * 2; +} + } // namespace fallback } // namespace matmul } // namespace megdnn diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp new file mode 100644 index 000000000..2752e8d70 --- /dev/null +++ b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp @@ -0,0 +1,458 @@ +//! risc-v gcc will error report uninitialized var at if/else case when use RVV type +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wuninitialized" + +#ifdef __GNUC__ +#ifndef __has_warning +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#else +#if __has_warning("-Wmaybe-uninitialized") +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif +#endif +#endif + +#include "src/fallback/matrix_mul/generic_strategy.h" +#include "src/fallback/matrix_mul/gi/fp32/common.h" + +using namespace megdnn; +using namespace matmul::fallback; + +namespace { + +void kern_4x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + float* output0 = output; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + float* r1 = output; + + GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, + d20d21, d22d23, d24d25, d26d27, d28d29, d30d31; + + if (is_first_k) { + d8d9 = GiBroadcastFloat32(0.0f); + d10d11 = GiBroadcastFloat32(0.0f); + d12d13 = GiBroadcastFloat32(0.0f); + d14d15 = GiBroadcastFloat32(0.0f); + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d16d17 = GiBroadcastFloat32(0.0f); + d18d19 = GiBroadcastFloat32(0.0f); + d20d21 = GiBroadcastFloat32(0.0f); + d22d23 = GiBroadcastFloat32(0.0f); + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d24d25 = GiBroadcastFloat32(0.0f); + d26d27 = GiBroadcastFloat32(0.0f); + d28d29 = GiBroadcastFloat32(0.0f); + d30d31 = GiBroadcastFloat32(0.0f); + } else { + d8d9 = GiLoadFloat32(r1); + r1 = r1 + 4; + d10d11 = GiLoadFloat32(r1); + r1 = r1 + 4; + + d12d13 = GiLoadFloat32(r1); + r1 = r1 + 4; + d14d15 = GiLoadFloat32(r1); + r1 = r1 + 4; + + d16d17 = GiLoadFloat32(r1); + r1 = r1 + 4; + d18d19 = GiLoadFloat32(r1); + r1 = r1 + 4; + + d20d21 = GiLoadFloat32(r1); + r1 = r1 + 4; + d22d23 = GiLoadFloat32(r1); + r1 = r1 + 4; + + d24d25 = GiLoadFloat32(r1); + r1 = r1 + 4; + d26d27 = GiLoadFloat32(r1); + r1 = r1 + 4; + + d28d29 = GiLoadFloat32(r1); + r1 = r1 + 4; + d30d31 = GiLoadFloat32(r1); + r1 = r1 + 4; + + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + } + for (; K > 0; K--) { + d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); + d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); + d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); + d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); + d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); + d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); + d2d3 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); + d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); + d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); + d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + + d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); + d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); + d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); + d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d16d17 = GiSimdFmaLane(d16d17, d2d3, d4d5, 0); + d18d19 = GiSimdFmaLane(d18d19, d2d3, d4d5, 1); + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d20d21 = GiSimdFmaLane(d20d21, d2d3, d4d5, 2); + d22d23 = GiSimdFmaLane(d22d23, d2d3, d4d5, 3); + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d24d25 = GiSimdFmaLane(d24d25, d2d3, d6d7, 0); + d26d27 = GiSimdFmaLane(d26d27, d2d3, d6d7, 1); + d28d29 = GiSimdFmaLane(d28d29, d2d3, d6d7, 2); + d30d31 = GiSimdFmaLane(d30d31, d2d3, d6d7, 3); + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + } + + if (1 == oddk) { + d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); + d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); + d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); + d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); + GiStoreFloat32(output0, d8d9); + output0 = output0 + 4; + GiStoreFloat32(output0, d10d11); + output0 = output0 + 4; + d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); + d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); + GiStoreFloat32(output0, d12d13); + output0 = output0 + 4; + GiStoreFloat32(output0, d14d15); + output0 = output0 + 4; + d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); + d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); + GiStoreFloat32(output0, d16d17); + output0 = output0 + 4; + GiStoreFloat32(output0, d18d19); + output0 = output0 + 4; + d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); + GiStoreFloat32(output0, d20d21); + output0 = output0 + 4; + GiStoreFloat32(output0, d22d23); + output0 = output0 + 4; + d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); + GiStoreFloat32(output0, d24d25); + output0 = output0 + 4; + GiStoreFloat32(output0, d26d27); + output0 = output0 + 4; + d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); + GiStoreFloat32(output0, d28d29); + output0 = output0 + 4; + GiStoreFloat32(output0, d30d31); + output0 = output0 + 4; + + } else { + d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); + d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); + d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); + d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d16d17 = GiSimdFmaLane(d16d17, d0d1, d6d7, 0); + d18d19 = GiSimdFmaLane(d18d19, d0d1, d6d7, 1); + d20d21 = GiSimdFmaLane(d20d21, d0d1, d6d7, 2); + d2d3 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d22d23 = GiSimdFmaLane(d22d23, d0d1, d6d7, 3); + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d24d25 = GiSimdFmaLane(d24d25, d0d1, d4d5, 0); + d26d27 = GiSimdFmaLane(d26d27, d0d1, d4d5, 1); + d28d29 = GiSimdFmaLane(d28d29, d0d1, d4d5, 2); + d30d31 = GiSimdFmaLane(d30d31, d0d1, d4d5, 3); + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + + d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); + d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); + d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); + d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d16d17 = GiSimdFmaLane(d16d17, d2d3, d4d5, 0); + d18d19 = GiSimdFmaLane(d18d19, d2d3, d4d5, 1); + GiStoreFloat32(output0, d8d9); + output0 = output0 + 4; + GiStoreFloat32(output0, d10d11); + output0 = output0 + 4; + d20d21 = GiSimdFmaLane(d20d21, d2d3, d4d5, 2); + d22d23 = GiSimdFmaLane(d22d23, d2d3, d4d5, 3); + GiStoreFloat32(output0, d12d13); + output0 = output0 + 4; + GiStoreFloat32(output0, d14d15); + output0 = output0 + 4; + d24d25 = GiSimdFmaLane(d24d25, d2d3, d6d7, 0); + d26d27 = GiSimdFmaLane(d26d27, d2d3, d6d7, 1); + GiStoreFloat32(output0, d16d17); + output0 = output0 + 4; + GiStoreFloat32(output0, d18d19); + output0 = output0 + 4; + d28d29 = GiSimdFmaLane(d28d29, d2d3, d6d7, 2); + d30d31 = GiSimdFmaLane(d30d31, d2d3, d6d7, 3); + GiStoreFloat32(output0, d20d21); + output0 = output0 + 4; + GiStoreFloat32(output0, d22d23); + output0 = output0 + 4; + GiStoreFloat32(output0, d24d25); + output0 = output0 + 4; + GiStoreFloat32(output0, d26d27); + output0 = output0 + 4; + GiStoreFloat32(output0, d28d29); + output0 = output0 + 4; + GiStoreFloat32(output0, d30d31); + output0 = output0 + 4; + } +} + +void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { + MEGDNN_MARK_USED_VAR(LDC); + const float* a_ptr = packA; + const float* b_ptr = packB; + + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + float* r1 = output; + + GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; + + if (is_first_k) { + d8d9 = GiBroadcastFloat32(0.0f); + d10d11 = GiBroadcastFloat32(0.0f); + + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + + d12d13 = GiBroadcastFloat32(0.0f); + + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + + d14d15 = GiBroadcastFloat32(0.0f); + } else { + if (n_remain == 4) { + d8d9 = GiLoadFloat32(r1); + r1 = r1 + 4; + d10d11 = GiLoadFloat32(r1); + r1 = r1 + 4; + d12d13 = GiLoadFloat32(r1); + r1 = r1 + 4; + d14d15 = GiLoadFloat32(r1); + r1 = r1 + 4; + } else if (n_remain == 3) { + d8d9 = GiLoadFloat32(r1); + r1 = r1 + 4; + d10d11 = GiLoadFloat32(r1); + r1 = r1 + 4; + d12d13 = GiLoadFloat32(r1); + r1 = r1 + 4; + } else if (n_remain == 2) { + d8d9 = GiLoadFloat32(r1); + r1 = r1 + 4; + d10d11 = GiLoadFloat32(r1); + r1 = r1 + 4; + } else if (n_remain == 1) { + d8d9 = GiLoadFloat32(r1); + r1 = r1 + 4; + } + } + + for (; K > 0; K--) { + d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); + d2d3 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); + d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); + + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); + d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); + d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); + } + + if (1 == oddk) { + d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); + d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); + d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); + d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); + } else { + d8d9 = GiSimdFmaLane(d8d9, d0d1, d4d5, 0); + d2d3 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d10d11 = GiSimdFmaLane(d10d11, d0d1, d4d5, 1); + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d12d13 = GiSimdFmaLane(d12d13, d0d1, d4d5, 2); + d14d15 = GiSimdFmaLane(d14d15, d0d1, d4d5, 3); + + d8d9 = GiSimdFmaLane(d8d9, d2d3, d6d7, 0); + d10d11 = GiSimdFmaLane(d10d11, d2d3, d6d7, 1); + d12d13 = GiSimdFmaLane(d12d13, d2d3, d6d7, 2); + d14d15 = GiSimdFmaLane(d14d15, d2d3, d6d7, 3); + } + + if (n_remain == 4) { + GiStoreFloat32(output, d8d9); + output = output + 4; + GiStoreFloat32(output, d10d11); + output = output + 4; + GiStoreFloat32(output, d12d13); + output = output + 4; + GiStoreFloat32(output, d14d15); + output = output + 4; + } else if (n_remain == 3) { + GiStoreFloat32(output, d8d9); + output = output + 4; + GiStoreFloat32(output, d10d11); + output = output + 4; + GiStoreFloat32(output, d12d13); + output = output + 4; + } else if (n_remain == 2) { + GiStoreFloat32(output, d8d9); + output = output + 4; + GiStoreFloat32(output, d10d11); + output = output + 4; + } else if (n_remain == 1) { + GiStoreFloat32(output, d8d9); + output = output + 4; + } +} + +} // namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_mk4_pack_4x12); +//! Now no matmul mode of only packB support in conv1x1 and im2col, so just copy +//! the weight +void gi_sgemm_mk4_pack_4x12::pack_A( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, + bool) const { + megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + constexpr int PACK_C_SIZE = 4; + size_t cp_length = (kmax - k0) * PACK_C_SIZE; + for (int m = y0; m < ymax; m += 4) { + const float* src = in + (m / PACK_C_SIZE) * ldin + k0 * PACK_C_SIZE; + memcpy(out, src, cp_length * sizeof(float)); + out += cp_length; + } +} + +void gi_sgemm_mk4_pack_4x12::pack_B( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose_B) const { + megdnn_assert(!transpose_B); + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + float tmpbuff[16] = {0.0f}; + + constexpr int PACK_C_SIZE = 4; + int ksize = kmax - k0; + int ksize12 = ksize * 12; + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE; + + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + transpose_1x12_4_s(inptr, outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + transpose_1x4_4_s(inptr, outptr_interleave); + outptr += ksize4; + } + if (x < xmax) { + memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE); + auto outptr_interleave = outptr; + const float* tmp_ptr = &tmpbuff[0]; + transpose_1x4_4_s(tmp_ptr, outptr_interleave); + outptr += ksize4; + } + outptr_base += 12 * PACK_C_SIZE; + outptr_base4 += 4 * PACK_C_SIZE; + } +} + +void gi_sgemm_mk4_pack_4x12::kern( + const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k, const float*, float*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); + constexpr int PACK_C_SIZE = 4; + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 12; + const int K12 = K * 12; + const int K4 = K * 4; + size_t m = 0; + for (; m < M; m += A_INTERLEAVE) { + float* output = C + (m / 4 * LDC); + + size_t n = 0; + const float* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); + output += PACK_C_SIZE * B_INTERLEAVE; + cur_packB += K12; + } + for (; n < N; n += 4) { + kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); + output += PACK_C_SIZE * 4; + cur_packB += K4; + } + packA += K4; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/opr_impl.cpp b/dnn/src/fallback/matrix_mul/opr_impl.cpp index bb51963cb..8343fbb26 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/matrix_mul/opr_impl.cpp @@ -28,6 +28,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoNaive naive; AlgoF32GiGemvMK4 f32_gemv_mk4; AlgoF32GiMK4_4x8 f32_mk4_4x8; + AlgoF32GiMK4Pack4x12 f32_mk4_gi_pack_4x12; AlgoF32Gi4x12 f32_4x8; SmallVector m_all_algos; AlgoBase::Mapper m_all_algos_map; @@ -36,6 +37,7 @@ public: AlgoPack() { m_all_algos.emplace_back(&f32_gemv_mk4); m_all_algos.emplace_back(&f32_mk4_4x8); + m_all_algos.emplace_back(&f32_mk4_gi_pack_4x12); m_all_algos.emplace_back(&f32_4x8); m_all_algos.emplace_back(&gemv); m_all_algos.emplace_back(&f32_k8x12x1); diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index 1af625cf0..93367b338 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -103,6 +103,7 @@ public: FB_NAIVE, FB_GI_F32_GEMV_MK4, FB_GI_F32_MK4_4x8, + FB_GI_F32_MK4_PACK_4x12, FB_GI_F32_4x12, #if MEGDNN_X86 @@ -230,10 +231,11 @@ public: }; private: - class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 - class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 - class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 - class AlgoF32Gi4x12; // fallback F32 gi Gemm + class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 + class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 + class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 + class AlgoF32GiMK4Pack4x12; // fallback F32 gi Gemm pack NCHW44 + class AlgoF32Gi4x12; // fallback F32 gi Gemm class AlgoGemv; class AlgoNaive; class AlgoPack; diff --git a/dnn/src/x86/conv_bias/postprocess_helper.h b/dnn/src/x86/conv_bias/postprocess_helper.h index 72d83899d..de7b48aab 100644 --- a/dnn/src/x86/conv_bias/postprocess_helper.h +++ b/dnn/src/x86/conv_bias/postprocess_helper.h @@ -364,7 +364,9 @@ struct PostProcess { DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { MEGDNN_MARK_USED_VAR(pack_oc_size); - megdnn_assert(pack_oc_size == 1, "PostProcess only support nchw in x86"); + megdnn_assert( + pack_oc_size == 1 || pack_oc_size == 4, + "PostProcess only support nchw/44 in x86"); megdnn_assert( nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY, "Add bias PostProcess only support IDENTITY"); diff --git a/dnn/src/x86/elemwise_op.h b/dnn/src/x86/elemwise_op.h index 18df3ae01..29f99d0eb 100644 --- a/dnn/src/x86/elemwise_op.h +++ b/dnn/src/x86/elemwise_op.h @@ -59,6 +59,11 @@ cb(dt_float32, float, "avx2", float, __m256, mm256, ps, ps, SIMDType::AVX2); template struct ParamElemVisitorHalfBoardCast; +//! some compiler do not define _mm256_set_m128 +#define _mm256_set_m128ff(xmm1, xmm2) \ + _mm256_permute2f128_ps( \ + _mm256_castps128_ps256(xmm1), _mm256_castps128_ps256(xmm2), 2) + #define cb( \ _ctype, _simd_ptr_type, load_half_fuc, half_type, _simd_type, board_cast_func) \ template <> \ @@ -78,9 +83,10 @@ cb(dt_int32, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); cb(dt_int16, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); cb(dt_int8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); cb(dt_uint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); -cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128); +cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128ff); #undef cb +#undef _mm256_set_m128ff /*! * \brief broadcast type * BCAST_x[0]x[1]...: x[i] == !stride[i] diff --git a/dnn/test/fallback/conv_bias.cpp b/dnn/test/fallback/conv_bias.cpp index 194d0c561..620477632 100644 --- a/dnn/test/fallback/conv_bias.cpp +++ b/dnn/test/fallback/conv_bias.cpp @@ -239,6 +239,74 @@ void checker_conv_bias( } } +TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_1X1_S1_MK4_PACK_F32) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, FULL_NLMODE, ALL_BIASMODE, 1, true); + check_conv_bias(args, handle(), "CONV1x1:FB_GI_F32_MK4_PACK_4x12:24"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S1_MK4_PACK_F32_PREPROCESS) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({2, 4, 7}, FULL_NLMODE, BR_AND_NO_BIASMODE, 1); +#define cb(name) \ + check_conv_bias_preprocess( \ + args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ + dtype::Float32(), dtype::Float32(), name); + cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); +#undef cb +} + +TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({3}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 2); +#define cb(name) \ + check_conv_bias_preprocess( \ + args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ + dtype::Float32(), dtype::Float32(), name); + cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); +#undef cb +} + +TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_1X1_S1_MK4_PACK_F32_PREPROCESS) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({1}, FULL_NLMODE, ALL_BIASMODE, 1, true); +#define cb(name) \ + check_conv_bias_preprocess( \ + args, handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ + dtype::Float32(), dtype::Float32(), name); + cb("CONV1x1:FB_GI_F32_MK4_PACK_4x12:24"); +#undef cb +} + +TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S1_MK4_PACK_F32) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({2, 4, 7}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 1); + check_conv_bias(args, handle(), "IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({3, 5, 6}, FULL_NLMODE, BR_AND_BIAS_BIASMODE, 2); +#define cb(name) check_conv_bias(args, handle(), name); + cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); +#undef cb +} + +TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_GI_IM2COL_S2_MK4_PACK_F32_FUSE) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({3}, FULL_NLMODE, ALL_BIASMODE, 2); +#define cb(name) check_conv_bias(args, handle(), name); + cb("IM2COLMATMUL:FB_GI_F32_MK4_PACK_4x12"); +#undef cb +} + TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) { using namespace conv_bias; param::ConvBias cur_param; diff --git a/dnn/test/fallback/matrix_mul.cpp b/dnn/test/fallback/matrix_mul.cpp index b1b71a70b..197804ea7 100644 --- a/dnn/test/fallback/matrix_mul.cpp +++ b/dnn/test/fallback/matrix_mul.cpp @@ -42,12 +42,18 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); } -TEST_F(FALLBACK, MATRIX_MULF_GI_F32_4x12) { +TEST_F(FALLBACK, MATRIX_MUL_GI_F32_4x12) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "FB_GI_F32_4x12"); } +TEST_F(FALLBACK, MATRIX_MUL_GI_PACK_MK4) { + matrix_mul::check_matrix_mul( + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), + "FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4, 1); +} + TEST_F(FALLBACK, MATRIX_MUL_RECORD) { TaskRecordChecker checker(1); using Param = MatrixMul::Param; @@ -163,6 +169,13 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_FB_GI_F32_4x12) { "FB_GI_F32_4x12", param::MatrixMul::Format::DEFAULT); } +TEST_F(FALLBACK, BENCHMARK_MATRIX_MUL_GI_PACK_MK4) { + auto args = matrix_mul::get_benchmark_matmul_args(); + matrix_mul::benchmark_single_algo( + handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, + "FB_GI_F32_MK4_PACK_4x12", param::MatrixMul::Format::MK4); +} + #endif } // namespace test } // namespace megdnn -- GitLab