From d2184af3b2a9f27c4440dc9d8d1610f86e38eaa3 Mon Sep 17 00:00:00 2001 From: zjl <610098971@qq.com> Date: Sun, 5 Sep 2021 09:11:52 +0800 Subject: [PATCH] feat(dnn/src/x86/matmul): add matmul_6x16 for x86 --- CMakeLists.txt | 1 + dnn/src/fallback/conv_bias/im2col/algos.cpp | 8 +- dnn/src/fallback/matrix_mul/opr_impl.h | 1 + dnn/src/x86/avx_helper.h | 9 + dnn/src/x86/matrix_mul/algos.cpp | 70 + dnn/src/x86/matrix_mul/algos.h | 13 + dnn/src/x86/matrix_mul/f32/strategy.h | 2 + dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp | 1255 ++++++++++++++++++ dnn/src/x86/matrix_mul/opr_impl.cpp | 2 + dnn/src/x86/matrix_mul/opr_impl.h | 2 +- dnn/test/x86/accuracy_shake.cpp | 9 + dnn/test/x86/conv_bias.cpp | 250 ++++ dnn/test/x86/matrix_mul.cpp | 15 + 13 files changed, 1632 insertions(+), 5 deletions(-) create mode 100644 dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e367ae1b7..aae966f0c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,7 @@ project(MegEngine LANGUAGES C CXX VERSION ${MGB_VER_STRING}) set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) set(CMAKE_POLICY_DEFAULT_CMP0048 NEW) diff --git a/dnn/src/fallback/conv_bias/im2col/algos.cpp b/dnn/src/fallback/conv_bias/im2col/algos.cpp index eac03a594..9c3785694 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.cpp +++ b/dnn/src/fallback/conv_bias/im2col/algos.cpp @@ -70,16 +70,16 @@ static void choice_ohw_oc_block( fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) { //! calculate m_oc_tile_size in choice_ohw_oc_block() fucntion, //! when ohw_tile_size < this value ohw_tile_size = ohw - static constexpr size_t DEFAULT_OHW_MIN_TILE_SIZE = 32; + size_t DEFAULT_OHW_MIN_TILE_SIZE = round_up(32UL, block_n); //! when nr_threads > 1 and round(ohw,nr_threads)>nr_threads, //! oc_tile_size = DEFAULT_OC_TILE_SIZE - static constexpr size_t DEFAULT_OC_TILE_SIZE = 512; + size_t DEFAULT_OC_TILE_SIZE = round_up(512UL, block_m); //! when oc_tile_size > this value m_oc_tile_size = //! DEFAULT_OC_MAX_TILE_SIZE - static constexpr size_t DEFAULT_OC_MAX_TILE_SIZE = 1024; + size_t DEFAULT_OC_MAX_TILE_SIZE = round_up(1024UL, block_m); //! when oc_tile_size < this value oc_tile_size = //! DEFAULT_OC_MIN_TILE_SIZE the purpose is aligning the calculation - static constexpr size_t DEFAULT_OC_MIN_TILE_SIZE = 128; + size_t DEFAULT_OC_MIN_TILE_SIZE = round_up(128UL, block_m);; size_t nr_threads = param.nr_threads; size_t OC = param.filter_meta.ocpg; size_t ohw = param.osz[0] * param.osz[1]; diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index d13aede45..0c0af0f1b 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -122,6 +122,7 @@ public: X86_INT8X8X16_SSE, X86_INT8X8X32_SSE_4X8X2, X86_F32_MK8_8X8, + X86_F32_6x16, X86_INT8X8X32_VNNI, X86_INT8X8X32_MKLDNN, #elif MEGDNN_AARCH64 || MEGDNN_ARMV7 diff --git a/dnn/src/x86/avx_helper.h b/dnn/src/x86/avx_helper.h index 86974b7d2..c13311414 100644 --- a/dnn/src/x86/avx_helper.h +++ b/dnn/src/x86/avx_helper.h @@ -31,6 +31,15 @@ static inline __m256 _mm256_loadu2_m128_emulate( _mm_loadu_ps(hiaddr), 1); } +MEGDNN_ATTRIBUTE_TARGET("avx") +static inline void _mm256_storeu2_m128_emulate(float *hiaddr, float *loaddr, + __m256 reg) { + auto xmm0 = _mm256_extractf128_ps(reg, 0); + auto xmm1 = _mm256_extractf128_ps(reg, 1); + _mm_storeu_ps(loaddr, xmm0); + _mm_storeu_ps(hiaddr, xmm1); +} + template struct Vector; diff --git a/dnn/src/x86/matrix_mul/algos.cpp b/dnn/src/x86/matrix_mul/algos.cpp index 9425e4cf9..3edf0d959 100644 --- a/dnn/src/x86/matrix_mul/algos.cpp +++ b/dnn/src/x86/matrix_mul/algos.cpp @@ -320,6 +320,35 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) { MIDOUT_END(); } +void gemm_f32_avx2_6x16(const MatrixMulImpl::KernParam& kern_param) { + MEGDNN_MARK_USED_VAR(kern_param); + MIDOUT_BEGIN(megdnn_x86_matmul_kern_avx2_6x16x2, midout_iv(0)) { + constexpr int cacheline = 64; + const size_t m = kern_param.M; + const size_t n = kern_param.N; + const size_t k = kern_param.K; + const bool trans_a = kern_param.trA; + const bool trans_b = kern_param.trB; + const size_t lda = kern_param.LDA; + const size_t ldb = kern_param.LDB; + const size_t ldc = kern_param.LDC; + auto a_type = kern_param.A_type; + auto b_type = kern_param.B_type; + auto c_type = kern_param.C_type; + const auto a_ptr = kern_param.A(); + const auto b_ptr = kern_param.B(); + auto c_ptr = kern_param.C(); + x86::matmul::sgemm_pack_6x16_avx2 strategy(m, n, k, a_type, b_type, + c_type); + + megdnn::matmul::GemmInterleaved( + m, n, k, trans_a, trans_b, strategy, cacheline) + .execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} + } // namespace /*************************AlgoInt8x8x16AVX2********************/ @@ -662,4 +691,45 @@ size_t MatrixMulImpl::AlgoF32MK8_8x8::get_workspace( MIDOUT_END(); } +/*************************AlgoFloatAVX2M6N16********************/ +MatrixMulImpl::kern_t MatrixMulImpl::AlgoFloatAVX2M6N16::get_kern( + const KernSizeParam&) const { + return gemm_f32_avx2_6x16; +} +bool MatrixMulImpl::AlgoFloatAVX2M6N16::usable( + const KernSizeParam& kern_size_param) const { + bool is_param_ok = + kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && + ((kern_size_param.A_type.enumv() == DTypeEnum::Float32 && + kern_size_param.C_type.enumv() == DTypeEnum::Float32)) && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == Param::Format::DEFAULT && + is_supported(SIMDType::AVX2); + return is_param_ok; +} +size_t MatrixMulImpl::AlgoFloatAVX2M6N16::get_workspace( + const KernSizeParam& kern_param) const { + constexpr int cacheline = 64; + const size_t m = kern_param.M; + const size_t n = kern_param.N; + const size_t k = kern_param.K; + const bool trans_a = kern_param.trA; + const bool trans_b = kern_param.trB; + auto a_type = kern_param.A_type; + auto b_type = kern_param.B_type; + auto c_type = kern_param.C_type; + x86::matmul::sgemm_pack_6x16_avx2 strategy(m, n, k, a_type, b_type, + c_type); + + return megdnn::matmul::GemmInterleaved< + x86::matmul::sgemm_pack_6x16_avx2>( + m, n, k, trans_a, trans_b, strategy, cacheline) + .get_workspace_size(); +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( + AlgoFloatAVX2M6N16, megdnn_x86_matmul_kern, + "AlgoFloatAVX2M6N16"_hash, x86::matmul::sgemm_pack_6x16_avx2, + float, float, float, AlgoDataType::FLOAT32, DEFAULT); + // vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/matrix_mul/algos.h b/dnn/src/x86/matrix_mul/algos.h index e1fe4cb93..698584af6 100644 --- a/dnn/src/x86/matrix_mul/algos.h +++ b/dnn/src/x86/matrix_mul/algos.h @@ -149,6 +149,19 @@ public: MEGDNN_DECL_ALGO_TYPE(X86_F32_MK8_8X8) }; +class MatrixMulImpl::AlgoFloatAVX2M6N16 : public AlgoBase { +public: + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE; + } + const char *name() const override { return "X86_F32_6x16"; } + bool usable(const KernSizeParam &) const override; + size_t get_workspace(const KernSizeParam &) const override; + kern_t get_kern(const KernSizeParam &) const override; + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(X86_F32_6x16) +}; + #if MEGDNN_X86_WITH_VNNI class MatrixMulImpl::AlgoInt8x8x32Vnni : public AlgoBase { public: diff --git a/dnn/src/x86/matrix_mul/f32/strategy.h b/dnn/src/x86/matrix_mul/f32/strategy.h index 7f7990023..9adbbd928 100644 --- a/dnn/src/x86/matrix_mul/f32/strategy.h +++ b/dnn/src/x86/matrix_mul/f32/strategy.h @@ -19,6 +19,8 @@ namespace matmul { MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 8, 8, 8, false, true, sgemm_nopack_8x8_avx2); +MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(float, float, float, float, + 6, 16, 1, false, false, sgemm_pack_6x16_avx2); } // namespace matmul } // namespace x86 } // namespace megdnn \ No newline at end of file diff --git a/dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp b/dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp new file mode 100644 index 000000000..48a4cc1ae --- /dev/null +++ b/dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp @@ -0,0 +1,1255 @@ +/** + * \file dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +/** + * \file dnn/src/x86/matrix_mul/f32/strategy_6x16.cpp + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright copyright (c) 2014-2019 megvii inc. all rights reserved. + */ +#include + +#include "src/common/utils.h" +#include "src/x86/avx_helper.h" +#include "src/x86/matrix_mul/common/common.h" +#include "src/x86/matrix_mul/f32/strategy.h" +#include "src/common/unroll_macro.h" + +using namespace megdnn; +using namespace x86; + +#define DNN_AVX2_TARGET +#if !defined(__clang__) +//! bypass gcc bug https://bugs.launchpad.net/ubuntu/+source/gcc-5/+bug/1642109 +#pragma GCC target("avx2") +#else +#undef DNN_AVX2_TARGET +#define DNN_AVX2_TARGET MEGDNN_ATTRIBUTE_TARGET("avx2") +#endif + +#define UNROLL_CODE(cb, i, a...) UNROLL_CALL1(i,cb,##a) +namespace { + +DNN_AVX2_TARGET +void transpose_16x8_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + const float *inptr4, const float *inptr5, + const float *inptr6, const float *inptr7, + const float *inptr8, const float *inptr9, + const float *inptr10, const float *inptr11, + const float *inptr12, const float *inptr13, + const float *inptr14, const float *inptr15, + float *outptr) { + auto ymm0 = _mm256_loadu_ps(inptr0); // A0A1A2A3A4A5A6A7 + auto ymm1 = _mm256_loadu_ps(inptr1); // B0B1B2B3B4B5B6B7 + auto ymm2 = _mm256_loadu_ps(inptr2); // C0C1C2C3C4C5C6C7 + auto ymm3 = _mm256_loadu_ps(inptr3); // D0D1D2D3D4D5D6D7 + auto ymm4 = _mm256_loadu_ps(inptr4); // E0E1E2E3E4E5E6E7 + auto ymm5 = _mm256_loadu_ps(inptr5); // F0F1F2F3F4F5F6F7 + auto ymm6 = _mm256_loadu_ps(inptr6); // G0G1G2G3G4G5G6G7 + auto ymm7 = _mm256_loadu_ps(inptr7); // H0H1H2H3H4H5H6H7 + + auto ymm8 = _mm256_unpacklo_ps(ymm0, ymm2); // A0C0A1C1A4C4A5C5 + auto ymm9 = _mm256_unpackhi_ps(ymm0, ymm2); // A2C2A3C3A6C6A7C7 + auto ymm10 = _mm256_unpacklo_ps(ymm1, ymm3); // B0D0B1D1B4D4B5D5 + auto ymm11 = _mm256_unpackhi_ps(ymm1, ymm3); // B2D2B3D3B6D6B7D7 + auto ymm12 = _mm256_unpacklo_ps(ymm4, ymm6); // E0G0E1G1E4G4E5G5 + auto ymm13 = _mm256_unpackhi_ps(ymm4, ymm6); // E2G2E3G3E6G6E7G7 + auto ymm14 = _mm256_unpacklo_ps(ymm5, ymm7); // F0H0F1H1F4H4F5H5 + auto ymm15 = _mm256_unpackhi_ps(ymm5, ymm7); // F2H2F3H3F6H6F7H7 + + ymm0 = _mm256_unpacklo_ps(ymm8, ymm10); // A0B0C0D0A4B4C4D4 + ymm1 = _mm256_unpackhi_ps(ymm8, ymm10); // A1B1C1D1A5B5C5D5 + ymm2 = _mm256_unpacklo_ps(ymm9, ymm11); // A2B2C2D2A6B6C6D6 + ymm3 = _mm256_unpackhi_ps(ymm9, ymm11); // A3B3C3D3A7B7C7D7 + ymm4 = _mm256_unpacklo_ps(ymm12, ymm14); // E0F0G0H0E4F4G4H4 + ymm5 = _mm256_unpackhi_ps(ymm12, ymm14); // E1F1G1H1E5F5G5H5 + ymm6 = _mm256_unpacklo_ps(ymm13, ymm15); // E2F2G2H2E6F6G6H6 + ymm7 = _mm256_unpackhi_ps(ymm13, ymm15); // E3F3G3H3E7F7G7H7 + + ymm8 = _mm256_permute2f128_ps(ymm0, ymm4, 0x20); // A0B0C0D0E0F0G0H0 + ymm9 = _mm256_permute2f128_ps(ymm1, ymm5, 0x20); // A1B1C1D1E1F1G1H1 + ymm10 = _mm256_permute2f128_ps(ymm2, ymm6, 0x20); // A2B2C2D2E2F2G2H2 + ymm11 = _mm256_permute2f128_ps(ymm3, ymm7, 0x20); // A3B3C3D3E3F3G3H3 + ymm12 = _mm256_permute2f128_ps(ymm0, ymm4, 0x31); // A4B4C4D4E4F4G4H4 + ymm13 = _mm256_permute2f128_ps(ymm1, ymm5, 0x31); // A5B5C5D5E5F5G5H5 + ymm14 = _mm256_permute2f128_ps(ymm2, ymm6, 0x31); // A6B6C6D6E6F6G6H6 + ymm15 = _mm256_permute2f128_ps(ymm3, ymm7, 0x31); // A7B7C7D7E7F7G7H7 + + _mm256_storeu_ps(outptr + 16 * 0, ymm8); + _mm256_storeu_ps(outptr + 16 * 1, ymm9); + _mm256_storeu_ps(outptr + 16 * 2, ymm10); + _mm256_storeu_ps(outptr + 16 * 3, ymm11); + _mm256_storeu_ps(outptr + 16 * 4, ymm12); + _mm256_storeu_ps(outptr + 16 * 5, ymm13); + _mm256_storeu_ps(outptr + 16 * 6, ymm14); + _mm256_storeu_ps(outptr + 16 * 7, ymm15); + ymm0 = _mm256_loadu_ps(inptr8); // A0A1A2A3A4A5A6A7 + ymm1 = _mm256_loadu_ps(inptr9); // B0B1B2B3B4B5B6B7 + ymm2 = _mm256_loadu_ps(inptr10); // C0C1C2C3C4C5C6C7 + ymm3 = _mm256_loadu_ps(inptr11); // D0D1D2D3D4D5D6D7 + ymm4 = _mm256_loadu_ps(inptr12); // E0E1E2E3E4E5E6E7 + ymm5 = _mm256_loadu_ps(inptr13); // F0F1F2F3F4F5F6F7 + ymm6 = _mm256_loadu_ps(inptr14); // G0G1G2G3G4G5G6G7 + ymm7 = _mm256_loadu_ps(inptr15); // H0H1H2H3H4H5H6H7 + + ymm8 = _mm256_unpacklo_ps(ymm0, ymm2); // A0C0A1C1A4C4A5C5 + ymm9 = _mm256_unpackhi_ps(ymm0, ymm2); // A2C2A3C3A6C6A7C7 + ymm10 = _mm256_unpacklo_ps(ymm1, ymm3); // B0D0B1D1B4D4B5D5 + ymm11 = _mm256_unpackhi_ps(ymm1, ymm3); // B2D2B3D3B6D6B7D7 + ymm12 = _mm256_unpacklo_ps(ymm4, ymm6); // E0G0E1G1E4G4E5G5 + ymm13 = _mm256_unpackhi_ps(ymm4, ymm6); // E2G2E3G3E6G6E7G7 + ymm14 = _mm256_unpacklo_ps(ymm5, ymm7); // F0H0F1H1F4H4F5H5 + ymm15 = _mm256_unpackhi_ps(ymm5, ymm7); // F2H2F3H3F6H6F7H7 + + ymm0 = _mm256_unpacklo_ps(ymm8, ymm10); // A0B0C0D0A4B4C4D4 + ymm1 = _mm256_unpackhi_ps(ymm8, ymm10); // A1B1C1D1A5B5C5D5 + ymm2 = _mm256_unpacklo_ps(ymm9, ymm11); // A2B2C2D2A6B6C6D6 + ymm3 = _mm256_unpackhi_ps(ymm9, ymm11); // A3B3C3D3A7B7C7D7 + ymm4 = _mm256_unpacklo_ps(ymm12, ymm14); // E0F0G0H0E4F4G4H4 + ymm5 = _mm256_unpackhi_ps(ymm12, ymm14); // E1F1G1H1E5F5G5H5 + ymm6 = _mm256_unpacklo_ps(ymm13, ymm15); // E2F2G2H2E6F6G6H6 + ymm7 = _mm256_unpackhi_ps(ymm13, ymm15); // E3F3G3H3E7F7G7H7 + + ymm8 = _mm256_permute2f128_ps(ymm0, ymm4, 0x20); // A0B0C0D0E0F0G0H0 + ymm9 = _mm256_permute2f128_ps(ymm1, ymm5, 0x20); // A1B1C1D1E1F1G1H1 + ymm10 = _mm256_permute2f128_ps(ymm2, ymm6, 0x20); // A2B2C2D2E2F2G2H2 + ymm11 = _mm256_permute2f128_ps(ymm3, ymm7, 0x20); // A3B3C3D3E3F3G3H3 + ymm12 = _mm256_permute2f128_ps(ymm0, ymm4, 0x31); // A4B4C4D4E4F4G4H4 + ymm13 = _mm256_permute2f128_ps(ymm1, ymm5, 0x31); // A5B5C5D5E5F5G5H5 + ymm14 = _mm256_permute2f128_ps(ymm2, ymm6, 0x31); // A6B6C6D6E6F6G6H6 + ymm15 = _mm256_permute2f128_ps(ymm3, ymm7, 0x31); // A7B7C7D7E7F7G7H7 + + _mm256_storeu_ps(outptr + 16 * 0 + 8, ymm8); + _mm256_storeu_ps(outptr + 16 * 1 + 8, ymm9); + _mm256_storeu_ps(outptr + 16 * 2 + 8, ymm10); + _mm256_storeu_ps(outptr + 16 * 3 + 8, ymm11); + _mm256_storeu_ps(outptr + 16 * 4 + 8, ymm12); + _mm256_storeu_ps(outptr + 16 * 5 + 8, ymm13); + _mm256_storeu_ps(outptr + 16 * 6 + 8, ymm14); + _mm256_storeu_ps(outptr + 16 * 7 + 8, ymm15); +} + +DNN_AVX2_TARGET +void transpose_16x4_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + const float *inptr4, const float *inptr5, + const float *inptr6, const float *inptr7, + const float *inptr8, const float *inptr9, + const float *inptr10, const float *inptr11, + const float *inptr12, const float *inptr13, + const float *inptr14, const float *inptr15, + float *outptr) { + const std::uint32_t arr[8] = {0, 1, 4, 5, 2, 3, 6, 7}; + __m256i order = _mm256_loadu_si256((const __m256i *)arr); + auto ymm0 = _mm256_loadu2_m128_emulate(inptr2, inptr0); // A0A1A2A3C0C1C2C3 + auto ymm1 = _mm256_loadu2_m128_emulate(inptr3, inptr1); // B0B1B2B3D0D1D2D3 + auto ymm2 = _mm256_loadu2_m128_emulate(inptr6, inptr4); // E0E1E2E3G0G1G2G3 + auto ymm3 = _mm256_loadu2_m128_emulate(inptr7, inptr5); // F0F1F2F3H0H1H2H3 + + auto ymm4 = _mm256_unpacklo_ps(ymm0, ymm1); // A0B0A1B1C0D0C1D1 + auto ymm5 = _mm256_unpackhi_ps(ymm0, ymm1); // A2B2A3B3C2D2C3D3 + auto ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); // E0F0E1F1G0H0G1H1 + auto ymm7 = _mm256_unpackhi_ps(ymm2, ymm3); // E2F2E3F3G2H2G3H3 + + auto ymm8 = _mm256_permutevar8x32_ps(ymm4, order); // A0B0C0D0A1B1C1D1 + auto ymm9 = _mm256_permutevar8x32_ps(ymm5, order); // A2B2C2D2A3B3C3D3 + auto ymm10 = _mm256_permutevar8x32_ps(ymm6, order); // E0F0G0H0E1F1G1H1 + auto ymm11 = _mm256_permutevar8x32_ps(ymm7, order); // E2F2G2H2E3F3G3H3 + + ymm0 = _mm256_permute2f128_ps(ymm8, ymm10, 0x20); // A0B0C0D0E0F0G0H0 + ymm1 = _mm256_permute2f128_ps(ymm8, ymm10, 0x31); // A1B1C1D1E1F1G1H1 + ymm2 = _mm256_permute2f128_ps(ymm9, ymm11, 0x20); // A2B2C2D2E2F2G2H2 + ymm3 = _mm256_permute2f128_ps(ymm9, ymm11, 0x31); // A3B3C3D3E3F3G3H3 + + _mm256_storeu_ps(outptr + 16 * 0, ymm0); + _mm256_storeu_ps(outptr + 16 * 1, ymm1); + _mm256_storeu_ps(outptr + 16 * 2, ymm2); + _mm256_storeu_ps(outptr + 16 * 3, ymm3); + ymm0 = _mm256_loadu2_m128_emulate(inptr10, inptr8); // A0A1A2A3C0C1C2C3 + ymm1 = _mm256_loadu2_m128_emulate(inptr11, inptr9); // B0B1B2B3D0D1D2D3 + ymm2 = _mm256_loadu2_m128_emulate(inptr14, inptr12); // E0E1E2E3G0G1G2G3 + ymm3 = _mm256_loadu2_m128_emulate(inptr15, inptr13); // F0F1F2F3H0H1H2H3 + + ymm4 = _mm256_unpacklo_ps(ymm0, ymm1); // A0B0A1B1C0D0C1D1 + ymm5 = _mm256_unpackhi_ps(ymm0, ymm1); // A2B2A3B3C2D2C3D3 + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); // E0F0E1F1G0H0G1H1 + ymm7 = _mm256_unpackhi_ps(ymm2, ymm3); // E2F2E3F3G2H2G3H3 + + ymm8 = _mm256_permutevar8x32_ps(ymm4, order); // A0B0C0D0A1B1C1D1 + ymm9 = _mm256_permutevar8x32_ps(ymm5, order); // A2B2C2D2A3B3C3D3 + ymm10 = _mm256_permutevar8x32_ps(ymm6, order); // E0F0G0H0E1F1G1H1 + ymm11 = _mm256_permutevar8x32_ps(ymm7, order); // E2F2G2H2E3F3G3H3 + + ymm0 = _mm256_permute2f128_ps(ymm8, ymm10, 0x20); // A0B0C0D0E0F0G0H0 + ymm1 = _mm256_permute2f128_ps(ymm8, ymm10, 0x31); // A1B1C1D1E1F1G1H1 + ymm2 = _mm256_permute2f128_ps(ymm9, ymm11, 0x20); // A2B2C2D2E2F2G2H2 + ymm3 = _mm256_permute2f128_ps(ymm9, ymm11, 0x31); // A3B3C3D3E3F3G3H3 + + _mm256_storeu_ps(outptr + 16 * 0 + 8, ymm0); + _mm256_storeu_ps(outptr + 16 * 1 + 8, ymm1); + _mm256_storeu_ps(outptr + 16 * 2 + 8, ymm2); + _mm256_storeu_ps(outptr + 16 * 3 + 8, ymm3); +} + +static size_t min(size_t a, size_t b) { return a > b ? b : a; } + +DNN_AVX2_TARGET +void transpose_6x16_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + const float *inptr4, const float *inptr5, + float *outptr) { + auto ymm0 = _mm256_loadu_ps(inptr0 + 0); // A0A1A2A3A4A5A6A7 + auto ymm1 = _mm256_loadu_ps(inptr0 + 8); // a0a1a2a3a4a5a6a7 + auto ymm2 = _mm256_loadu_ps(inptr1 + 0); // B0B1B2B3B4B5B6B7 + auto ymm3 = _mm256_loadu_ps(inptr1 + 8); // b0b1b2b3b4b5b6b7 + auto ymm4 = _mm256_loadu_ps(inptr2 + 0); // C0C1C2C3C4C5C6C7 + auto ymm5 = _mm256_loadu_ps(inptr2 + 8); // c0c1c2c3c4c5c6c7 + auto ymm6 = _mm256_loadu_ps(inptr3 + 0); // D0D1D2D3D4D5D6D7 + auto ymm7 = _mm256_loadu_ps(inptr3 + 8); // d0d1d2d3d4d5d6d7 + + auto ymm8 = _mm256_unpacklo_ps(ymm0, ymm4); // A0C0A1C1A4C4A5C5 + auto ymm9 = _mm256_unpackhi_ps(ymm0, ymm4); // A2C2A3C3A6C6A7C7 + auto ymm10 = _mm256_unpacklo_ps(ymm2, ymm6); // B0D0B1D1B4D4B5D5 + auto ymm11 = _mm256_unpackhi_ps(ymm2, ymm6); // B2D2B3D3B6D6B7D7 + + auto ymm12 = _mm256_unpacklo_ps(ymm1, ymm5); // a0c0a1c1a4c4a5c5 + auto ymm13 = _mm256_unpackhi_ps(ymm1, ymm5); // a2c2a3c3a6c6a7c7 + auto ymm14 = _mm256_unpacklo_ps(ymm3, ymm7); // b0d0b1d1b4d4b5d5 + auto ymm15 = _mm256_unpackhi_ps(ymm3, ymm7); // b2d2b3d3b6d6b7d7 + + ymm0 = _mm256_unpacklo_ps(ymm8, ymm10); // A0B0C0D0A4B4C4D4 + ymm1 = _mm256_unpackhi_ps(ymm8, ymm10); // A1B1C1D1A5B5C5D5 + ymm2 = _mm256_unpacklo_ps(ymm9, ymm11); // A2B2C2D2A6B6C6D6 + ymm3 = _mm256_unpackhi_ps(ymm9, ymm11); // A3B3C3D3A7B7C7D7 + + ymm4 = _mm256_unpacklo_ps(ymm12, ymm14); // a0b0c0d0a4b4c4d4 + ymm5 = _mm256_unpackhi_ps(ymm12, ymm14); // a1b1c1d1a5b5c5d5 + ymm6 = _mm256_unpacklo_ps(ymm13, ymm15); // a2b2c2d2a6b6c6d6 + ymm7 = _mm256_unpackhi_ps(ymm13, ymm15); // a3b3c3d3a7b7c7d7 + + _mm256_storeu2_m128_emulate(outptr + 6 * 4, outptr + 6 * 0, ymm0); + _mm256_storeu2_m128_emulate(outptr + 6 * 5, outptr + 6 * 1, ymm1); + _mm256_storeu2_m128_emulate(outptr + 6 * 6, outptr + 6 * 2, ymm2); + _mm256_storeu2_m128_emulate(outptr + 6 * 7, outptr + 6 * 3, ymm3); + _mm256_storeu2_m128_emulate(outptr + 6 * 12, outptr + 6 * 8, ymm4); + _mm256_storeu2_m128_emulate(outptr + 6 * 13, outptr + 6 * 9, ymm5); + _mm256_storeu2_m128_emulate(outptr + 6 * 14, outptr + 6 * 10, ymm6); + _mm256_storeu2_m128_emulate(outptr + 6 * 15, outptr + 6 * 11, ymm7); + + float other[4 * 8]; + ymm8 = _mm256_loadu_ps(inptr4 + 0); // E0E1E2E3E4E5E6E7 + ymm9 = _mm256_loadu_ps(inptr4 + 8); // e0e1e2e3e4e5e6e7 + ymm10 = _mm256_loadu_ps(inptr5 + 0); // F0F1F2F3F4F5F6F7 + ymm11 = _mm256_loadu_ps(inptr5 + 8); // f0f1f2f3f4f5f6f7 + _mm256_storeu_ps(other, ymm8); + _mm256_storeu_ps(other + 8, ymm9); + _mm256_storeu_ps(other + 16, ymm10); + _mm256_storeu_ps(other + 24, ymm11); + + for (size_t i = 0; i < 16; i++) { + outptr[6 * i + 4] = other[i]; + outptr[6 * i + 5] = other[i + 16]; + } +} + +DNN_AVX2_TARGET +void transpose_6x8_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + const float *inptr4, const float *inptr5, + float *outptr) { + auto ymm0 = _mm256_loadu_ps(inptr0); // A0A1A2A3A4A5A6A7 + auto ymm1 = _mm256_loadu_ps(inptr1); // B0B1B2B3B4B5B6B7 + auto ymm2 = _mm256_loadu_ps(inptr2); // C0C1C2C3C4C5C6C7 + auto ymm3 = _mm256_loadu_ps(inptr3); // D0D1D2D3D4D5D6D7 + + auto ymm4 = _mm256_unpacklo_ps(ymm0, ymm2); // A0C0A1C1A4C4A5C5 + auto ymm5 = _mm256_unpackhi_ps(ymm0, ymm2); // A2C2A3C3A6C6A7C7 + auto ymm6 = _mm256_unpacklo_ps(ymm1, ymm3); // B0D0B1D1B4D4B5D5 + auto ymm7 = _mm256_unpackhi_ps(ymm1, ymm3); // B2D2B3D3B6D6B7D7 + + auto ymm8 = _mm256_unpacklo_ps(ymm4, ymm6); // A0B0C0D0A4B4C4D4 + auto ymm9 = _mm256_unpackhi_ps(ymm4, ymm6); // A1B1C1D1A5B5C5D5 + auto ymm10 = _mm256_unpacklo_ps(ymm5, ymm7); // A2B2C2D2A6B6C6D6 + auto ymm11 = _mm256_unpackhi_ps(ymm5, ymm7); // A3B3C3D3A7B7C7D7 + + _mm256_storeu2_m128_emulate(outptr + 6 * 4, outptr + 6 * 0, ymm8); + _mm256_storeu2_m128_emulate(outptr + 6 * 5, outptr + 6 * 1, ymm9); + _mm256_storeu2_m128_emulate(outptr + 6 * 6, outptr + 6 * 2, ymm10); + _mm256_storeu2_m128_emulate(outptr + 6 * 7, outptr + 6 * 3, ymm11); + float other[16]; + auto ymm12 = _mm256_loadu_ps(inptr4); // E0E1E2E3E4E5E6E7 + auto ymm13 = _mm256_loadu_ps(inptr5); // F0F1F2F3F4F5F6F7 + _mm256_storeu_ps(other, ymm12); + _mm256_storeu_ps(other + 8, ymm13); + + for (size_t i = 0; i < 8; i++) { + outptr[6 * i + 4] = other[i]; + outptr[6 * i + 5] = other[8 + i]; + } +} + +DNN_AVX2_TARGET +void transpose_6x4_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + const float *inptr4, const float *inptr5, + float *outptr) { + const std::uint32_t arr[8] = {0, 1, 4, 5, 2, 3, 6, 7}; + __m256i order = _mm256_loadu_si256((const __m256i *)arr); + auto ymm0 = _mm256_loadu2_m128_emulate(inptr2, inptr0); // A0A1A2A3C0C1C2C3 + auto ymm1 = _mm256_loadu2_m128_emulate(inptr3, inptr1); // B0B1B2B3D0D1D2D3 + auto ymm2 = _mm256_unpacklo_ps(ymm0, ymm1); // A0B0A1B1C0D0C1D1 + auto ymm3 = _mm256_unpackhi_ps(ymm0, ymm1); // A2B2A3B3C2D2C3D3 + auto ymm4 = _mm256_permutevar8x32_ps(ymm2, order); // A0B0C0D0A1B1C1D1 + auto ymm5 = _mm256_permutevar8x32_ps(ymm3, order); // A2B2C2D2A3B3C3D3 + + _mm256_storeu2_m128_emulate(outptr + 6 * 1, outptr + 6 * 0, ymm4); + _mm256_storeu2_m128_emulate(outptr + 6 * 3, outptr + 6 * 2, ymm5); + float other[8]; + auto ymm6 = _mm256_loadu2_m128_emulate(inptr5, inptr4); // E0E1E2E3E4E5E6E7 + _mm256_storeu_ps(other, ymm6); + + for (size_t i = 0; i < 4; i++) { + outptr[6 * i + 4] = other[i]; + outptr[6 * i + 5] = other[4 + i]; + } +} + +DNN_AVX2_TARGET +void transpose_4x8_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + float *outptr) { + auto ymm0 = _mm256_loadu_ps(inptr0); // A0A1A2A3A4A5A6A7 + auto ymm1 = _mm256_loadu_ps(inptr1); // B0B1B2B3B4B5B6B7 + auto ymm2 = _mm256_loadu_ps(inptr2); // C0C1C2C3C4C5C6C7 + auto ymm3 = _mm256_loadu_ps(inptr3); // D0D1D2D3D4D5D6D7 + + auto ymm4 = _mm256_unpacklo_ps(ymm0, ymm2); // A0C0A1C1A4C4A5C5 + auto ymm5 = _mm256_unpackhi_ps(ymm0, ymm2); // A2C2A3C3A6C6A7C7 + auto ymm6 = _mm256_unpacklo_ps(ymm1, ymm3); // B0D0B1D1B4D4B5D5 + auto ymm7 = _mm256_unpackhi_ps(ymm1, ymm3); // B2D2B3D3B6D6B7D7 + + auto ymm8 = _mm256_unpacklo_ps(ymm4, ymm6); // A0B0C0D0A4B4C4D4 + auto ymm9 = _mm256_unpackhi_ps(ymm4, ymm6); // A1B1C1D1A5B5C5D5 + auto ymm10 = _mm256_unpacklo_ps(ymm5, ymm7); // A2B2C2D2A6B6C6D6 + auto ymm11 = _mm256_unpackhi_ps(ymm5, ymm7); // A3B3C3D3A7B7C7D7 + + ymm0 = _mm256_permute2f128_ps(ymm8, ymm9, 0x20); // A0B0C0D0A1B1C1D1 + ymm1 = _mm256_permute2f128_ps(ymm10, ymm11, 0x20); // A2B2C2D2A3B3C3D3 + ymm2 = _mm256_permute2f128_ps(ymm8, ymm9, 0x31); // A4B4C4D4A5B5C5D5 + ymm3 = _mm256_permute2f128_ps(ymm10, ymm11, 0x31); // A6B6C6D6A7B7C7D7 + + _mm256_storeu_ps(outptr + 8 * 0, ymm0); + _mm256_storeu_ps(outptr + 8 * 1, ymm1); + _mm256_storeu_ps(outptr + 8 * 2, ymm2); + _mm256_storeu_ps(outptr + 8 * 3, ymm3); +} + +DNN_AVX2_TARGET +void transpose_4x4_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + float *outptr) { + const std::uint32_t arr[8] = {0, 1, 4, 5, 2, 3, 6, 7}; + __m256i order = _mm256_loadu_si256((const __m256i *)arr); + auto ymm0 = _mm256_loadu2_m128_emulate(inptr2, inptr0); // A0A1A2A3C0C1C2C3 + auto ymm1 = _mm256_loadu2_m128_emulate(inptr3, inptr1); // B0B1B2B3D0D1D2D3 + auto ymm2 = _mm256_unpacklo_ps(ymm0, ymm1); // A0B0A1B1C0D0C1D1 + auto ymm3 = _mm256_unpackhi_ps(ymm0, ymm1); // A2B2A3B3C2D2C3D3 + auto ymm4 = _mm256_permutevar8x32_ps(ymm2, order); // A0B0C0D0A1B1C1D1 + auto ymm5 = _mm256_permutevar8x32_ps(ymm3, order); // A2B2C2D2A3B3C3D3 + _mm256_storeu_ps(outptr, ymm4); + _mm256_storeu_ps(outptr + 8, ymm5); +} + +void transpose_2x16_1_s(const float *inptr0, const float *inptr1, + float *outptr) { + for (size_t i = 0; i < 16; i++) { + *outptr++ = inptr0[i]; + *outptr++ = inptr1[i]; + } +} +void transpose_2x8_1_s(const float *inptr0, const float *inptr1, + float *outptr) { + for (size_t i = 0; i < 8; i++) { + *outptr++ = inptr0[i]; + *outptr++ = inptr1[i]; + } +} +void transpose_2x4_1_s(const float *inptr0, const float *inptr1, + float *outptr) { + for (size_t i = 0; i < 4; i++) { + *outptr++ = inptr0[i]; + *outptr++ = inptr1[i]; + } +} + +DNN_AVX2_TARGET +void interleave_1x16_1_s(const float *inptr0, float *outptr) { + auto ymm0 = _mm256_loadu_ps(inptr0); + auto ymm1 = _mm256_loadu_ps(inptr0 + 8); + _mm256_storeu_ps(outptr, ymm0); + _mm256_storeu_ps(outptr + 8, ymm1); +} + +DNN_AVX2_TARGET +void interleave_8x16_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + const float *inptr4, const float *inptr5, + const float *inptr6, const float *inptr7, + float *outptr) { + auto ymm0 = _mm256_loadu_ps(inptr0); + auto ymm1 = _mm256_loadu_ps(inptr0 + 8); + auto ymm2 = _mm256_loadu_ps(inptr1); + auto ymm3 = _mm256_loadu_ps(inptr1 + 8); + auto ymm4 = _mm256_loadu_ps(inptr2); + auto ymm5 = _mm256_loadu_ps(inptr2 + 8); + auto ymm6 = _mm256_loadu_ps(inptr3); + auto ymm7 = _mm256_loadu_ps(inptr3 + 8); + auto ymm8 = _mm256_loadu_ps(inptr4); + auto ymm9 = _mm256_loadu_ps(inptr4 + 8); + auto ymm10 = _mm256_loadu_ps(inptr5); + auto ymm11 = _mm256_loadu_ps(inptr5 + 8); + auto ymm12 = _mm256_loadu_ps(inptr6); + auto ymm13 = _mm256_loadu_ps(inptr6 + 8); + auto ymm14 = _mm256_loadu_ps(inptr7); + auto ymm15 = _mm256_loadu_ps(inptr7 + 8); + + _mm256_storeu_ps(outptr + 8 * 0, ymm0); + _mm256_storeu_ps(outptr + 8 * 1, ymm1); + _mm256_storeu_ps(outptr + 8 * 2, ymm2); + _mm256_storeu_ps(outptr + 8 * 3, ymm3); + _mm256_storeu_ps(outptr + 8 * 4, ymm4); + _mm256_storeu_ps(outptr + 8 * 5, ymm5); + _mm256_storeu_ps(outptr + 8 * 6, ymm6); + _mm256_storeu_ps(outptr + 8 * 7, ymm7); + _mm256_storeu_ps(outptr + 8 * 8, ymm8); + _mm256_storeu_ps(outptr + 8 * 9, ymm9); + _mm256_storeu_ps(outptr + 8 * 10, ymm10); + _mm256_storeu_ps(outptr + 8 * 11, ymm11); + _mm256_storeu_ps(outptr + 8 * 12, ymm12); + _mm256_storeu_ps(outptr + 8 * 13, ymm13); + _mm256_storeu_ps(outptr + 8 * 14, ymm14); + _mm256_storeu_ps(outptr + 8 * 15, ymm15); +} + +DNN_AVX2_TARGET +void interleave_8x4_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + const float *inptr4, const float *inptr5, + const float *inptr6, const float *inptr7, + float *outptr) { + auto ymm0 = _mm256_loadu2_m128_emulate(inptr1, inptr0); // A0A1A2A3B0B1B2B3 + auto ymm1 = _mm256_loadu2_m128_emulate(inptr3, inptr2); // C0C1C2C3D0D1D2D3 + auto ymm2 = _mm256_loadu2_m128_emulate(inptr5, inptr4); // E0E1E2E3F0F1F2F3 + auto ymm3 = _mm256_loadu2_m128_emulate(inptr7, inptr6); // G0G1G2G3H0H1H2H3 + _mm256_storeu_ps(outptr + 8 * 0, ymm0); + _mm256_storeu_ps(outptr + 8 * 1, ymm1); + _mm256_storeu_ps(outptr + 8 * 2, ymm2); + _mm256_storeu_ps(outptr + 8 * 3, ymm3); +} + +void interleave_8x2_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + const float *inptr4, const float *inptr5, + const float *inptr6, const float *inptr7, + float *outptr) { +#define cb(i) \ + *outptr++ = inptr##i[0]; \ + *outptr++ = inptr##i[1]; + UNROLL_CODE(cb, 8) +#undef cb +} + +void interleave_1x4_1_s(const float *inptr0, float *outptr) { + outptr[0] = inptr0[0]; + outptr[1] = inptr0[1]; + outptr[2] = inptr0[2]; + outptr[3] = inptr0[3]; +} +void interleave_8x6_1_s(const float *inptr0, const float *inptr1, + const float *inptr2, const float *inptr3, + const float *inptr4, const float *inptr5, + const float *inptr6, const float *inptr7, + float *outptr) { +#define cb(i) auto xmm##i = _mm_loadu_ps(inptr##i); + UNROLL_CODE(cb, 8) +#undef cb +#define cb(i) _mm_storeu_ps(outptr + 6 * i, xmm##i); + UNROLL_CODE(cb, 8) +#undef cb +#define cb(i) \ + outptr[6 * i + 4] = inptr##i[4]; \ + outptr[6 * i + 5] = inptr##i[5]; + UNROLL_CODE(cb, 8) +#undef cb +} + +void interleave_1x6_1_s(const float *inptr0, float *outptr) { + outptr[0] = inptr0[0]; + outptr[1] = inptr0[1]; + outptr[2] = inptr0[2]; + outptr[3] = inptr0[3]; + outptr[4] = inptr0[4]; + outptr[5] = inptr0[5]; +} + +void interleave_1x2_1_s(const float *inptr0, float *outptr) { + outptr[0] = inptr0[0]; + outptr[1] = inptr0[1]; +} + +static inline void interleave_helper(const float *inptr, float *outptr, + int unroll_k, int ksize, float val) { + int k = 0; + for (; k < ksize; k++) { + *outptr++ = *inptr++; + } + for (; k < unroll_k; k++) { + *outptr++ = val; + } +} +void interleave_1(const float *inptr0, float *outptr, int unroll_k, int ksize, + float val) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + inptr0 += size; + outptr += unroll_k; + } +} + +void interleave_8(const float *inptr0, const float *inptr1, const float *inptr2, + const float *inptr3, const float *inptr4, const float *inptr5, + const float *inptr6, const float *inptr7, float *outptr, + int unroll_k, int ksize, float val) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + inptr0 += size; + outptr += unroll_k; + interleave_helper(inptr1, outptr, unroll_k, size, val); + inptr1 += size; + outptr += unroll_k; + interleave_helper(inptr2, outptr, unroll_k, size, val); + inptr2 += size; + outptr += unroll_k; + interleave_helper(inptr3, outptr, unroll_k, size, val); + inptr3 += size; + outptr += unroll_k; + interleave_helper(inptr4, outptr, unroll_k, size, val); + inptr4 += size; + outptr += unroll_k; + interleave_helper(inptr5, outptr, unroll_k, size, val); + inptr5 += size; + outptr += unroll_k; + interleave_helper(inptr6, outptr, unroll_k, size, val); + inptr6 += size; + outptr += unroll_k; + interleave_helper(inptr7, outptr, unroll_k, size, val); + inptr7 += size; + outptr += unroll_k; + } +} + +DNN_AVX2_TARGET +MEGDNN_ATTRIBUTE_TARGET("fma") +void gemm_6x16_kern2x16(const float *packA, const float *packB, int K, + float *output, int LDC, bool is_first_k, int m_remain) { + const float *cur_b = packB; + const float *cur_a = packA; + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 b_tmp0, b_tmp1; + __m256 tmp; + if (is_first_k) { +#define cb(i) ymm##i = _mm256_set1_ps(0.0f); + UNROLL_CODE(cb, 4) +#undef cb + } else { + ymm0 = _mm256_loadu_ps(output + LDC * 0 + 0); + ymm1 = _mm256_loadu_ps(output + LDC * 0 + 8); + ymm2 = _mm256_loadu_ps(output + LDC * 1 + 0); + ymm3 = _mm256_loadu_ps(output + LDC * 1 + 8); + } + b_tmp0 = _mm256_loadu_ps(cur_b); + b_tmp1 = _mm256_loadu_ps(cur_b + 8); + size_t i = 0; + for (; i + 2 <= K; i += 2) { + cur_b += 16; + +#define CAL_OUPUT(i, first, second) \ + tmp = _mm256_broadcast_ss(cur_a + i); \ + ymm##first = _mm256_fmadd_ps(b_tmp0, tmp, ymm##first); \ + ymm##second = _mm256_fmadd_ps(b_tmp1, tmp, ymm##second); + + CAL_OUPUT(0, 0, 1) + CAL_OUPUT(1, 2, 3) + b_tmp0 = _mm256_loadu_ps(cur_b); + b_tmp1 = _mm256_loadu_ps(cur_b + 8); + cur_b += 16; + CAL_OUPUT(2, 0, 1) + CAL_OUPUT(3, 2, 3) + cur_a += 4; + b_tmp0 = _mm256_loadu_ps(cur_b); + b_tmp1 = _mm256_loadu_ps(cur_b + 8); + } + if (i < K) { + CAL_OUPUT(0, 0, 1) + CAL_OUPUT(1, 2, 3) + } +#undef CAL_OUPUT + switch (m_remain) { + case 2: + _mm256_storeu_ps(output + LDC * 1 + 0, ymm2); + _mm256_storeu_ps(output + LDC * 1 + 8, ymm3); + case 1: + _mm256_storeu_ps(output + LDC * 0 + 0, ymm0); + _mm256_storeu_ps(output + LDC * 0 + 8, ymm1); + default: + break; + } +} + +DNN_AVX2_TARGET +MEGDNN_ATTRIBUTE_TARGET("fma") +void gemm_6x16_kern6x4(const float *packA, const float *packB, int K, + float *output, int LDC, bool is_first_k, int n_remain) { + const float *cur_b = packB; + const float *cur_a = packA; + __m128 xmm0, xmm1, xmm2, xmm3, xmm4, xmm5; + __m128 tmp_a, tmp_b; + if (is_first_k) { + xmm0 = _mm_set1_ps(0.0f); + xmm1 = _mm_set1_ps(0.0f); + xmm2 = _mm_set1_ps(0.0f); + xmm3 = _mm_set1_ps(0.0f); + xmm4 = _mm_set1_ps(0.0f); + xmm5 = _mm_set1_ps(0.0f); + } else { + xmm0 = _mm_loadu_ps(output + LDC * 0); + xmm1 = _mm_loadu_ps(output + LDC * 1); + xmm2 = _mm_loadu_ps(output + LDC * 2); + xmm3 = _mm_loadu_ps(output + LDC * 3); + xmm4 = _mm_loadu_ps(output + LDC * 4); + xmm5 = _mm_loadu_ps(output + LDC * 5); + } + + for (size_t i = 0; i < K; i++) { + tmp_b = _mm_loadu_ps(cur_b); + cur_b += 4; + tmp_a = _mm_broadcast_ss(cur_a); + xmm0 = _mm_fmadd_ps(tmp_a, tmp_b, xmm0); + tmp_a = _mm_broadcast_ss(cur_a + 1); + xmm1 = _mm_fmadd_ps(tmp_a, tmp_b, xmm1); + tmp_a = _mm_broadcast_ss(cur_a + 2); + xmm2 = _mm_fmadd_ps(tmp_a, tmp_b, xmm2); + tmp_a = _mm_broadcast_ss(cur_a + 3); + xmm3 = _mm_fmadd_ps(tmp_a, tmp_b, xmm3); + tmp_a = _mm_broadcast_ss(cur_a + 4); + xmm4 = _mm_fmadd_ps(tmp_a, tmp_b, xmm4); + tmp_a = _mm_broadcast_ss(cur_a + 5); + xmm5 = _mm_fmadd_ps(tmp_a, tmp_b, xmm5); + cur_a += 6; + } + if (n_remain == 4) { + _mm_storeu_ps(output + LDC * 0, xmm0); + _mm_storeu_ps(output + LDC * 1, xmm1); + _mm_storeu_ps(output + LDC * 2, xmm2); + _mm_storeu_ps(output + LDC * 3, xmm3); + _mm_storeu_ps(output + LDC * 4, xmm4); + _mm_storeu_ps(output + LDC * 5, xmm5); + } else { + float dst[6 * 4]; + _mm_storeu_ps(dst + 4 * 0, xmm0); + _mm_storeu_ps(dst + 4 * 1, xmm1); + _mm_storeu_ps(dst + 4 * 2, xmm2); + _mm_storeu_ps(dst + 4 * 3, xmm3); + _mm_storeu_ps(dst + 4 * 4, xmm4); + _mm_storeu_ps(dst + 4 * 5, xmm5); + for (size_t i = 0; i < n_remain; i++) { + for (size_t j = 0; j < 6; j++) { + output[LDC * j + i] = dst[4 * j + i]; + } + } + } +} + +DNN_AVX2_TARGET +MEGDNN_ATTRIBUTE_TARGET("fma") +void gemm_6x16_kern2x4(const float *packA, const float *packB, int K, + float *output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + const float *cur_b = packB; + const float *cur_a = packA; + __m128 xmm0, xmm1; + __m128 tmp_a, tmp_b; + if (is_first_k) { + xmm0 = _mm_set1_ps(0.0f); + xmm1 = _mm_set1_ps(0.0f); + } else { + xmm0 = _mm_loadu_ps(output + LDC * 0); + xmm1 = _mm_loadu_ps(output + LDC * 1); + } + + for (size_t i = 0; i < K; i++) { + tmp_b = _mm_loadu_ps(cur_b); + cur_b += 4; + tmp_a = _mm_broadcast_ss(cur_a); + xmm0 = _mm_fmadd_ps(tmp_a, tmp_b, xmm0); + tmp_a = _mm_broadcast_ss(cur_a + 1); + xmm1 = _mm_fmadd_ps(tmp_a, tmp_b, xmm1); + cur_a += 2; + } + float dst[2 * 4]; + _mm_storeu_ps(dst + 4 * 0, xmm0); + _mm_storeu_ps(dst + 4 * 1, xmm1); + for (size_t i = 0; i < n_remain; i++) { + for (size_t j = 0; j < m_remain; j++) { + output[LDC * j + i] = dst[4 * j + i]; + } + } +} + +DNN_AVX2_TARGET +MEGDNN_ATTRIBUTE_TARGET("fma") +void gemm_6x16_kern6x16(const float *packA, const float *packB, int K, + float *output, int LDC, bool is_first_k) { + const float *cur_b = packB; + const float *cur_a = packA; + __m256 ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, + ymm11; + __m256 b_tmp0, b_tmp1; + __m256 tmp; + if (is_first_k) { +#define cb(i) ymm##i = _mm256_set1_ps(0.0f); + UNROLL_CODE(cb, 12) +#undef cb + } else { + ymm0 = _mm256_loadu_ps(output + LDC * 0 + 0); + ymm1 = _mm256_loadu_ps(output + LDC * 0 + 8); + ymm2 = _mm256_loadu_ps(output + LDC * 1 + 0); + ymm3 = _mm256_loadu_ps(output + LDC * 1 + 8); + ymm4 = _mm256_loadu_ps(output + LDC * 2 + 0); + ymm5 = _mm256_loadu_ps(output + LDC * 2 + 8); + ymm6 = _mm256_loadu_ps(output + LDC * 3 + 0); + ymm7 = _mm256_loadu_ps(output + LDC * 3 + 8); + ymm8 = _mm256_loadu_ps(output + LDC * 4 + 0); + ymm9 = _mm256_loadu_ps(output + LDC * 4 + 8); + ymm10 = _mm256_loadu_ps(output + LDC * 5 + 0); + ymm11 = _mm256_loadu_ps(output + LDC * 5 + 8); + } + b_tmp0 = _mm256_loadu_ps(cur_b); + b_tmp1 = _mm256_loadu_ps(cur_b + 8); + size_t i = 0; + for (; i + 2 <= K; i += 2) { + cur_b += 16; + +#define CAL_OUPUT(i, first, second) \ + tmp = _mm256_broadcast_ss(cur_a + i); \ + ymm##first = _mm256_fmadd_ps(b_tmp0, tmp, ymm##first); \ + ymm##second = _mm256_fmadd_ps(b_tmp1, tmp, ymm##second); + + CAL_OUPUT(0, 0, 1) + CAL_OUPUT(1, 2, 3) + CAL_OUPUT(2, 4, 5) + CAL_OUPUT(3, 6, 7) + CAL_OUPUT(4, 8, 9) + CAL_OUPUT(5, 10, 11) + b_tmp0 = _mm256_loadu_ps(cur_b); + b_tmp1 = _mm256_loadu_ps(cur_b + 8); + cur_b += 16; + CAL_OUPUT(6, 0, 1) + CAL_OUPUT(7, 2, 3) + CAL_OUPUT(8, 4, 5) + CAL_OUPUT(9, 6, 7) + CAL_OUPUT(10, 8, 9) + CAL_OUPUT(11, 10, 11) + cur_a += 12; + b_tmp0 = _mm256_loadu_ps(cur_b); + b_tmp1 = _mm256_loadu_ps(cur_b + 8); + } + if (i < K) { + CAL_OUPUT(0, 0, 1) + CAL_OUPUT(1, 2, 3) + CAL_OUPUT(2, 4, 5) + CAL_OUPUT(3, 6, 7) + CAL_OUPUT(4, 8, 9) + CAL_OUPUT(5, 10, 11) + } +#undef CAL_OUPUT + _mm256_storeu_ps(output + LDC * 0 + 0, ymm0); + _mm256_storeu_ps(output + LDC * 0 + 8, ymm1); + _mm256_storeu_ps(output + LDC * 1 + 0, ymm2); + _mm256_storeu_ps(output + LDC * 1 + 8, ymm3); + _mm256_storeu_ps(output + LDC * 2 + 0, ymm4); + _mm256_storeu_ps(output + LDC * 2 + 8, ymm5); + _mm256_storeu_ps(output + LDC * 3 + 0, ymm6); + _mm256_storeu_ps(output + LDC * 3 + 8, ymm7); + _mm256_storeu_ps(output + LDC * 4 + 0, ymm8); + _mm256_storeu_ps(output + LDC * 4 + 8, ymm9); + _mm256_storeu_ps(output + LDC * 5 + 0, ymm10); + _mm256_storeu_ps(output + LDC * 5 + 8, ymm11); +} + +void gemm_6x16_kern(const float *packA, const float *packB, size_t M, size_t N, + size_t K, float *C, size_t LDC, int is_first_k) { + size_t n = 0; + const int K2 = K * 2; + const int K4 = K * 4; + const int K6 = K * 6; + const int K16 = K * 16; + const int A_INTERLEAVE6 = 6; + const int A_INTERLEAVE2 = 2; + const int B_INTERLEAVE16 = 16; + const int B_INTERLEAVE4 = 4; + auto *cur_packB = packB; + for (; n + B_INTERLEAVE16 <= N; n += B_INTERLEAVE16) { + size_t m = 0; + auto output = C + n; + auto *cur_packA = packA; + for (; m + A_INTERLEAVE6 <= M; m += A_INTERLEAVE6) { + gemm_6x16_kern6x16(cur_packA, cur_packB, K, output, LDC, is_first_k); + output += A_INTERLEAVE6 * LDC; + cur_packA += K6; + } + for (; m < M; m += A_INTERLEAVE2) { + gemm_6x16_kern2x16(cur_packA, cur_packB, K, output, LDC, is_first_k, + min(M - m, 2)); + output += A_INTERLEAVE2 * LDC; + cur_packA += K2; + } + cur_packB += K16; + } + + for (; n < N; n += B_INTERLEAVE4) { + size_t m = 0; + auto output = C + n; + auto *cur_packA = packA; + for (; m + A_INTERLEAVE6 <= M; m += A_INTERLEAVE6) { + gemm_6x16_kern6x4(cur_packA, cur_packB, K, output, LDC, is_first_k, + min(N - n, 4)); + output += A_INTERLEAVE6 * LDC; + cur_packA += K6; + } + for (; m < M; m += A_INTERLEAVE2) { + + gemm_6x16_kern2x4(cur_packA, cur_packB, K, output, LDC, is_first_k, + min(M - m, 2), min(N - n, 4)); + output += A_INTERLEAVE2 * LDC; + cur_packA += K2; + } + cur_packB += K4; + } +} + +void gemm_6x16_pack_A_t(float *outptr, const float *inptr, int ldin, int x0, + int xmax, int k0, int kmax) { + size_t ksize = kmax - k0; + size_t ksize6 = ksize * 6; + size_t ksize2 = ksize * 2; + float *outptr_base6 = outptr; + float *outptr_base2 = outptr_base6 + (xmax - x0) / 6 * ksize6; + size_t k = k0; + + for (; k + 7 < kmax; k += 8) { + const float *cur_inptr = inptr + k * ldin + k0; +#define cb(i) const float *inptr##i = cur_inptr + ldin * i; + UNROLL_CODE(cb, 8) +#undef cb +#define cb(i) __builtin_prefetch(inptr##i, 0, 3); + UNROLL_CODE(cb, 8) +#undef cb + int x = x0; + float *outptr = outptr_base6; + for (; x + 6 <= xmax; x += 6) { + interleave_8x6_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr); +#define cb(i) inptr##i += 6; + UNROLL_CODE(cb, 8) +#undef cb + outptr += ksize6; + } + outptr = outptr_base2; + for (; x + 2 <= xmax; x += 2) { + interleave_8x2_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr); +#define cb(i) inptr##i += 2; + UNROLL_CODE(cb, 8) +#undef cb + outptr += ksize2; + } + if (x < xmax) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 2, xmax - x, 0); + inptr0 += xmax - x; + inptr1 += xmax - x; + inptr2 += xmax - x; + inptr3 += xmax - x; + inptr4 += xmax - x; + inptr5 += xmax - x; + inptr6 += xmax - x; + inptr7 += xmax - x; + } + outptr_base6 += 8 * 6; + outptr_base2 += 8 * 2; + } + for (; k < kmax; k++) { + const float *inptr0 = inptr + k * ldin + k0; + __builtin_prefetch(inptr0, 0, 3); + int x = x0; + float *outptr = outptr_base6; + for (; x + 6 <= xmax; x += 6) { + interleave_1x6_1_s(inptr0, outptr); + inptr0 += 6; + outptr += ksize6; + } + outptr = outptr_base2; + for (; x + 2 <= xmax; x += 2) { + interleave_1x2_1_s(inptr0, outptr); + inptr0 += 2; + outptr += ksize2; + } + if (x < xmax) { + interleave_1(inptr0, outptr, 2, xmax - x, 0); + inptr0 += xmax - x; + outptr += 2; + } + outptr_base6 += 6; + outptr_base2 += 2; + } +} + +void gemm_6x16_pack_A_n(float *outptr, const float *inptr, int ldin, int y0, + int ymax, int k0, int kmax) { + float zerobuff[16]; + memset(zerobuff, 0, sizeof(float) * 16); + size_t y = y0; + const size_t PACK_SIZE_96 = 6 * 16; + const size_t PACK_SIZE_48 = 6 * 8; + const size_t PACK_SIZE_24 = 6 * 4; + const size_t PACK_SIZE_32 = 4 * 8; + const size_t PACK_SIZE_16 = 4 * 4; + const size_t PACK_SIZE_8 = 4 * 2; + for (; y + 5 < ymax; y += 6) { + const float *cur_inptr = inptr + y * ldin + k0; +#define cb(i) const float *inptr##i = cur_inptr + ldin * i; + UNROLL_CODE(cb, 6) +#undef cb +#define cb(i) __builtin_prefetch(inptr##i, 0, 3); + UNROLL_CODE(cb, 6) +#undef cb + int x = (kmax - k0); + for (; x > 15; x -= 16) { + transpose_6x16_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + outptr); +#define cb(i) inptr##i += 16; + UNROLL_CODE(cb, 6) +#undef cb + outptr += PACK_SIZE_96; + } + for (; x > 7; x -= 8) { + transpose_6x8_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, outptr); +#define cb(i) inptr##i += 8; + UNROLL_CODE(cb, 6) +#undef cb + outptr += PACK_SIZE_48; + } + for (; x > 3; x -= 4) { + transpose_6x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, outptr); +#define cb(i) inptr##i += 4; + UNROLL_CODE(cb, 6) +#undef cb + outptr += PACK_SIZE_24; + } + for (; x > 0; x--) { +#define cb(i) *outptr++ = *inptr##i++; + UNROLL_CODE(cb, 6) +#undef cb + } + } + for (; y < ymax; y += 2) { + const float *cur_inptr = inptr + y * ldin + k0; +#define cb(i) const float *inptr##i = cur_inptr + ldin * i; + UNROLL_CODE(cb, 2) +#undef cb +#define cb(i) __builtin_prefetch(inptr##i, 0, 3); + UNROLL_CODE(cb, 2) +#undef cb + int x = kmax - k0; + for (; x > 15; x -= 16) { + if ((y + 1) >= ymax) { + inptr1 = zerobuff; + } + transpose_2x16_1_s(inptr0, inptr1, outptr); +#define cb(i) inptr##i += 16; + UNROLL_CODE(cb, 2) +#undef cb + outptr += PACK_SIZE_32; + } + for (; x > 7; x -= 8) { + if ((y + 1) >= ymax) { + inptr1 = zerobuff; + } + transpose_2x8_1_s(inptr0, inptr1, outptr); +#define cb(i) inptr##i += 8; + UNROLL_CODE(cb, 2) +#undef cb + outptr += PACK_SIZE_16; + } + for (; x > 3; x -= 4) { + if ((y + 1) >= ymax) { + inptr1 = zerobuff; + } + transpose_2x4_1_s(inptr0, inptr1, outptr); +#define cb(i) inptr##i += 4; + UNROLL_CODE(cb, 2) +#undef cb + outptr += PACK_SIZE_8; + } + if (x > 0) { + if ((y + 1) >= ymax) { + inptr1 = zerobuff; + } + for (size_t i = 0; i < x; i++) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + } + } + } +} + +void gemm_6x16_pack_B_t(float *outptr, const float *inptr, int ldin, int y0, + int ymax, int k0, int kmax) { + float zerobuff[16]; + memset(zerobuff, 0, sizeof(float) * 16); + const size_t PACK_SIZE_128 = 8 * 16; + const size_t PACK_SIZE_64 = 4 * 16; + const size_t PACK_SiZE_32 = 4 * 8; + const size_t PACK_SIZE_16 = 4 * 4; + size_t y = y0; + for (; y + 15 < ymax; y += 16) { + const float *cur_inptr = inptr + y * ldin + k0; +#define cb(i) const float *inptr##i = cur_inptr + ldin * i; + UNROLL_CODE(cb, 16) +#undef cb +#define cb(i) __builtin_prefetch(inptr##i, 0, 3); + UNROLL_CODE(cb, 16) +#undef cb + int x = (kmax - k0); + for (; x > 7; x -= 8) { + transpose_16x8_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, inptr8, inptr9, inptr10, inptr11, inptr12, + inptr13, inptr14, inptr15, outptr); +#define cb(i) inptr##i += 8; + UNROLL_CODE(cb, 16) +#undef cb + outptr += PACK_SIZE_128; + } + for (; x > 3; x -= 4) { + transpose_16x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, inptr8, inptr9, inptr10, inptr11, inptr12, + inptr13, inptr14, inptr15, outptr); +#define cb(i) inptr##i += 4; + UNROLL_CODE(cb, 16) +#undef cb + outptr += PACK_SIZE_64; + } + for (; x > 0; x--) { +#define cb(i) *outptr++ = *inptr##i++; + UNROLL_CODE(cb, 16) +#undef cb + } + } + for (; y < ymax; y += 4) { + const float *cur_inptr = inptr + y * ldin + k0; +#define cb(i) const float *inptr##i = cur_inptr + ldin * i; + UNROLL_CODE(cb, 4) +#undef cb +#define cb(i) __builtin_prefetch(inptr##i, 0, 3); + UNROLL_CODE(cb, 4) +#undef cb + int x = kmax - k0; + for (; x > 7; x -= 8) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + default: + break; + } + } + transpose_4x8_1_s(inptr0, inptr1, inptr2, inptr3, outptr); +#define cb(i) inptr##i += 8; + UNROLL_CODE(cb, 4) +#undef cb + outptr += PACK_SiZE_32; + } + for (; x > 3; x -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + default: + break; + } + } + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); +#define cb(i) inptr##i += 4; + UNROLL_CODE(cb, 4) +#undef cb + outptr += PACK_SIZE_16; + } + if (x > 0) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + break; + } + } + for (size_t i = 0; i < x; i++) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + } + } +} + +void gemm_6x16_pack_B_n(float *outptr, const float *inptr, int ldin, int x0, + int xmax, int k0, int kmax) { + size_t ksize = kmax - k0; + size_t ksize16 = ksize * 16; + size_t ksize4 = ksize * 4; + float *outptr_base16 = outptr; + float *outptr_base4 = outptr_base16 + (xmax - x0) / 16 * ksize16; + size_t k = k0; + + for (; k + 7 < kmax; k += 8) { + const float *cur_inptr = inptr + k * ldin + k0; +#define cb(i) const float *inptr##i = cur_inptr + ldin * i; + UNROLL_CODE(cb, 8) +#undef cb +#define cb(i) __builtin_prefetch(inptr##i, 0, 3); + UNROLL_CODE(cb, 8) +#undef cb + int x = x0; + float *outptr = outptr_base16; + for (; x + 16 <= xmax; x += 16) { + interleave_8x16_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, + inptr6, inptr7, outptr); +#define cb(i) inptr##i += 16; + UNROLL_CODE(cb, 8) +#undef cb + outptr += ksize16; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + interleave_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr); +#define cb(i) inptr##i += 4; + UNROLL_CODE(cb, 8) +#undef cb + outptr += ksize4; + } + + if (x < xmax) { + interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, + inptr7, outptr, 4, xmax - x, 0); + inptr0 += xmax - x; + inptr1 += xmax - x; + inptr2 += xmax - x; + inptr3 += xmax - x; + inptr4 += xmax - x; + inptr5 += xmax - x; + inptr6 += xmax - x; + inptr7 += xmax - x; + } + outptr_base16 += 8 * 16; + outptr_base4 += 8 * 4; + } + + for (; k < kmax; k++) { + const float *inptr0 = inptr + k * ldin + k0; + __builtin_prefetch(inptr0, 0, 3); + int x = x0; + float *outptr = outptr_base16; + for (; x + 16 <= xmax; x += 16) { + interleave_1x16_1_s(inptr0, outptr); + inptr0 += 16; + outptr += ksize16; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + interleave_1x4_1_s(inptr0, outptr); + inptr0 += 4; + outptr += ksize4; + } + if (x < xmax) { + interleave_1(inptr0, outptr, 4, xmax - x, 0); + inptr0 += xmax - x; + outptr += 4; + } + outptr_base16 += 16; + outptr_base4 += 4; + } +} +} // namespace +#undef UNROLL_CODE + +namespace megdnn { +namespace x86 { +namespace matmul { +void sgemm_pack_6x16_avx2::pack_A(float *out, const float *in, int ldin, int y0, + int ymax, int k0, int kmax, + bool transpose_A) const { + if (!transpose_A) + gemm_6x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); + else + gemm_6x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); +} + +void sgemm_pack_6x16_avx2::pack_B(float *out, const float *in, int ldin, int x0, + int xmax, int k0, int kmax, + bool transpose_B) const { + if (!transpose_B) + gemm_6x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + else + gemm_6x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); +} + +void sgemm_pack_6x16_avx2::kern(const float *packA, const float *packB, + size_t M, size_t N, size_t K, float *C, + size_t LDC, bool is_first_k, const float *bias, + float *workspace) const { + MEGDNN_MARK_USED_VAR(bias); + MEGDNN_MARK_USED_VAR(workspace); + gemm_6x16_kern(packA, packB, M, N, K, C, LDC, is_first_k); +}; +MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_pack_6x16_avx2); +} // namespace matmul +} // namespace x86 +} // namespace megdnn diff --git a/dnn/src/x86/matrix_mul/opr_impl.cpp b/dnn/src/x86/matrix_mul/opr_impl.cpp index cf6e64e93..7cbb75779 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.cpp +++ b/dnn/src/x86/matrix_mul/opr_impl.cpp @@ -34,6 +34,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt8x8x16AVX2 algoint8x8x16avx2_m4n16k2; AlgoInt8x8x16SSE algoint8x8x16sse_m4n8k2; AlgoF32MK8_8x8 algof32mk8_8x8; + AlgoFloatAVX2M6N16 algof32_6x16; SmallVector m_all_algos; fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; @@ -51,6 +52,7 @@ public: m_all_algos.emplace_back(&algoint8x8x32sse_m4n8k2); m_all_algos.emplace_back(&algoint8x8x16sse_m4n8k2); m_all_algos.emplace_back(&algof32mk8_8x8); + m_all_algos.emplace_back(&algof32_6x16); #if MEGDNN_X86_WITH_MKL_DNN m_all_algos.emplace_back(&algoint8x8x32mkldnn); #endif diff --git a/dnn/src/x86/matrix_mul/opr_impl.h b/dnn/src/x86/matrix_mul/opr_impl.h index 5d9eb78e9..537006ff5 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.h +++ b/dnn/src/x86/matrix_mul/opr_impl.h @@ -68,7 +68,7 @@ private: class AlgoInt8x8x16SSE; class AlgoPack; class AlgoF32MK8_8x8; - + class AlgoFloatAVX2M6N16; public: static const AlgoPack& algo_pack(); }; diff --git a/dnn/test/x86/accuracy_shake.cpp b/dnn/test/x86/accuracy_shake.cpp index 2e43c33eb..bfde6bdee 100644 --- a/dnn/test/x86/accuracy_shake.cpp +++ b/dnn/test/x86/accuracy_shake.cpp @@ -98,6 +98,15 @@ TEST_F(X86, SHAKE_MATRIX_MUL_FORWARD) { .exec({{20, 100}, {100, 60}, {}}); } +TEST_F(X86, SHAKE_MATRIX_MUL_6x16_FORWARD) { + AccuracyShakeChecker checker(handle()); + checker.set_before_exec_callback(AlgoGenerator("X86_F32_6x16")); + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .exec({{20, 100}, {100, 60}, {}}); +} + } // namespace test } // namespace megdnn diff --git a/dnn/test/x86/conv_bias.cpp b/dnn/test/x86/conv_bias.cpp index 247ee073b..63a3d1173 100644 --- a/dnn/test/x86/conv_bias.cpp +++ b/dnn/test/x86/conv_bias.cpp @@ -1171,6 +1171,110 @@ TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32_NOPACK_PREPROCESS) { #endif +TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_6x16) { + using namespace conv_bias; + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + //! no bias + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, TensorShape{}); + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + args.emplace_back( + param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, (h + 2 * p - kernel) / param.stride_h + 1, + (w + 2 * p - kernel) / param.stride_w + 1}); + }; + + for (size_t kernel : {2, 3, 4, 5, 6, 7}) + for (size_t ic : {1, 4, 8, 16}) + for (size_t oc : {1, 4, 8, 16, 300}) + for (size_t p : {0, 2}) + for (size_t size : {8,24}) + for (NonlineMode nonline_mode : + {NonlineMode::IDENTITY, NonlineMode::RELU}) { + run(oc, ic, size, size, kernel, p, nonline_mode); + } + + run(2046, 8, 20, 20, 3, 1, NonlineMode::IDENTITY); + Checker checker(handle()); + +#define cb(algo_name) \ + checker.set_before_exec_callback( \ + conv_bias::ConvBiasAlgoChecker(algo_name)); \ + for (auto&& arg : args) { \ + checker.set_param(arg.param).execs( \ + {arg.src, arg.filter, arg.bias, {}, {}}); \ + } + cb("IM2COLMATMUL:X86_F32_6x16:192"); +} + +TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32_6x16) { + using namespace conv_bias; + std::vector args; + + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + //! no bias + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, TensorShape{}); + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, 1, 1}); + args.emplace_back( + param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}, + TensorShape{1, oc, (h + 2 * p - kernel) / param.stride_h + 1, + (w + 2 * p - kernel) / param.stride_w + 1}); + }; + + for (size_t kernel : {2, 3, 4, 5, 6, 7}) + for (size_t ic : {1, 4, 8, 16}) + for (size_t oc : {1, 4, 8, 16, 300}) + for (size_t p : {0, 2}) + for (size_t size : {8, 24}) + for (NonlineMode nonline_mode : + {NonlineMode::IDENTITY, NonlineMode::RELU}) { + run(oc, ic, size, size, kernel, p, nonline_mode); + } + + run(2046, 8, 20, 20, 3, 1, NonlineMode::IDENTITY); + Checker checker(handle()); +#define cb(algo_name) \ + checker.set_before_exec_callback( \ + conv_bias::ConvBiasAlgoChecker(algo_name)); \ + for (auto&& arg : args) { \ + checker.set_param(arg.param).execs( \ + {arg.src, arg.filter, arg.bias, {}, {}}); \ + } + + cb("IM2COLMATMUL:X86_F32_6x16:192"); + +#undef cb +} + + #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) { using namespace conv_bias; @@ -1377,6 +1481,12 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS) { check_conv_bias(args, handle(), "CONV1x1:X86_F32_BLAS:48"); } +TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_6x16) { + using namespace conv_bias; + std::vector args = get_conv_bias_1x1_args(false, false); + check_conv_bias(args, handle(), "CONV1x1:X86_F32_6x16:48"); +} + TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS_NOPACK_REPROCESS) { using namespace conv_bias; std::vector args = get_conv_bias_1x1_args(false, false); @@ -2627,6 +2737,76 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32) { shapes_and_computation.clear(); } +TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_6x16) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + + std::vector data_type = {dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32()}; + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group) { + SmallVector shapes{{N, IC, H, W}, + {OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + + bench_case(1, 64, 32, 7, 7, 3, 1); + bench_case(1, 64, 64, 7, 7, 3, 1); + bench_case(1, 64, 128, 7, 7, 3, 1); + bench_case(1, 64, 256, 7, 7, 3, 1); + bench_case(1, 64, 512, 7, 7, 3, 1); + bench_case(1, 64, 1024, 7, 7, 3, 1); + + bench_case(1, 64, 32, 14, 14, 3, 1); + bench_case(1, 64, 64, 14, 14, 3, 1); + bench_case(1, 64, 128, 14, 14, 3, 1); + bench_case(1, 64, 256, 14, 14, 3, 1); + bench_case(1, 64, 512, 14, 14, 3, 1); + + bench_case(1, 64, 1024, 14, 14, 3, 1); + bench_case(1, 128, 128, 14, 14, 3, 1); + bench_case(1, 128, 256, 14, 14, 3, 1); + bench_case(1, 512, 512, 14, 14, 3, 1); + bench_case(1, 256, 512, 14, 14, 3, 1); + bench_case(1, 512, 1024, 14, 14, 3, 1); + bench_case(1, 1024, 1024, 14, 14, 3, 1); + + std::string algo_name = "IM2COLMATMUL:X86_F32_6x16:192"; + printf("Benchmark IM2COLMATMUL:X86_F32_6x16 algo\n"); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {7}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); +} + TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) { constexpr size_t RUNS = 50; @@ -2697,6 +2877,76 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, shapes_and_computation.clear(); } +TEST_F(X86_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_IM2COL_F32_6X16_single_thread) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + + std::vector data_type = {dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32()}; + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group) { + SmallVector shapes{{N, IC, H, W}, + {OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + + bench_case(1, 64, 32, 7, 7, 3, 1); + bench_case(1, 64, 64, 7, 7, 3, 1); + bench_case(1, 64, 128, 7, 7, 3, 1); + bench_case(1, 64, 256, 7, 7, 3, 1); + bench_case(1, 64, 512, 7, 7, 3, 1); + bench_case(1, 64, 1024, 7, 7, 3, 1); + + bench_case(1, 64, 32, 14, 14, 3, 1); + bench_case(1, 64, 64, 14, 14, 3, 1); + bench_case(1, 64, 128, 14, 14, 3, 1); + bench_case(1, 64, 256, 14, 14, 3, 1); + bench_case(1, 64, 512, 14, 14, 3, 1); + + bench_case(1, 64, 1024, 14, 14, 3, 1); + bench_case(1, 128, 128, 14, 14, 3, 1); + bench_case(1, 128, 256, 14, 14, 3, 1); + bench_case(1, 512, 512, 14, 14, 3, 1); + bench_case(1, 256, 512, 14, 14, 3, 1); + bench_case(1, 512, 1024, 14, 14, 3, 1); + bench_case(1, 1024, 1024, 14, 14, 3, 1); + + std::string algo_name = "IM2COLMATMUL:X86_F32_MKL_PACKA:192"; + std::string algo_name1 = "IM2COLMATMUL:X86_F32_6x16:192"; + printf("Benchmark IM2COLMATMUL:X86_F32_6x16 algo\n"); + benchmark_impl_comp(param, shapes_and_computation, algo_name, algo_name1, + RUNS, {1, {4}}, {1, {4}}, data_type); + benchmark_impl_comp(param, shapes_and_computation, algo_name, algo_name1, + RUNS, {1, {7}}, {1, {7}}, data_type); + shapes_and_computation.clear(); +} + TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_INT8X8X32) { constexpr size_t RUNS = 50; diff --git a/dnn/test/x86/matrix_mul.cpp b/dnn/test/x86/matrix_mul.cpp index 5f8903a11..759e89bc5 100644 --- a/dnn/test/x86/matrix_mul.cpp +++ b/dnn/test/x86/matrix_mul.cpp @@ -85,6 +85,13 @@ TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) { param::MatrixMul::Format::MK8, 1, 1e-3, false); } +TEST_F(X86, MATRIX_MUL_AVX2_6x16) { + matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, + dtype::Float32{}, handle(), "X86_F32_6x16", + param::MatrixMul::Format::DEFAULT, 1, 1e-3, false); +} + + #if MEGDNN_WITH_BENCHMARK TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_MK8_8X8) { @@ -96,6 +103,14 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_MK8_8X8) { "X86_F32_BLAS"); } +TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_6x16) { + auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); + matrix_mul::benchmark_with_contrast( + handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, + "X86_F32_6x16", param::MatrixMul::Format::DEFAULT, dtype::Float32{}, + dtype::Float32{}, dtype::Float32{},"X86_F32_BLAS"); +} + TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { constexpr size_t RUNS = 50; auto rng = std::make_unique(-127, 127); -- GitLab