From 25b6a13148ba587dc769aab96043e9a092047d83 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 12 Jun 2020 18:28:37 +0800 Subject: [PATCH] feat(dnn/x86): add x86 avx2 8x8x16 matmul GitOrigin-RevId: d2172c50b244a0683b00710a88bc507d02a9734f --- dnn/src/x86/matrix_mul/algos.cpp | 73 ++++++ dnn/src/x86/matrix_mul/algos.h | 22 +- dnn/src/x86/matrix_mul/common/common.h | 45 +++- .../matrix_mul/int8/avx2_strategy_4x16x2.cpp | 93 ++++++-- .../x86/matrix_mul/int8/kernel_avx2_4x16x2.h | 209 +++++++++++------- dnn/src/x86/matrix_mul/int8/strategy.h | 7 +- dnn/src/x86/matrix_mul/opr_impl.cpp | 2 + dnn/src/x86/matrix_mul/opr_impl.h | 4 +- dnn/test/x86/conv_bias.cpp | 49 +++- dnn/test/x86/convolution.cpp | 61 ++++- dnn/test/x86/matrix_mul.cpp | 24 +- 11 files changed, 464 insertions(+), 125 deletions(-) diff --git a/dnn/src/x86/matrix_mul/algos.cpp b/dnn/src/x86/matrix_mul/algos.cpp index cf7e36698..c7d7ddd0f 100644 --- a/dnn/src/x86/matrix_mul/algos.cpp +++ b/dnn/src/x86/matrix_mul/algos.cpp @@ -318,6 +318,79 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) { } } // namespace +void MatrixMulImpl::AlgoInt8x8x16AVX2::gemm_s8s8s16_avx2_4x16x2( + const MatrixMulImpl::KernParam& kern_param) { + MEGDNN_MARK_USED_VAR(kern_param); + MIDOUT_BEGIN(megdnn_x86_matmul_kern_avx2_4x16x2, midout_iv(1)) { + constexpr int cacheline = 64; + const size_t m = kern_param.M; + const size_t n = kern_param.N; + const size_t k = kern_param.K; + const bool trans_a = kern_param.trA; + const bool trans_b = kern_param.trB; + const size_t lda = kern_param.LDA; + const size_t ldb = kern_param.LDB; + const size_t ldc = kern_param.LDC; + auto a_type = kern_param.A_type; + auto b_type = kern_param.B_type; + auto c_type = kern_param.C_type; + const auto a_ptr = kern_param.A(); + const auto b_ptr = kern_param.B(); + auto c_ptr = kern_param.C(); + x86::matmul::gemm_avx2_s8s8s16_4x16x2 strategy(m, n, k, a_type, b_type, + c_type); + + megdnn::matmul::GemmInterleaved( + m, n, k, trans_a, trans_b, strategy, cacheline) + .execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_kern( + const KernSizeParam&) const { + return gemm_s8s8s16_avx2_4x16x2; +} +bool MatrixMulImpl::AlgoInt8x8x16AVX2::usable( + const KernSizeParam& kern_size_param) const { + bool is_ab_same = + kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv(); + bool is_type_ok = + ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && + kern_size_param.C_type.enumv() == DTypeEnum::Int16) || + (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16)); + bool is_mode_ok = + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + is_supported(SIMDType::AVX2); + bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok; + return is_param_ok; +} +bool MatrixMulImpl::AlgoInt8x8x16AVX2::preferred(const KernSizeParam&) const { + return true; +} +size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace( + const KernSizeParam& kern_param) const { + constexpr int cacheline = 64; + const size_t m = kern_param.M; + const size_t n = kern_param.N; + const size_t k = kern_param.K; + const bool trans_a = kern_param.trA; + const bool trans_b = kern_param.trB; + auto a_type = kern_param.A_type; + auto b_type = kern_param.B_type; + auto c_type = kern_param.C_type; + x86::matmul::gemm_avx2_s8s8s16_4x16x2 strategy(m, n, k, a_type, b_type, + c_type); + + return megdnn::matmul::GemmInterleaved< + x86::matmul::gemm_avx2_s8s8s16_4x16x2>( + m, n, k, trans_a, trans_b, strategy, cacheline) + .get_workspace_size(); +} +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( + AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, 8, + x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16); MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( const KernSizeParam&) const { diff --git a/dnn/src/x86/matrix_mul/algos.h b/dnn/src/x86/matrix_mul/algos.h index 46c085cd7..a4dabae26 100644 --- a/dnn/src/x86/matrix_mul/algos.h +++ b/dnn/src/x86/matrix_mul/algos.h @@ -6,13 +6,14 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once -#include "src/x86/matrix_mul/opr_impl.h" #include "src/fallback/matrix_mul/gemm_common.h" +#include "src/x86/matrix_mul/opr_impl.h" namespace megdnn { namespace x86 { @@ -71,6 +72,23 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; +class MatrixMulImpl::AlgoInt8x8x16AVX2 : public AlgoBase { +private: + static void gemm_s8s8s16_avx2_4x16x2( + const MatrixMulImpl::KernParam& kern_param); + static MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 m_algo; + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "X86_INT8X8X16_AVX2"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_x86_algo_type; } + bool preferred(const KernSizeParam&) const override; + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase { public: bool is_reproducible() const override { return true; } diff --git a/dnn/src/x86/matrix_mul/common/common.h b/dnn/src/x86/matrix_mul/common/common.h index 45f11f94f..2b5a3e0b7 100644 --- a/dnn/src/x86/matrix_mul/common/common.h +++ b/dnn/src/x86/matrix_mul/common/common.h @@ -6,16 +6,17 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include #ifdef WIN32 -#include -#include #include +#include #include +#include #endif #include #include @@ -787,19 +788,49 @@ static inline void transpose_4x8_k2_int8_to_int16(const int8_t* inptr0, MEGDNN_ATTRIBUTE_TARGET("avx2") static inline __v8si _m256_continue_mask_v8si(const int& x) { + // clang-format off static __v8si map[9] = { - {0, 0, 0, 0, 0, 0, 0, 0}, {-1, 0, 0, 0, 0, 0, 0, 0}, - {-1, -1, 0, 0, 0, 0, 0, 0}, {-1, -1, -1, 0, 0, 0, 0, 0}, - {-1, -1, -1, -1, 0, 0, 0, 0}, {-1, -1, -1, -1, -1, 0, 0, 0}, - {-1, -1, -1, -1, -1, -1, 0, 0}, {-1, -1, -1, -1, -1, -1, -1, 0}, + {00, 00, 00, 00, 00, 00, 00, 00}, + {-1, 00, 00, 00, 00, 00, 00, 00}, + {-1, -1, 00, 00, 00, 00, 00, 00}, + {-1, -1, -1, 00, 00, 00, 00, 00}, + {-1, -1, -1, -1, 00, 00, 00, 00}, + {-1, -1, -1, -1, -1, 00, 00, 00}, + {-1, -1, -1, -1, -1, -1, 00, 00}, + {-1, -1, -1, -1, -1, -1, -1, 00}, {-1, -1, -1, -1, -1, -1, -1, -1}}; return map[x]; + // clang-format on } MEGDNN_ATTRIBUTE_TARGET("avx2") static inline __m256i _m256_continue_mask(const int& x) { return (__m256i)_m256_continue_mask_v8si(x); } +MEGDNN_ATTRIBUTE_TARGET("sse2") +static inline __m128i _mm_continue_mask(const int& x) { + static __v16qi map[17] = { + {00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, + {-1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, + {-1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, + {-1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, + {-1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, + {-1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, + {-1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, + {-1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00}, + {-1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + }; + return (__m128i)map[x]; +} + MEGDNN_ATTRIBUTE_TARGET("sse2") static inline void transpose_4xk_int8_to_int16_pad(const int8_t* inptr0, const int8_t* inptr1, diff --git a/dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp b/dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp index ee31b0e0c..57cd5c084 100644 --- a/dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp +++ b/dnn/src/x86/matrix_mul/int8/avx2_strategy_4x16x2.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/common/utils.h" @@ -18,10 +19,9 @@ using namespace megdnn; using namespace x86; using namespace x86::matmul; -MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_avx2_s8s8s32_4x16x2); -void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in, - int ldin, int y0, int ymax, int k0, - int kmax, bool transpose) const { +static inline void gemm_packa(dt_int16* out, const dt_int8* in, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) { if (transpose) { matmul_avx2_4x16x2::gemm_s8s8s32_avx2_4x16x2_pack_at(out, in, ldin, y0, ymax, k0, kmax); @@ -30,10 +30,8 @@ void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in, ymax, k0, kmax); } } - -void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax, - bool transpose) const { +static inline void gemm_packb(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) { if (transpose) { matmul_avx2_4x16x2::gemm_s8s8s32_avx2_4x16x2_pack_bt(out, in, ldin, x0, xmax, k0, kmax); @@ -42,20 +40,11 @@ void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, xmax, k0, kmax); } } - -void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr, - const dt_int8* pack_b_ptr, size_t m, - size_t n, size_t k, dt_int32* c_ptr, - size_t ldc, bool is_first_k, - const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int32) || - (A_dtype.enumv() == DTypeEnum::QuantizedS8 && - C_dtype.enumv() == DTypeEnum::QuantizedS32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); - megdnn_assert(is_first_k == true); +template +static inline void gemm_kern(const dt_int16* pack_a_ptr, + const dt_int8* pack_b_ptr, size_t m, size_t n, + size_t k, CType* c_ptr, size_t ldc, + bool is_first_k) { constexpr size_t m_tile = 4; constexpr size_t n_tile = 16; constexpr size_t k_tile = 2; @@ -109,4 +98,62 @@ void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr, } } } + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_avx2_s8s8s32_4x16x2); +void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in, + int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { + gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose); +} + +void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax, + bool transpose) const { + gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose); +} + +void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr, + const dt_int8* pack_b_ptr, size_t m, + size_t n, size_t k, dt_int32* c_ptr, + size_t ldc, bool is_first_k, + const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + megdnn_assert(is_first_k == true); + gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k); +} + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_avx2_s8s8s16_4x16x2); +void gemm_avx2_s8s8s16_4x16x2::pack_A(dt_int16* out, const dt_int8* in, + int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { + gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose); +} + +void gemm_avx2_s8s8s16_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax, + bool transpose) const { + gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose); +} + +void gemm_avx2_s8s8s16_4x16x2::kern(const dt_int16* pack_a_ptr, + const dt_int8* pack_b_ptr, size_t m, + size_t n, size_t k, dt_int16* c_ptr, + size_t ldc, bool is_first_k, + const dt_int32*, dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS16)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + megdnn_assert(is_first_k == true); + gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k); +} // vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/matrix_mul/int8/kernel_avx2_4x16x2.h b/dnn/src/x86/matrix_mul/int8/kernel_avx2_4x16x2.h index d68439095..01565354a 100644 --- a/dnn/src/x86/matrix_mul/int8/kernel_avx2_4x16x2.h +++ b/dnn/src/x86/matrix_mul/int8/kernel_avx2_4x16x2.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include @@ -20,11 +21,47 @@ namespace megdnn { namespace x86 { namespace matmul_avx2_4x16x2 { +template +MEGDNN_ATTRIBUTE_TARGET("avx2") +void store_overflow(void* ptr, __m256i a); + +template <> +void store_overflow(void* ptr, __m256i a) { + static __m256i idx = _mm256_setr_epi32(0, 2, 4, 6, 0, 0, 0, 0); + a = _mm256_shufflelo_epi16(a, 0x08); + a = _mm256_shufflehi_epi16(a, 0x08); + a = _mm256_permutevar8x32_epi32(a, idx); + _mm_storeu_si128((__m128i*)ptr, _mm256_extractf128_si256(a, 0)); +} +template <> +void store_overflow(void* ptr, __m256i a) { + _mm256_storeu_si256((__m256i*)(ptr), a); +} +template +MEGDNN_ATTRIBUTE_TARGET("avx2") +void store_overflow(void* ptr, __m256i a, int remain); + +template <> +void store_overflow(void* ptr, __m256i a, int remain) { + __m128i mask = _mm_continue_mask(remain * sizeof(int16_t)); + static __m256i idx = _mm256_setr_epi32(0, 2, 4, 6, 0, 0, 0, 0); + a = _mm256_shufflelo_epi16(a, 0x08); + a = _mm256_shufflehi_epi16(a, 0x08); + a = _mm256_permutevar8x32_epi32(a, idx); + _mm_maskmoveu_si128(_mm256_extractf128_si256(a, 0), mask, + reinterpret_cast(ptr)); +} +template <> +void store_overflow(void* ptr, __m256i a, int remain) { + __m256i mask = _m256_continue_mask(remain); + _mm256_maskstore_epi32(reinterpret_cast(ptr), mask, a); +} +template MEGDNN_ATTRIBUTE_TARGET("avx2") static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, - int32_t* c_ptr, + CType* c_ptr, const uint32_t ldc, const uint32_t k) { constexpr uint32_t k_step = 2; @@ -104,19 +141,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr, pack_b_ptr += 32; } - _mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); - _mm256_storeu_si256((__m256i*)(c_ptr + 8), c_vec[1]); - _mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); - _mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); - _mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); - _mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc + 8), c_vec[5]); - _mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); - _mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc + 8), c_vec[7]); + store_overflow(c_ptr, c_vec[0]); + store_overflow(c_ptr + 8, c_vec[1]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + ldc + 8, c_vec[3]); + store_overflow(c_ptr + 2 * ldc, c_vec[4]); + store_overflow(c_ptr + 2 * ldc + 8, c_vec[5]); + store_overflow(c_ptr + 3 * ldc, c_vec[6]); + store_overflow(c_ptr + 3 * ldc + 8, c_vec[7]); } - +template MEGDNN_ATTRIBUTE_TARGET("avx2") static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( - const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, + const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const uint32_t ldc, const uint32_t k, const uint32_t remain_n) { constexpr uint32_t k_step = 2; @@ -173,15 +210,15 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( pack_b_ptr += 32; } - __m256i mask = _m256_continue_mask(remain_n); - _mm256_maskstore_epi32((c_ptr), mask, c_vec[0]); - _mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); - _mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); - _mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]); + store_overflow(c_ptr, c_vec[0], remain_n); + store_overflow(c_ptr + ldc, c_vec[2], remain_n); + store_overflow(c_ptr + 2 * ldc, c_vec[4], remain_n); + store_overflow(c_ptr + 3 * ldc, c_vec[6], remain_n); } +template MEGDNN_ATTRIBUTE_TARGET("avx2") static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( - const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, + const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const uint32_t ldc, const uint32_t k, const uint32_t remain_m, uint32_t remain_n) { constexpr uint32_t k_step = 2; @@ -239,29 +276,29 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( pack_b_ptr += 32; } - __m256i mask = _m256_continue_mask(remain_n); - _mm256_maskstore_epi32((c_ptr), mask, c_vec[0]); + store_overflow(c_ptr, c_vec[0], remain_n); switch (remain_m) { case 2: - _mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); + store_overflow(c_ptr + ldc, c_vec[2], remain_n); + break; case 3: - _mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); - _mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); + store_overflow(c_ptr + ldc, c_vec[2], remain_n); + store_overflow(c_ptr + 2 * ldc, c_vec[4], remain_n); break; case 4: - _mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); - _mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); - _mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]); + store_overflow(c_ptr + ldc, c_vec[2], remain_n); + store_overflow(c_ptr + 2 * ldc, c_vec[4], remain_n); + store_overflow(c_ptr + 3 * ldc, c_vec[6], remain_n); break; default: break; } } - +template MEGDNN_ATTRIBUTE_TARGET("avx2") static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m( - const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, + const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const uint32_t ldc, const uint32_t k, const uint32_t remain_m) { constexpr uint32_t k_step = 2; @@ -339,34 +376,36 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m( pack_a_ptr += 8; pack_b_ptr += 32; } - _mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); - _mm256_storeu_si256((__m256i*)(c_ptr + 8), c_vec[1]); + + store_overflow(c_ptr, c_vec[0]); + store_overflow(c_ptr + 8, c_vec[1]); + switch (remain_m) { case 2: - _mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); - _mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + ldc + 8, c_vec[3]); break; case 3: - _mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); - _mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); - _mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); - _mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc + 8), c_vec[5]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + ldc + 8, c_vec[3]); + store_overflow(c_ptr + 2 * ldc, c_vec[4]); + store_overflow(c_ptr + 2 * ldc + 8, c_vec[5]); break; case 4: - _mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); - _mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); - _mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); - _mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc + 8), c_vec[5]); - _mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); - _mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc + 8), c_vec[7]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + ldc + 8, c_vec[3]); + store_overflow(c_ptr + 2 * ldc, c_vec[4]); + store_overflow(c_ptr + 2 * ldc + 8, c_vec[5]); + store_overflow(c_ptr + 3 * ldc, c_vec[6]); + store_overflow(c_ptr + 3 * ldc + 8, c_vec[7]); default: break; } } - +template MEGDNN_ATTRIBUTE_TARGET("avx2") static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n( - const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, + const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const uint32_t ldc, const uint32_t k, uint32_t remain_n) { constexpr uint32_t k_step = 2; @@ -446,29 +485,28 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n( } if (remain_n >= 8) { - _mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); - _mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); - _mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); - _mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); + store_overflow(c_ptr, c_vec[0]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + 2 * ldc, c_vec[4]); + store_overflow(c_ptr + 3 * ldc, c_vec[6]); remain_n -= 8; if (remain_n > 0) { - __m256i mask = _m256_continue_mask(remain_n); - _mm256_maskstore_epi32((c_ptr + 8), mask, c_vec[1]); - _mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); - _mm256_maskstore_epi32((c_ptr + 2 * ldc + 8), mask, c_vec[5]); - _mm256_maskstore_epi32((c_ptr + 3 * ldc + 8), mask, c_vec[7]); + store_overflow(c_ptr + 8, c_vec[1], remain_n); + store_overflow(c_ptr + ldc + 8, c_vec[3], remain_n); + store_overflow(c_ptr + 2 * ldc + 8, c_vec[5], remain_n); + store_overflow(c_ptr + 3 * ldc + 8, c_vec[7], remain_n); } } else { - __m256i mask = _m256_continue_mask(remain_n); - _mm256_maskstore_epi32((c_ptr), mask, c_vec[0]); - _mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); - _mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); - _mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]); + store_overflow(c_ptr, c_vec[0], remain_n); + store_overflow(c_ptr + ldc, c_vec[2], remain_n); + store_overflow(c_ptr + 2 * ldc, c_vec[4], remain_n); + store_overflow(c_ptr + 3 * ldc, c_vec[6], remain_n); } } +template MEGDNN_ATTRIBUTE_TARGET("avx2") static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n( - const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, + const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const uint32_t ldc, const uint32_t k, const uint32_t remain_m, uint32_t remain_n) { constexpr uint32_t k_step = 2; @@ -549,19 +587,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n( } if (remain_n >= 8) { - _mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); + store_overflow(c_ptr, c_vec[0]); switch (remain_m) { case 2: - _mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); + store_overflow(c_ptr + ldc, c_vec[2]); break; case 3: - _mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); - _mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + 2 * ldc, c_vec[4]); break; case 4: - _mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); - _mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); - _mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + 2 * ldc, c_vec[4]); + store_overflow(c_ptr + 3 * ldc, c_vec[6]); break; default: break; @@ -569,43 +607,41 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n( remain_n -= 8; if (remain_n > 0) { - __m256i mask = _m256_continue_mask(remain_n); - _mm256_maskstore_epi32((c_ptr + 8), mask, c_vec[1]); + store_overflow(c_ptr + 8, c_vec[1], remain_n); switch (remain_m) { case 2: - _mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); + store_overflow(c_ptr + ldc + 8, c_vec[3], remain_n); break; case 3: - _mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); - _mm256_maskstore_epi32((c_ptr + 2 * ldc + 8), mask, - c_vec[5]); + store_overflow(c_ptr + ldc + 8, c_vec[3], remain_n); + store_overflow(c_ptr + 2 * ldc + 8, c_vec[5], + remain_n); break; case 4: - _mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); - _mm256_maskstore_epi32((c_ptr + 2 * ldc + 8), mask, - c_vec[5]); - _mm256_maskstore_epi32((c_ptr + 3 * ldc + 8), mask, - c_vec[7]); + store_overflow(c_ptr + ldc + 8, c_vec[3], remain_n); + store_overflow(c_ptr + 2 * ldc + 8, c_vec[5], + remain_n); + store_overflow(c_ptr + 3 * ldc + 8, c_vec[7], + remain_n); break; default: break; } } } else { - __m256i mask = _m256_continue_mask(remain_n); - _mm256_maskstore_epi32((c_ptr), mask, c_vec[0]); + store_overflow(c_ptr, c_vec[0], remain_n); switch (remain_m) { case 2: - _mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); + store_overflow(c_ptr + ldc, c_vec[2], remain_n); break; case 3: - _mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); - _mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); + store_overflow(c_ptr + ldc, c_vec[2], remain_n); + store_overflow(c_ptr + 2 * ldc, c_vec[4], remain_n); break; case 4: - _mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); - _mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); - _mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]); + store_overflow(c_ptr + ldc, c_vec[2], remain_n); + store_overflow(c_ptr + 2 * ldc, c_vec[4], remain_n); + store_overflow(c_ptr + 3 * ldc, c_vec[6], remain_n); break; default: break; @@ -833,4 +869,5 @@ static inline void gemm_s8s8s32_avx2_4x16x2_pack_at(dt_int16* out, } // namespace x86 } // namespace megdnn - // vim: syntax=cpp.doxygen \ No newline at end of file + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/x86/matrix_mul/int8/strategy.h b/dnn/src/x86/matrix_mul/int8/strategy.h index db4c557a8..97ba36a0c 100644 --- a/dnn/src/x86/matrix_mul/int8/strategy.h +++ b/dnn/src/x86/matrix_mul/int8/strategy.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "src/fallback/matrix_mul/gemm_common.h" @@ -29,6 +30,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32, 4, 16, 2, false, false, gemm_avx2_s8s8s32_4x16x2); +MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int32, + 4, 16, 2, false, false, + gemm_avx2_s8s8s16_4x16x2); + MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32, 4, 8, 2, false, false, gemm_sse_s8s8s32_4x8x2); diff --git a/dnn/src/x86/matrix_mul/opr_impl.cpp b/dnn/src/x86/matrix_mul/opr_impl.cpp index 032356b60..125b5abbf 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.cpp +++ b/dnn/src/x86/matrix_mul/opr_impl.cpp @@ -37,6 +37,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt8x8x32AVX2M4N16K2 algoint8x8x32avx2_m4n16k2; AlgoInt8x8x32AVX2M2N4K16 algoint8x8x32avx2_m2n4k16; AlgoInt8x8x32SSEM4N8K2 algoint8x8x32sse_m4n8k2; + AlgoInt8x8x16AVX2 algoint8x8x16avx2_m4n16k2; AlgoF32MK8_8x8 algof32mk8_8x8; public: @@ -47,6 +48,7 @@ public: #endif } all_algos.emplace_back(&algoint8x8x32avx2_m4n16k2); + all_algos.emplace_back(&algoint8x8x16avx2_m4n16k2); all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16); all_algos.emplace_back(&algoint8x8x32sse_m4n8k2); all_algos.emplace_back(&algof32mk8_8x8); diff --git a/dnn/src/x86/matrix_mul/opr_impl.h b/dnn/src/x86/matrix_mul/opr_impl.h index 10af99866..d15e72625 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.h +++ b/dnn/src/x86/matrix_mul/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -54,6 +55,7 @@ protected: class AlgoInt8x8x32AVX2M2N4K16; class AlgoInt8x8x32AVX2M4N16K2; class AlgoInt8x8x32SSEM4N8K2; + class AlgoInt8x8x16AVX2; class AlgoPack; class AlgoF32MK8_8x8; }; diff --git a/dnn/test/x86/conv_bias.cpp b/dnn/test/x86/conv_bias.cpp index bc65c0a79..903950ca1 100644 --- a/dnn/test/x86/conv_bias.cpp +++ b/dnn/test/x86/conv_bias.cpp @@ -752,7 +752,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) { } } -TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { +TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X) { using namespace conv_bias; std::vector args; @@ -807,6 +807,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { .set_param(arg.param) \ .execs({arg.src, arg.filter, {}, {}, {}}); \ } +#define cb2(algo_name) \ + checker.set_before_exec_callback( \ + conv_bias::ConvBiasAlgoChecker(algo_name)); \ + checker.set_dtype(0, dtype::Int8()); \ + checker.set_dtype(1, dtype::Int8()); \ + checker.set_dtype(2, dtype::Int16()); \ + checker.set_dtype(4, dtype::Int16()); \ + for (auto&& arg : args) { \ + checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); \ + } #if MEGDNN_X86_WITH_MKL_DNN if (megdnn::x86::is_supported(x86::SIMDType::VNNI)) { @@ -821,12 +831,14 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { if (megdnn::x86::is_supported(x86::SIMDType::AVX2)) { cb("IM2COLMATMUL:X86_INT8X8X32_AVX2_2X4X16"); cb("IM2COLMATMUL:X86_INT8X8X32_AVX2_4X16X2"); + cb2("IM2COLMATMUL:X86_INT8X8X16_AVX2"); } if (::megdnn::x86::is_supported(::megdnn::x86::SIMDType::SSE4_2)) { cb("IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2"); } #undef cb +#undef cb2 } TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) { @@ -1964,6 +1976,39 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8) { shapes_and_computation.clear(); } +TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_8816) { + constexpr size_t RUNS = 30; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::DENSE; + + std::vector data_type = {dtype::Int8(), dtype::Int8(), + dtype::Int16(), dtype::Int16()}; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS) { + param.pad_h = FS / 2; + param.pad_w = FS / 2; + + SmallVector shapes{ + {N, IC, H, W}, {OC, IC, FS, FS}, {}, {}, {}}; + TensorShape dst{N, OC, (H + 2 * param.pad_h - FS) / param.stride_h + 1, + (W + 2 * param.pad_w - FS) / param.stride_w + 1}; + float computations = (IC * FS * FS * dst.total_nr_elems() * 2) * 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 48, 192, 15, 15, 1); + + std::string algo_name = "IM2COLMATMUL:X86_INT8X8X16_AVX2"; + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + shapes_and_computation.clear(); +} + TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8_STRIDE2) { constexpr size_t RUNS = 50; @@ -1985,7 +2030,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, SmallVector shapes{ {N, IC, H, W}, {OC, IC, FS, FS}, {}, {}, {}}; TensorShape dst{N, OC, (H + 2 * param.pad_h - FS) / param.stride_h + 1, - (W + 2 * param.pad_w - FS) / param.pad_w + 1}; + (W + 2 * param.pad_w - FS) / param.stride_w + 1}; float computations = (IC * FS * FS * dst.total_nr_elems() * 2) * 1e-6; shapes_and_computation.push_back(std::make_pair(shapes, computations)); }; diff --git a/dnn/test/x86/convolution.cpp b/dnn/test/x86/convolution.cpp index 039db8d4f..ab0aaeb93 100644 --- a/dnn/test/x86/convolution.cpp +++ b/dnn/test/x86/convolution.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "test/x86/fixture.h" @@ -369,6 +370,63 @@ TEST_F(X86, CONVOLUTION_DIRECT_MKLDNN_C8) { #endif #if MEGDNN_WITH_BENCHMARK +TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x16) { + using namespace convolution; + using Param = param::Convolution; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t stride, size_t group = 1) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + if (group > 1) { + param.sparse = param::Convolution::Sparse::GROUP; + args.emplace_back( + param, TensorShape{1, ic, h, w}, + TensorShape{group, oc / group, ic / group, kernel, kernel}); + } else { + param.sparse = param::Convolution::Sparse::DENSE; + args.emplace_back(param, TensorShape{1, ic, h, w}, + TensorShape{oc, ic, kernel, kernel}); + } + }; + + run(48, 96, 15, 15, 1, 1); + run(64, 64, 60, 60, 3, 1); + run(64, 64, 60, 60, 3, 1, 64); + + constexpr size_t RUN = 30; + Benchmarker benchmark(handle()); + benchmark.set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int16()); + benchmark.set_display(false); + benchmark.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Float32()}, + {arg.filter, dtype::Float32()}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float icpg = arg.filter.ndim == 4 ? arg.filter[1] : arg.filter[2]; + float filter = arg.filter.ndim == 4 ? arg.filter[2] : arg.filter[3]; + float computations = dst_layout.total_nr_elems() * icpg * filter * + filter * 2.0 / (1024 * 1024 * 1024) * 1e3; + + auto used_int = + benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / + RUN; + + printf("%s %s: int: %f ms %f Gflops \n", arg.src.to_string().c_str(), + arg.filter.to_string().c_str(), used_int, + computations / used_int); + } +} #if MEGDNN_X86_WITH_MKL_DNN TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) { using namespace convolution; @@ -419,7 +477,6 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) { float computations = dst_layout.total_nr_elems() * arg.filter[1] * arg.filter[2] * arg.filter[3] * 2.0 / (1024 * 1024 * 1024) * 1e3; - auto used_int = benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / RUN; diff --git a/dnn/test/x86/matrix_mul.cpp b/dnn/test/x86/matrix_mul.cpp index f58ec3122..8533f7c73 100644 --- a/dnn/test/x86/matrix_mul.cpp +++ b/dnn/test/x86/matrix_mul.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "test/x86/fixture.h" @@ -47,6 +48,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "X86_INT8X8X32_AVX2_4X16X2"); } +TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle(), "X86_INT8X8X16_AVX2"); +} TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "X86_INT8X8X32_SSE_4X8X2"); @@ -116,6 +121,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { benchmarker_avx2_4x16x2.set_before_exec_callback( AlgoChecker("X86_INT8X8X32_AVX2_4X16X2")); + Benchmarker benchmarker_avx2_4x16x2_8816(handle()); + benchmarker_avx2_4x16x2_8816.set_display(false) + .set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_rng(0, rng.get()) + .set_rng(1, rng.get()); + benchmarker_avx2_4x16x2_8816.set_before_exec_callback( + AlgoChecker("X86_INT8X8X16_AVX2")); + Benchmarker benchmarker_avx2_2x4x16(handle()); benchmarker_avx2_2x4x16.set_display(false) .set_times(RUNS) @@ -183,6 +199,12 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { << "k2_speed_up " << float_used / avx2_used_4x16x2 << ", k16_speed_up " << float_used / avx2_used_2x4x16 << ","; + auto avx2_used_4x16x2_8816 = + benchmarker_avx2_4x16x2_8816.exec({{M, K}, {K, N}, {}}) / + RUNS; + std::cout << "avx2_8816: " << avx2_used_4x16x2_8816 + << " ms, 8816 throughput " + << computations / avx2_used_4x16x2_8816 << " Gflops,"; } if (is_supported(SIMDType::SSE4_1)) { auto sse_used = -- GitLab