#include #ifdef WIN32 #include #include #include #include #endif #include #include #include #include "src/common/utils.h" #include "src/x86/matrix_mul/common/common.h" namespace megdnn { namespace x86 { namespace matmul_sse_4x8x2 { template MEGDNN_ATTRIBUTE_TARGET("sse4.1") void inline store_overflow(void* ptr, __m128i a); template <> void inline store_overflow(void* ptr, __m128i a) { a = _mm_shufflelo_epi16(a, 0x08); a = _mm_shufflehi_epi16(a, 0x08); a = _mm_shuffle_epi32(a, 0x08); _mm_storel_epi64((__m128i*)ptr, a); } template <> void inline store_overflow(void* ptr, __m128i a) { _mm_storeu_si128((__m128i*)(ptr), a); } template MEGDNN_ATTRIBUTE_TARGET("sse4.1") void inline store_overflow(void* ptr, __m128i a, int remain); template <> void inline store_overflow(void* ptr, __m128i a, int remain) { __m128i mask = _mm_continue_mask(remain * sizeof(int16_t)); a = _mm_shufflelo_epi16(a, 0x08); a = _mm_shufflehi_epi16(a, 0x08); a = _mm_shuffle_epi32(a, 0x08); _mm_maskmoveu_si128(a, mask, reinterpret_cast(ptr)); } template <> void inline store_overflow(void* ptr, __m128i a, int remain) { __m128i mask = _mm_continue_mask(remain * sizeof(int32_t)); _mm_maskmoveu_si128(a, mask, reinterpret_cast(ptr)); } template MEGDNN_ATTRIBUTE_TARGET("sse4.1") static inline void kern_gemm_s8s8s32_sse_4x8x2( const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const int ldc, const int k) { constexpr int k_step = 2; __m128i a_vec[2]; __m128i b_vec[2]; __m128i c_vec[4 * 2]; __m128i c_temp[4]; b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr); b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[0] = _mm_setzero_si128(); c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]); c_vec[1] = _mm_setzero_si128(); c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[2] = _mm_setzero_si128(); c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]); c_vec[3] = _mm_setzero_si128(); c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[4] = _mm_setzero_si128(); c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]); c_vec[5] = _mm_setzero_si128(); c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[6] = _mm_setzero_si128(); c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]); c_vec[7] = _mm_setzero_si128(); c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]); pack_a_ptr += 8; pack_b_ptr += 16; for (int iter_k = 2; iter_k < k; iter_k += k_step) { b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr); b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]); c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]); c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]); c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]); c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]); pack_a_ptr += 8; pack_b_ptr += 16; } store_overflow(c_ptr, c_vec[0]); store_overflow(c_ptr + 4, c_vec[1]); store_overflow(c_ptr + ldc, c_vec[2]); store_overflow(c_ptr + ldc + 4, c_vec[3]); store_overflow(c_ptr + 2 * ldc, c_vec[4]); store_overflow(c_ptr + 2 * ldc + 4, c_vec[5]); store_overflow(c_ptr + 3 * ldc, c_vec[6]); store_overflow(c_ptr + 3 * ldc + 4, c_vec[7]); } template MEGDNN_ATTRIBUTE_TARGET("sse4.1") static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m( const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const int ldc, const int k, const int remain_m) { constexpr int k_step = 2; __m128i a_vec[2]; __m128i b_vec[2]; __m128i c_vec[4 * 2]; __m128i c_temp[4]; b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr); b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[0] = _mm_setzero_si128(); c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]); c_vec[1] = _mm_setzero_si128(); c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[2] = _mm_setzero_si128(); c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]); c_vec[3] = _mm_setzero_si128(); c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[4] = _mm_setzero_si128(); c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]); c_vec[5] = _mm_setzero_si128(); c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[6] = _mm_setzero_si128(); c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]); c_vec[7] = _mm_setzero_si128(); c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]); pack_a_ptr += 8; pack_b_ptr += 16; for (int iter_k = 2; iter_k < k; iter_k += k_step) { b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr); b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]); c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]); c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]); c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]); c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]); pack_a_ptr += 8; pack_b_ptr += 16; } store_overflow(c_ptr, c_vec[0]); store_overflow(c_ptr + 4, c_vec[1]); switch (remain_m) { case 2: store_overflow(c_ptr + ldc, c_vec[2]); store_overflow(c_ptr + ldc + 4, c_vec[3]); break; case 3: store_overflow(c_ptr + ldc, c_vec[2]); store_overflow(c_ptr + ldc + 4, c_vec[3]); store_overflow(c_ptr + 2 * ldc, c_vec[4]); store_overflow(c_ptr + 2 * ldc + 4, c_vec[5]); break; case 4: store_overflow(c_ptr + ldc, c_vec[2]); store_overflow(c_ptr + ldc + 4, c_vec[3]); store_overflow(c_ptr + 2 * ldc, c_vec[4]); store_overflow(c_ptr + 2 * ldc + 4, c_vec[5]); store_overflow(c_ptr + 3 * ldc, c_vec[6]); store_overflow(c_ptr + 3 * ldc + 4, c_vec[7]); default: break; } } template MEGDNN_ATTRIBUTE_TARGET("sse4.1") static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n( const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const int ldc, const int k, int remain_n) { constexpr int k_step = 2; __m128i a_vec[2]; __m128i b_vec[2]; __m128i c_vec[4 * 2]; __m128i c_temp[4]; b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr); b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[0] = _mm_setzero_si128(); c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]); c_vec[1] = _mm_setzero_si128(); c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[2] = _mm_setzero_si128(); c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]); c_vec[3] = _mm_setzero_si128(); c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[4] = _mm_setzero_si128(); c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]); c_vec[5] = _mm_setzero_si128(); c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[6] = _mm_setzero_si128(); c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]); c_vec[7] = _mm_setzero_si128(); c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]); pack_a_ptr += 8; pack_b_ptr += 16; for (int iter_k = 2; iter_k < k; iter_k += k_step) { b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr); b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]); c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]); c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]); c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]); c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]); pack_a_ptr += 8; pack_b_ptr += 16; } if (remain_n >= 4) { store_overflow(c_ptr, c_vec[0]); store_overflow(c_ptr + ldc, c_vec[2]); store_overflow(c_ptr + 2 * ldc, c_vec[4]); store_overflow(c_ptr + 3 * ldc, c_vec[6]); c_ptr += 4; remain_n -= 4; c_vec[0] = c_vec[1]; c_vec[2] = c_vec[3]; c_vec[4] = c_vec[5]; c_vec[6] = c_vec[7]; } store_overflow(c_ptr, c_vec[0], remain_n); store_overflow(c_ptr + ldc, c_vec[2], remain_n); store_overflow(c_ptr + 2 * ldc, c_vec[4], remain_n); store_overflow(c_ptr + 3 * ldc, c_vec[6], remain_n); } template MEGDNN_ATTRIBUTE_TARGET("sse4.1") static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, const int ldc, const int k, int remain_m, int remain_n) { constexpr int k_step = 2; __m128i a_vec[2]; __m128i b_vec[2]; __m128i c_vec[4 * 2]; __m128i c_temp[4]; b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr); b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[0] = _mm_setzero_si128(); c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]); c_vec[1] = _mm_setzero_si128(); c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[2] = _mm_setzero_si128(); c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]); c_vec[3] = _mm_setzero_si128(); c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[4] = _mm_setzero_si128(); c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]); c_vec[5] = _mm_setzero_si128(); c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[6] = _mm_setzero_si128(); c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]); c_vec[7] = _mm_setzero_si128(); c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]); pack_a_ptr += 8; pack_b_ptr += 16; for (int iter_k = 2; iter_k < k; iter_k += k_step) { b_vec[0] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr); b_vec[1] = _mm_cvtepi8_epi16_from_ptr(pack_b_ptr + 8); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 2)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[0] = _mm_add_epi32(c_vec[0], c_temp[0]); c_vec[1] = _mm_add_epi32(c_vec[1], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[2] = _mm_add_epi32(c_vec[2], c_temp[2]); c_vec[3] = _mm_add_epi32(c_vec[3], c_temp[3]); a_vec[0] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 4)); a_vec[1] = _mm_set1_epi32(*(int32_t*)(pack_a_ptr + 6)); c_temp[0] = _mm_madd_epi16(a_vec[0], b_vec[0]); c_temp[1] = _mm_madd_epi16(a_vec[0], b_vec[1]); c_vec[4] = _mm_add_epi32(c_vec[4], c_temp[0]); c_vec[5] = _mm_add_epi32(c_vec[5], c_temp[1]); c_temp[2] = _mm_madd_epi16(a_vec[1], b_vec[0]); c_temp[3] = _mm_madd_epi16(a_vec[1], b_vec[1]); c_vec[6] = _mm_add_epi32(c_vec[6], c_temp[2]); c_vec[7] = _mm_add_epi32(c_vec[7], c_temp[3]); pack_a_ptr += 8; pack_b_ptr += 16; } int index_array[4]{0, 2, 4, 6}; if (remain_n >= 4) { for (int m = 0; m < remain_m; ++m) { store_overflow(c_ptr + m * ldc, c_vec[index_array[m]]); } c_ptr += 4; remain_n -= 4; c_vec[0] = c_vec[1]; c_vec[2] = c_vec[3]; c_vec[4] = c_vec[5]; c_vec[6] = c_vec[7]; } for (int m = 0; m < remain_m; ++m) { store_overflow(c_ptr + m * ldc, c_vec[index_array[m]], remain_n); } } static inline void gemm_s8s8s32_sse_4x8x2_pack_an( dt_int16* out, const dt_int8* in, int ldin, int m_start, int m_max, int k_start, int k_max) { constexpr int tile_m = 4; constexpr int tile_k_step = 8; constexpr int tile_k = 2; constexpr int tile_len = tile_m * tile_k_step; const int k_size = k_max - k_start; const int m_end = (m_max - m_start) / tile_m * tile_m + m_start; const int m_remain = m_max - m_end; for (int m = m_start; m < m_end; m += tile_m) { const dt_int8* in0 = in + m * ldin + k_start; const dt_int8* in1 = in0 + ldin; const dt_int8* in2 = in1 + ldin; const dt_int8* in3 = in2 + ldin; int remain_k = k_size; for (; remain_k >= tile_k_step; remain_k -= tile_k_step) { transpose_4x8_k2_int8_to_int16(in0, in1, in2, in3, out); out += tile_len; in0 += tile_k_step; in1 += tile_k_step; in2 += tile_k_step; in3 += tile_k_step; } if (remain_k > 0) { transpose_4xk_int8_to_int16_pad(in0, in1, in2, in3, out, remain_k); out += tile_m * round_up(remain_k, tile_k); } } if (m_remain > 0) { dt_int8 zerobuff[tile_k_step]; std::memset(zerobuff, 0, sizeof(int8_t) * tile_k_step); const dt_int8* in0 = in + m_end * ldin + k_start; const dt_int8* in1 = in0 + ldin; const dt_int8* in2 = in1 + ldin; const dt_int8* in3 = &zerobuff[0]; int in1_step = tile_k_step; int in2_step = tile_k_step; if (m_remain < 3) { in2 = &zerobuff[0]; in2_step = 0; } if (m_remain < 2) { in1 = &zerobuff[0]; in1_step = 0; } int remain_k = k_size; for (; remain_k >= tile_k_step; remain_k -= tile_k_step) { transpose_4x8_k2_int8_to_int16(in0, in1, in2, in3, out); out += tile_len; in0 += tile_k_step; in1 += in1_step; in2 += in2_step; } if (remain_k > 0) { transpose_4xk_int8_to_int16_pad(in0, in1, in2, in3, out, remain_k); out += tile_m * round_up(remain_k, tile_k); in0 += tile_k_step; in1 += in1_step; in2 += in2_step; } } } static inline void gemm_s8s8s32_sse_4x8x2_pack_bn( dt_int8* out, const dt_int8* in, int ldin, int n_start, int n_max, int k_start, int k_max) { constexpr int tile_n = 8; constexpr int tile_k = 2; constexpr int tile_len = tile_n * tile_k; const int k_size = k_max - k_start; const int k_end = k_size / tile_k * tile_k + k_start; const int k_remain = k_max - k_end; const int n_size = n_max - n_start; const int n_end = n_size / tile_n * tile_n + n_start; const int n_remain = n_max - n_end; const int pack_line_len = round_up(k_size, tile_k) * tile_n; int k = k_start; for (; k < k_end; k += tile_k) { int8_t* outptr = out; for (int n = n_start; n < n_end; n += tile_n) { const dt_int8* inptr_0 = in + k * ldin + n; const dt_int8* inptr_1 = inptr_0 + ldin; transpose_2x8_no_inc(inptr_0, inptr_1, outptr); outptr += pack_line_len; } if (n_end < n_max) { naive_transpose_kn_pad( outptr, in + k * ldin + n_end, ldin, tile_k, n_remain, tile_k, tile_n); } out += tile_len; } if (k_remain > 0) { int8_t* outptr = out; dt_int8 zerobuff[tile_n]; std::memset(zerobuff, 0, sizeof(int8_t) * tile_n); for (int n = n_start; n < n_end; n += tile_n) { const dt_int8* inptr_0 = in + k * ldin + n; const dt_int8* inptr_1 = &zerobuff[0]; transpose_2x8_no_inc(inptr_0, inptr_1, outptr); outptr += pack_line_len; } if (n_end < n_max) { naive_transpose_kn_pad( outptr, in + k * ldin + n_end, ldin, k_remain, n_remain, tile_k, tile_n); } } } static inline void gemm_s8s8s32_sse_4x8x2_pack_bt( dt_int8* out, const dt_int8* in, int ldin, int n_start, int n_max, int k_start, int k_max) { constexpr int tile_n = 8; constexpr int tile_k = 2; constexpr int tile_k_step = 16; const int k_size = k_max - k_start; const int k_end = k_size / tile_k_step * tile_k_step + k_start; const int k_remain = k_max - k_end; const int n_size = n_max - n_start; const int n_end = n_size / tile_n * tile_n + n_start; const int n_remain = n_max - n_end; for (int n = n_start; n < n_end; n += tile_n) { const dt_int8* in0 = in + n * ldin + k_start; const dt_int8* in1 = in0 + ldin; const dt_int8* in2 = in1 + ldin; const dt_int8* in3 = in2 + ldin; const dt_int8* in4 = in3 + ldin; const dt_int8* in5 = in4 + ldin; const dt_int8* in6 = in5 + ldin; const dt_int8* in7 = in6 + ldin; for (int k = k_start; k < k_end; k += tile_k_step) { transpose_8x16_k2(out, in0, in1, in2, in3, in4, in5, in6, in7); in0 += tile_k_step; in1 += tile_k_step; in2 += tile_k_step; in3 += tile_k_step; in4 += tile_k_step; in5 += tile_k_step; in6 += tile_k_step; in7 += tile_k_step; out += tile_n * tile_k_step; } naive_transpose_8xk_k2(out, in0, in1, in2, in3, in4, in5, in6, in7, k_remain); out += tile_n * round_up(k_remain, tile_k); } if (n_remain > 0) { const dt_int8* in0 = in + n_end * ldin + k_start; naive_transpose_nk_k2(out, in0, ldin, n_remain, k_size, tile_n); } } static inline void gemm_s8s8s32_sse_4x8x2_pack_at( dt_int16* out, const dt_int8* in, int ldin, int m_start, int m_max, int k_start, int k_max) { constexpr int tile_m = 8; constexpr int tile_m_step = 4; constexpr int tile_k = 2; const int k_size = k_max - k_start; const int k_end = k_size / tile_k * tile_k + k_start; const int k_remain = k_max - k_end; const int m_size = m_max - m_start; const int m_end = m_size / tile_m * tile_m + m_start; const int pack_line_len = round_up(k_size, tile_k) * tile_m_step; int k = k_start; for (; k < k_end; k += tile_k) { dt_int16* outptr = out; for (int m = m_start; m < m_end; m += tile_m) { const dt_int8* inptr_0 = in + k * ldin + m; const dt_int8* inptr_1 = inptr_0 + ldin; transpose_km_2x8_k2_tile4_int8_to_int16( inptr_0, inptr_1, outptr, pack_line_len); outptr += (tile_m / tile_m_step) * pack_line_len; } if (m_end < m_max) { for (int m = m_end; m < m_max; m += tile_m_step) { const int m_remain = m_max - m >= tile_m_step ? tile_m_step : m_max - m; naive_transpose_kn_pad( outptr, in + k * ldin + m, ldin, tile_k, m_remain, tile_k, tile_m_step); outptr += pack_line_len; } } out += tile_m_step * tile_k; } if (k_remain > 0) { dt_int16* outptr = out; dt_int8 zerobuff[tile_m]; std::memset(zerobuff, 0, sizeof(int8_t) * tile_m); for (int n = m_start; n < m_end; n += tile_m) { const dt_int8* inptr_0 = in + k * ldin + n; const dt_int8* inptr_1 = &zerobuff[0]; transpose_km_2x8_k2_tile4_int8_to_int16( inptr_0, inptr_1, outptr, pack_line_len); outptr += (tile_m / tile_m_step) * pack_line_len; } if (m_end < m_max) { for (int m = m_end; m < m_max; m += tile_m_step) { const int m_remain = m_max - m >= tile_m_step ? tile_m_step : m_max - m; naive_transpose_kn_pad( outptr, in + k * ldin + m, ldin, k_remain, m_remain, tile_k, tile_m_step); outptr += pack_line_len; } } } } } // namespace matmul_sse_4x8x2 } // namespace x86 } // namespace megdnn // vim: syntax=cpp.doxygen