diff --git a/dnn/src/x86/matrix_mul/algos.cpp b/dnn/src/x86/matrix_mul/algos.cpp index c7d7ddd0f34f6071072b7d6d600ea81459f4e4e7..5ec7d957c7b4e38fb0ed02047b6030fed94543fc 100644 --- a/dnn/src/x86/matrix_mul/algos.cpp +++ b/dnn/src/x86/matrix_mul/algos.cpp @@ -184,7 +184,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern( } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32Vnni, - megdnn_x86_matmul_kern, 5, + megdnn_x86_matmul_kern, + "AlgoInt8x8x32Vnni"_hash, x86::matmul::gemm_int8_vnni_12x32x4, dt_int8, dt_int32, dt_uint8); #endif @@ -318,6 +319,8 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) { } } // namespace + +/*************************AlgoInt8x8x16AVX2********************/ void MatrixMulImpl::AlgoInt8x8x16AVX2::gemm_s8s8s16_avx2_4x16x2( const MatrixMulImpl::KernParam& kern_param) { MEGDNN_MARK_USED_VAR(kern_param); @@ -389,9 +392,86 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace( .get_workspace_size(); } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( - AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, 8, + AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash, x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16); +/*************************AlgoInt8x8x16SSE********************/ +void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2( + const MatrixMulImpl::KernParam& kern_param) { + MEGDNN_MARK_USED_VAR(kern_param); + MIDOUT_BEGIN(megdnn_x86_matmul_kern_sse_4x8x2, midout_iv(2)) { + constexpr int cacheline = 64; + const size_t m = kern_param.M; + const size_t n = kern_param.N; + const size_t k = kern_param.K; + const bool trans_a = kern_param.trA; + const bool trans_b = kern_param.trB; + const size_t lda = kern_param.LDA; + const size_t ldb = kern_param.LDB; + const size_t ldc = kern_param.LDC; + auto a_type = kern_param.A_type; + auto b_type = kern_param.B_type; + auto c_type = kern_param.C_type; + const auto a_ptr = kern_param.A(); + const auto b_ptr = kern_param.B(); + auto c_ptr = kern_param.C(); + x86::matmul::gemm_sse_s8s8s16_4x8x2 strategy(m, n, k, a_type, b_type, + c_type); + + megdnn::matmul::GemmInterleaved( + m, n, k, trans_a, trans_b, strategy, cacheline) + .execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16SSE::get_kern( + const KernSizeParam&) const { + return gemm_s8s8s16_sse_4x8x2; +} +bool MatrixMulImpl::AlgoInt8x8x16SSE::usable( + const KernSizeParam& kern_size_param) const { + bool is_ab_same = + kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv(); + bool is_type_ok = + ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && + kern_size_param.C_type.enumv() == DTypeEnum::Int16) || + (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16)); + bool is_mode_ok = + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + is_supported(SIMDType::SSE4_1); + bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok; + return is_param_ok; +} +bool MatrixMulImpl::AlgoInt8x8x16SSE::preferred(const KernSizeParam&) const { + return true; +} +size_t MatrixMulImpl::AlgoInt8x8x16SSE::get_workspace( + const KernSizeParam& kern_param) const { + constexpr int cacheline = 64; + const size_t m = kern_param.M; + const size_t n = kern_param.N; + const size_t k = kern_param.K; + const bool trans_a = kern_param.trA; + const bool trans_b = kern_param.trB; + auto a_type = kern_param.A_type; + auto b_type = kern_param.B_type; + auto c_type = kern_param.C_type; + x86::matmul::gemm_sse_s8s8s16_4x8x2 strategy(m, n, k, a_type, b_type, + c_type); + + return megdnn::matmul::GemmInterleaved( + m, n, k, trans_a, trans_b, strategy, cacheline) + .get_workspace_size(); +} +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16SSE, + megdnn_x86_matmul_kern, + "AlgoInt8x8x16SSE"_hash, + x86::matmul::gemm_sse_s8s8s16_4x8x2, + dt_int8, dt_int16, dt_int16); + +/*************************AlgoInt8x8x32AVX2M4N16K2********************/ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( const KernSizeParam&) const { return gemm_s8s8s32_avx2_4x16x2; @@ -426,8 +506,9 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace( .get_workspace_size(); } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( - AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern, 8, - x86::matmul::gemm_avx2_s8s8s32_4x16x2, dt_int8, dt_int32, dt_int16); + AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern, + "AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2, + dt_int8, dt_int32, dt_int16); MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern( const KernSizeParam&) const { @@ -463,7 +544,8 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace( .get_workspace_size(); } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, - megdnn_x86_matmul_kern, 8, + megdnn_x86_matmul_kern, + "AlgoInt8x8x32AVX2M2N4K16"_hash, x86::matmul::gemm_avx2_s8s8s32_2x4x16, dt_int8, dt_int32); @@ -501,7 +583,8 @@ size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace( .get_workspace_size(); } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2, - megdnn_x86_matmul_kern, 9, + megdnn_x86_matmul_kern, + "AlgoInt8x8x32SSEM4N8K2"_hash, x86::matmul::gemm_sse_s8s8s32_4x8x2, dt_int8, dt_int32, dt_int16); diff --git a/dnn/src/x86/matrix_mul/algos.h b/dnn/src/x86/matrix_mul/algos.h index a4dabae26491e5fca3775a4b57b221977a4a981c..01e51dd9fb97e05d82866cb1f285b9f38178b970 100644 --- a/dnn/src/x86/matrix_mul/algos.h +++ b/dnn/src/x86/matrix_mul/algos.h @@ -76,7 +76,6 @@ class MatrixMulImpl::AlgoInt8x8x16AVX2 : public AlgoBase { private: static void gemm_s8s8s16_avx2_4x16x2( const MatrixMulImpl::KernParam& kern_param); - static MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 m_algo; public: bool is_reproducible() const override { return true; } @@ -89,6 +88,22 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; +class MatrixMulImpl::AlgoInt8x8x16SSE : public AlgoBase { +private: + static void gemm_s8s8s16_sse_4x8x2( + const MatrixMulImpl::KernParam& kern_param); + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "X86_INT8X8X16_SSE"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_x86_algo_type; } + bool preferred(const KernSizeParam&) const override; + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase { public: bool is_reproducible() const override { return true; } diff --git a/dnn/src/x86/matrix_mul/int8/kernel_sse_4x8x2.h b/dnn/src/x86/matrix_mul/int8/kernel_sse_4x8x2.h index 6d10cf3709c576d0389c0114f9784d4929747cfe..c0e90c4df4b055d775c77d691bc34f3c3ea7173d 100644 --- a/dnn/src/x86/matrix_mul/int8/kernel_sse_4x8x2.h +++ b/dnn/src/x86/matrix_mul/int8/kernel_sse_4x8x2.h @@ -6,10 +6,17 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include +#ifdef WIN32 +#include +#include +#include +#include +#endif #include #include #include @@ -21,10 +28,44 @@ namespace x86 { namespace matmul_sse_4x8x2 { +template +MEGDNN_ATTRIBUTE_TARGET("sse4.1") +void store_overflow(void* ptr, __m128i a); + +template <> +void store_overflow(void* ptr, __m128i a) { + a = _mm_shufflelo_epi16(a, 0x08); + a = _mm_shufflehi_epi16(a, 0x08); + a = _mm_shuffle_epi32(a, 0x08); + _mm_storel_epi64((__m128i*)ptr, a); +} +template <> +void store_overflow(void* ptr, __m128i a) { + _mm_storeu_si128((__m128i*)(ptr), a); +} +template +MEGDNN_ATTRIBUTE_TARGET("sse4.1") +void store_overflow(void* ptr, __m128i a, int remain); + +template <> +void store_overflow(void* ptr, __m128i a, int remain) { + __m128i mask = _mm_continue_mask(remain * sizeof(int16_t)); + a = _mm_shufflelo_epi16(a, 0x08); + a = _mm_shufflehi_epi16(a, 0x08); + a = _mm_shuffle_epi32(a, 0x08); + _mm_maskmoveu_si128(a, mask, reinterpret_cast(ptr)); +} +template <> +void store_overflow(void* ptr, __m128i a, int remain) { + __m128i mask = _mm_continue_mask(remain * sizeof(int32_t)); + _mm_maskmoveu_si128(a, mask, reinterpret_cast(ptr)); +} + +template MEGDNN_ATTRIBUTE_TARGET("sse4.1") static inline void kern_gemm_s8s8s32_sse_4x8x2(const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, - int32_t* c_ptr, const int ldc, + CType* c_ptr, const int ldc, const int k) { constexpr int k_step = 2; @@ -102,20 +143,20 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2(const int16_t* pack_a_ptr, pack_a_ptr += 8; pack_b_ptr += 16; } - - _mm_storeu_si128((__m128i*)(c_ptr), c_vec[0]); - _mm_storeu_si128((__m128i*)(c_ptr + 4), c_vec[1]); - _mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); - _mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); - _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); - _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc + 4), c_vec[5]); - _mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); - _mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc + 4), c_vec[7]); + store_overflow(c_ptr, c_vec[0]); + store_overflow(c_ptr + 4, c_vec[1]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + ldc + 4, c_vec[3]); + store_overflow(c_ptr + 2 * ldc, c_vec[4]); + store_overflow(c_ptr + 2 * ldc + 4, c_vec[5]); + store_overflow(c_ptr + 3 * ldc, c_vec[6]); + store_overflow(c_ptr + 3 * ldc + 4, c_vec[7]); } +template MEGDNN_ATTRIBUTE_TARGET("sse4.1") static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m( - const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, + const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const int ldc, const int k, const int remain_m) { constexpr int k_step = 2; @@ -194,34 +235,35 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m( pack_b_ptr += 16; } - _mm_storeu_si128((__m128i*)(c_ptr), c_vec[0]); - _mm_storeu_si128((__m128i*)(c_ptr + 4), c_vec[1]); + store_overflow(c_ptr, c_vec[0]); + store_overflow(c_ptr + 4, c_vec[1]); switch (remain_m) { case 2: - _mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); - _mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + ldc + 4, c_vec[3]); break; case 3: - _mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); - _mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); - _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); - _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc + 4), c_vec[5]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + ldc + 4, c_vec[3]); + store_overflow(c_ptr + 2 * ldc, c_vec[4]); + store_overflow(c_ptr + 2 * ldc + 4, c_vec[5]); break; case 4: - _mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); - _mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); - _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); - _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc + 4), c_vec[5]); - _mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); - _mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc + 4), c_vec[7]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + ldc + 4, c_vec[3]); + store_overflow(c_ptr + 2 * ldc, c_vec[4]); + store_overflow(c_ptr + 2 * ldc + 4, c_vec[5]); + store_overflow(c_ptr + 3 * ldc, c_vec[6]); + store_overflow(c_ptr + 3 * ldc + 4, c_vec[7]); default: break; } } +template MEGDNN_ATTRIBUTE_TARGET("sse4.1") static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n( - const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, + const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const int ldc, const int k, int remain_n) { constexpr int k_step = 2; @@ -301,10 +343,10 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n( } if (remain_n >= 4) { - _mm_storeu_si128((__m128i*)(c_ptr), c_vec[0]); - _mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); - _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); - _mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); + store_overflow(c_ptr, c_vec[0]); + store_overflow(c_ptr + ldc, c_vec[2]); + store_overflow(c_ptr + 2 * ldc, c_vec[4]); + store_overflow(c_ptr + 3 * ldc, c_vec[6]); c_ptr += 4; remain_n -= 4; c_vec[0] = c_vec[1]; @@ -312,35 +354,16 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n( c_vec[4] = c_vec[5]; c_vec[6] = c_vec[7]; } - - switch (remain_n) { - case 0: - break; - case 1: - *(c_ptr) = _mm_extract_epi32(c_vec[0], 0); - *(c_ptr + ldc) = _mm_extract_epi32(c_vec[2], 0); - *(c_ptr + 2 * ldc) = _mm_extract_epi32(c_vec[4], 0); - *(c_ptr + 3 * ldc) = _mm_extract_epi32(c_vec[6], 0); - break; - case 2: - case 3: - _mm_storel_epi64((__m128i*)(c_ptr), c_vec[0]); - _mm_storel_epi64((__m128i*)(c_ptr + ldc), c_vec[2]); - _mm_storel_epi64((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); - _mm_storel_epi64((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); - break; - } - if (remain_n == 3) { - *(c_ptr + 2) = _mm_extract_epi32(c_vec[0], 2); - *(c_ptr + ldc + 2) = _mm_extract_epi32(c_vec[2], 2); - *(c_ptr + 2 * ldc + 2) = _mm_extract_epi32(c_vec[4], 2); - *(c_ptr + 3 * ldc + 2) = _mm_extract_epi32(c_vec[6], 2); - } + store_overflow(c_ptr, c_vec[0], remain_n); + store_overflow(c_ptr + ldc, c_vec[2], remain_n); + store_overflow(c_ptr + 2 * ldc, c_vec[4], remain_n); + store_overflow(c_ptr + 3 * ldc, c_vec[6], remain_n); } +template MEGDNN_ATTRIBUTE_TARGET("sse4.1") static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( - const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, + const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const int ldc, const int k, int remain_m, int remain_n) { constexpr int k_step = 2; @@ -421,8 +444,7 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( int index_array[4]{0, 2, 4, 6}; if (remain_n >= 4) { for (int m = 0; m < remain_m; ++m) { - _mm_storeu_si128((__m128i*)(c_ptr + m * ldc), - c_vec[index_array[m]]); + store_overflow(c_ptr + m * ldc, c_vec[index_array[m]]); } c_ptr += 4; remain_n -= 4; @@ -431,29 +453,8 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( c_vec[4] = c_vec[5]; c_vec[6] = c_vec[7]; } - - switch (remain_n) { - case 0: - break; - case 1: - for (int m = 0; m < remain_m; ++m) { - *(c_ptr + m * ldc) = - _mm_extract_epi32(c_vec[index_array[m]], 0); - } - break; - case 2: - case 3: - for (int m = 0; m < remain_m; ++m) { - _mm_storel_epi64((__m128i*)(c_ptr + m * ldc), - c_vec[index_array[m]]); - } - break; - } - if (remain_n == 3) { - for (int m = 0; m < remain_m; ++m) { - *(c_ptr + m * ldc + 2) = - _mm_extract_epi32(c_vec[index_array[m]], 2); - } + for (int m = 0; m < remain_m; ++m) { + store_overflow(c_ptr + m * ldc, c_vec[index_array[m]], remain_n); } } diff --git a/dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp b/dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp index 875991f514f68bde1dca4196e31bf897edd11e4f..d4f931a9e914b1c51391e7ee3b733d20a7861f7b 100644 --- a/dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp +++ b/dnn/src/x86/matrix_mul/int8/sse_strategy_4x8x2.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/common/utils.h" @@ -18,11 +19,9 @@ using namespace megdnn; using namespace x86; using namespace x86::matmul; -MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_sse_s8s8s32_4x8x2); - -void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +static inline void gemm_packa(dt_int16* out, const dt_int8* in, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) { if (transpose) { matmul_sse_4x8x2::gemm_s8s8s32_sse_4x8x2_pack_at(out, in, ldin, y0, ymax, k0, kmax); @@ -31,10 +30,8 @@ void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin, ymax, k0, kmax); } } - -void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax, - bool transpose) const { +static inline void gemm_packb(dt_int8* out, const dt_int8* in, int ldin, int x0, + int xmax, int k0, int kmax, bool transpose) { if (transpose) { matmul_sse_4x8x2::gemm_s8s8s32_sse_4x8x2_pack_bt(out, in, ldin, x0, xmax, k0, kmax); @@ -43,20 +40,11 @@ void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, xmax, k0, kmax); } } - -void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr, - const dt_int8* pack_b_ptr, size_t m, size_t n, - size_t k, dt_int32* c_ptr, size_t ldc, - bool is_first_k, const dt_int32*, - dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int32) || - (A_dtype.enumv() == DTypeEnum::QuantizedS8 && - C_dtype.enumv() == DTypeEnum::QuantizedS32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); - megdnn_assert(is_first_k == true); +template +static inline void gemm_kern(const dt_int16* pack_a_ptr, + const dt_int8* pack_b_ptr, size_t m, size_t n, + size_t k, CType* c_ptr, size_t ldc, + bool is_first_k) { constexpr int m_tile = 4; constexpr int n_tile = 8; constexpr int k_tile = 2; @@ -99,4 +87,62 @@ void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr, } } } +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_sse_s8s8s32_4x8x2); +void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose); +} + +void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax, + bool transpose) const { + gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose); +} + +void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr, + const dt_int8* pack_b_ptr, size_t m, size_t n, + size_t k, dt_int32* c_ptr, size_t ldc, + bool is_first_k, const dt_int32*, + dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + megdnn_assert(is_first_k == true); + gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k); +} + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_sse_s8s8s16_4x8x2); +void gemm_sse_s8s8s16_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin, + int y0, int ymax, int k0, int kmax, + bool transpose) const { + gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose); +} + +void gemm_sse_s8s8s16_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax, + bool transpose) const { + gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose); +} + +void gemm_sse_s8s8s16_4x8x2::kern(const dt_int16* pack_a_ptr, + const dt_int8* pack_b_ptr, size_t m, size_t n, + size_t k, dt_int16* c_ptr, size_t ldc, + bool is_first_k, const dt_int32*, + dt_int32*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS16)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), + C_dtype.name()); + megdnn_assert(is_first_k == true); + gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k); +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/matrix_mul/int8/strategy.h b/dnn/src/x86/matrix_mul/int8/strategy.h index 97ba36a0cf039d0da4c208ecd2313564bd69e962..2be50dec80e4ca816f01d3c2307e0216e925715b 100644 --- a/dnn/src/x86/matrix_mul/int8/strategy.h +++ b/dnn/src/x86/matrix_mul/int8/strategy.h @@ -38,6 +38,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32, 4, 8, 2, false, false, gemm_sse_s8s8s32_4x8x2); +MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int32, + 4, 8, 2, false, false, + gemm_sse_s8s8s16_4x8x2); + } // namespace matmul } // namespace x86 } // namespace megdnn diff --git a/dnn/src/x86/matrix_mul/opr_impl.cpp b/dnn/src/x86/matrix_mul/opr_impl.cpp index 125b5abbf9f37ba7d3a034756468f639c244ab86..a9d7c312e5809201aae92a04399cb76ae49e0762 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.cpp +++ b/dnn/src/x86/matrix_mul/opr_impl.cpp @@ -38,6 +38,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt8x8x32AVX2M2N4K16 algoint8x8x32avx2_m2n4k16; AlgoInt8x8x32SSEM4N8K2 algoint8x8x32sse_m4n8k2; AlgoInt8x8x16AVX2 algoint8x8x16avx2_m4n16k2; + AlgoInt8x8x16SSE algoint8x8x16sse_m4n8k2; AlgoF32MK8_8x8 algof32mk8_8x8; public: @@ -51,6 +52,7 @@ public: all_algos.emplace_back(&algoint8x8x16avx2_m4n16k2); all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16); all_algos.emplace_back(&algoint8x8x32sse_m4n8k2); + all_algos.emplace_back(&algoint8x8x16sse_m4n8k2); all_algos.emplace_back(&algof32mk8_8x8); #if MEGDNN_X86_WITH_MKL_DNN all_algos.emplace_back(&algoint8x8x32mkldnn); diff --git a/dnn/src/x86/matrix_mul/opr_impl.h b/dnn/src/x86/matrix_mul/opr_impl.h index d15e72625d3f2e5d19194236da1f8e7655513b65..be76c26cae5b701cc29cbe691df4697c17efa24b 100644 --- a/dnn/src/x86/matrix_mul/opr_impl.h +++ b/dnn/src/x86/matrix_mul/opr_impl.h @@ -56,6 +56,7 @@ protected: class AlgoInt8x8x32AVX2M4N16K2; class AlgoInt8x8x32SSEM4N8K2; class AlgoInt8x8x16AVX2; + class AlgoInt8x8x16SSE; class AlgoPack; class AlgoF32MK8_8x8; }; diff --git a/dnn/test/x86/conv_bias.cpp b/dnn/test/x86/conv_bias.cpp index 903950ca1f118f607cb6040bd78b81a497753dbf..8e5edd79816ac927026dadff0e051826de7e0471 100644 --- a/dnn/test/x86/conv_bias.cpp +++ b/dnn/test/x86/conv_bias.cpp @@ -835,6 +835,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X) { } if (::megdnn::x86::is_supported(::megdnn::x86::SIMDType::SSE4_2)) { cb("IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2"); + cb2("IM2COLMATMUL:X86_INT8X8X16_SSE"); } #undef cb @@ -1002,7 +1003,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS) { } #endif -TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { +TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X) { using namespace conv_bias; UniformIntRNG rng{-50, 50}; float epsilon = 0.001; @@ -1028,10 +1029,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, "CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24"); + checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, + dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, + "CONV1x1:X86_INT8X8X16_AVX2"); } checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, "CONV1x1:X86_INT8X8X32_SSE_4X8X2:48"); + checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, + dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, + "CONV1x1:X86_INT8X8X16_SSE"); } /************************* End Conv1x1 PackA ************************/ diff --git a/dnn/test/x86/convolution.cpp b/dnn/test/x86/convolution.cpp index ab0aaeb9309b89fe49007234ef8b955b39854302..5de87bcd70916b868ee7607ba9585158de6b342e 100644 --- a/dnn/test/x86/convolution.cpp +++ b/dnn/test/x86/convolution.cpp @@ -403,6 +403,7 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x16) { benchmark.set_dtype(0, dtype::Int8()) .set_dtype(1, dtype::Int8()) .set_dtype(2, dtype::Int16()); + benchmark.set_before_exec_callback(AlgoChecker(".*")); benchmark.set_display(false); benchmark.set_times(RUN); diff --git a/dnn/test/x86/matrix_mul.cpp b/dnn/test/x86/matrix_mul.cpp index 8533f7c73d7a99d28d094cb7387a798011d3e258..4ce8f1983dce37519ea618338d1005c299d6f40e 100644 --- a/dnn/test/x86/matrix_mul.cpp +++ b/dnn/test/x86/matrix_mul.cpp @@ -52,6 +52,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, handle(), "X86_INT8X8X16_AVX2"); } +TEST_F(X86, MATRIX_MUL_SSE_8X8X16) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle(), "X86_INT8X8X16_SSE"); +} TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "X86_INT8X8X32_SSE_4X8X2"); @@ -132,6 +136,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { benchmarker_avx2_4x16x2_8816.set_before_exec_callback( AlgoChecker("X86_INT8X8X16_AVX2")); + Benchmarker benchmarker_sse_4x8x2_8816(handle()); + benchmarker_sse_4x8x2_8816.set_display(false) + .set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_rng(0, rng.get()) + .set_rng(1, rng.get()); + benchmarker_sse_4x8x2_8816.set_before_exec_callback( + AlgoChecker("X86_INT8X8X16_SSE")); + Benchmarker benchmarker_avx2_2x4x16(handle()); benchmarker_avx2_2x4x16.set_display(false) .set_times(RUNS) @@ -212,9 +227,15 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { std::cout << "sse: " << sse_used << " ms, " << computations / sse_used << " Gflops, " << "speed_up " << float_used / sse_used << ", "; + auto sse_used_8816 = + benchmarker_sse_4x8x2_8816.exec({{M, K}, {K, N}, {}}) / + RUNS; + std::cout << "sse_8816: " << sse_used_8816 << " ms, " + << computations / sse_used_8816 << " Gflops, "; } std::cout << std::endl; }; + run(256, 256, 256); for (size_t M : {8, 64, 112, 256, 512}) { for (size_t K : {8, 16, 32, 64, 112, 256, 512}) {