提交 5dbf218d 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn/x86): add sse 8816 matmul

GitOrigin-RevId: ed8d9ee5db3342ec534dc2e8a38b8613b8733b2b
上级 25b6a131
...@@ -184,7 +184,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern( ...@@ -184,7 +184,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern(
} }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32Vnni, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32Vnni,
megdnn_x86_matmul_kern, 5, megdnn_x86_matmul_kern,
"AlgoInt8x8x32Vnni"_hash,
x86::matmul::gemm_int8_vnni_12x32x4, x86::matmul::gemm_int8_vnni_12x32x4,
dt_int8, dt_int32, dt_uint8); dt_int8, dt_int32, dt_uint8);
#endif #endif
...@@ -318,6 +319,8 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) { ...@@ -318,6 +319,8 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) {
} }
} // namespace } // namespace
/*************************AlgoInt8x8x16AVX2********************/
void MatrixMulImpl::AlgoInt8x8x16AVX2::gemm_s8s8s16_avx2_4x16x2( void MatrixMulImpl::AlgoInt8x8x16AVX2::gemm_s8s8s16_avx2_4x16x2(
const MatrixMulImpl::KernParam& kern_param) { const MatrixMulImpl::KernParam& kern_param) {
MEGDNN_MARK_USED_VAR(kern_param); MEGDNN_MARK_USED_VAR(kern_param);
...@@ -389,9 +392,86 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace( ...@@ -389,9 +392,86 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace(
.get_workspace_size(); .get_workspace_size();
} }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, 8, AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash,
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16); x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16);
/*************************AlgoInt8x8x16SSE********************/
void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2(
const MatrixMulImpl::KernParam& kern_param) {
MEGDNN_MARK_USED_VAR(kern_param);
MIDOUT_BEGIN(megdnn_x86_matmul_kern_sse_4x8x2, midout_iv(2)) {
constexpr int cacheline = 64;
const size_t m = kern_param.M;
const size_t n = kern_param.N;
const size_t k = kern_param.K;
const bool trans_a = kern_param.trA;
const bool trans_b = kern_param.trB;
const size_t lda = kern_param.LDA;
const size_t ldb = kern_param.LDB;
const size_t ldc = kern_param.LDC;
auto a_type = kern_param.A_type;
auto b_type = kern_param.B_type;
auto c_type = kern_param.C_type;
const auto a_ptr = kern_param.A<dt_int8>();
const auto b_ptr = kern_param.B<dt_int8>();
auto c_ptr = kern_param.C<dt_int16>();
x86::matmul::gemm_sse_s8s8s16_4x8x2 strategy(m, n, k, a_type, b_type,
c_type);
megdnn::matmul::GemmInterleaved<x86::matmul::gemm_sse_s8s8s16_4x8x2>(
m, n, k, trans_a, trans_b, strategy, cacheline)
.execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16SSE::get_kern(
const KernSizeParam&) const {
return gemm_s8s8s16_sse_4x8x2;
}
bool MatrixMulImpl::AlgoInt8x8x16SSE::usable(
const KernSizeParam& kern_size_param) const {
bool is_ab_same =
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv();
bool is_type_ok =
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
kern_size_param.C_type.enumv() == DTypeEnum::Int16) ||
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16));
bool is_mode_ok =
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
is_supported(SIMDType::SSE4_1);
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok;
return is_param_ok;
}
bool MatrixMulImpl::AlgoInt8x8x16SSE::preferred(const KernSizeParam&) const {
return true;
}
size_t MatrixMulImpl::AlgoInt8x8x16SSE::get_workspace(
const KernSizeParam& kern_param) const {
constexpr int cacheline = 64;
const size_t m = kern_param.M;
const size_t n = kern_param.N;
const size_t k = kern_param.K;
const bool trans_a = kern_param.trA;
const bool trans_b = kern_param.trB;
auto a_type = kern_param.A_type;
auto b_type = kern_param.B_type;
auto c_type = kern_param.C_type;
x86::matmul::gemm_sse_s8s8s16_4x8x2 strategy(m, n, k, a_type, b_type,
c_type);
return megdnn::matmul::GemmInterleaved<x86::matmul::gemm_sse_s8s8s16_4x8x2>(
m, n, k, trans_a, trans_b, strategy, cacheline)
.get_workspace_size();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16SSE,
megdnn_x86_matmul_kern,
"AlgoInt8x8x16SSE"_hash,
x86::matmul::gemm_sse_s8s8s16_4x8x2,
dt_int8, dt_int16, dt_int16);
/*************************AlgoInt8x8x32AVX2M4N16K2********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern(
const KernSizeParam&) const { const KernSizeParam&) const {
return gemm_s8s8s32_avx2_4x16x2; return gemm_s8s8s32_avx2_4x16x2;
...@@ -426,8 +506,9 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace( ...@@ -426,8 +506,9 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace(
.get_workspace_size(); .get_workspace_size();
} }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern, 8, AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern,
x86::matmul::gemm_avx2_s8s8s32_4x16x2, dt_int8, dt_int32, dt_int16); "AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2,
dt_int8, dt_int32, dt_int16);
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern( MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern(
const KernSizeParam&) const { const KernSizeParam&) const {
...@@ -463,7 +544,8 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace( ...@@ -463,7 +544,8 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace(
.get_workspace_size(); .get_workspace_size();
} }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16,
megdnn_x86_matmul_kern, 8, megdnn_x86_matmul_kern,
"AlgoInt8x8x32AVX2M2N4K16"_hash,
x86::matmul::gemm_avx2_s8s8s32_2x4x16, x86::matmul::gemm_avx2_s8s8s32_2x4x16,
dt_int8, dt_int32); dt_int8, dt_int32);
...@@ -501,7 +583,8 @@ size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace( ...@@ -501,7 +583,8 @@ size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace(
.get_workspace_size(); .get_workspace_size();
} }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2,
megdnn_x86_matmul_kern, 9, megdnn_x86_matmul_kern,
"AlgoInt8x8x32SSEM4N8K2"_hash,
x86::matmul::gemm_sse_s8s8s32_4x8x2, x86::matmul::gemm_sse_s8s8s32_4x8x2,
dt_int8, dt_int32, dt_int16); dt_int8, dt_int32, dt_int16);
......
...@@ -76,7 +76,6 @@ class MatrixMulImpl::AlgoInt8x8x16AVX2 : public AlgoBase { ...@@ -76,7 +76,6 @@ class MatrixMulImpl::AlgoInt8x8x16AVX2 : public AlgoBase {
private: private:
static void gemm_s8s8s16_avx2_4x16x2( static void gemm_s8s8s16_avx2_4x16x2(
const MatrixMulImpl::KernParam& kern_param); const MatrixMulImpl::KernParam& kern_param);
static MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 m_algo;
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
...@@ -89,6 +88,22 @@ public: ...@@ -89,6 +88,22 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
class MatrixMulImpl::AlgoInt8x8x16SSE : public AlgoBase {
private:
static void gemm_s8s8s16_sse_4x8x2(
const MatrixMulImpl::KernParam& kern_param);
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "X86_INT8X8X16_SSE"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
bool preferred(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
......
...@@ -6,10 +6,17 @@ ...@@ -6,10 +6,17 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include <immintrin.h> #include <immintrin.h>
#ifdef WIN32
#include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h>
#include <smmintrin.h>
#endif
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
...@@ -21,10 +28,44 @@ namespace x86 { ...@@ -21,10 +28,44 @@ namespace x86 {
namespace matmul_sse_4x8x2 { namespace matmul_sse_4x8x2 {
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("sse4.1")
void store_overflow(void* ptr, __m128i a);
template <>
void store_overflow<int16_t>(void* ptr, __m128i a) {
a = _mm_shufflelo_epi16(a, 0x08);
a = _mm_shufflehi_epi16(a, 0x08);
a = _mm_shuffle_epi32(a, 0x08);
_mm_storel_epi64((__m128i*)ptr, a);
}
template <>
void store_overflow<int32_t>(void* ptr, __m128i a) {
_mm_storeu_si128((__m128i*)(ptr), a);
}
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("sse4.1")
void store_overflow(void* ptr, __m128i a, int remain);
template <>
void store_overflow<int16_t>(void* ptr, __m128i a, int remain) {
__m128i mask = _mm_continue_mask(remain * sizeof(int16_t));
a = _mm_shufflelo_epi16(a, 0x08);
a = _mm_shufflehi_epi16(a, 0x08);
a = _mm_shuffle_epi32(a, 0x08);
_mm_maskmoveu_si128(a, mask, reinterpret_cast<char*>(ptr));
}
template <>
void store_overflow<int32_t>(void* ptr, __m128i a, int remain) {
__m128i mask = _mm_continue_mask(remain * sizeof(int32_t));
_mm_maskmoveu_si128(a, mask, reinterpret_cast<char*>(ptr));
}
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("sse4.1") MEGDNN_ATTRIBUTE_TARGET("sse4.1")
static inline void kern_gemm_s8s8s32_sse_4x8x2(const int16_t* pack_a_ptr, static inline void kern_gemm_s8s8s32_sse_4x8x2(const int16_t* pack_a_ptr,
const int8_t* pack_b_ptr, const int8_t* pack_b_ptr,
int32_t* c_ptr, const int ldc, CType* c_ptr, const int ldc,
const int k) { const int k) {
constexpr int k_step = 2; constexpr int k_step = 2;
...@@ -102,20 +143,20 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2(const int16_t* pack_a_ptr, ...@@ -102,20 +143,20 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2(const int16_t* pack_a_ptr,
pack_a_ptr += 8; pack_a_ptr += 8;
pack_b_ptr += 16; pack_b_ptr += 16;
} }
store_overflow<CType>(c_ptr, c_vec[0]);
_mm_storeu_si128((__m128i*)(c_ptr), c_vec[0]); store_overflow<CType>(c_ptr + 4, c_vec[1]);
_mm_storeu_si128((__m128i*)(c_ptr + 4), c_vec[1]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]);
_mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
_mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc + 4, c_vec[5]);
_mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc + 4), c_vec[5]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]);
_mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); store_overflow<CType>(c_ptr + 3 * ldc + 4, c_vec[7]);
_mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc + 4), c_vec[7]);
} }
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("sse4.1") MEGDNN_ATTRIBUTE_TARGET("sse4.1")
static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m( static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m(
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
const int ldc, const int k, const int remain_m) { const int ldc, const int k, const int remain_m) {
constexpr int k_step = 2; constexpr int k_step = 2;
...@@ -194,34 +235,35 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m( ...@@ -194,34 +235,35 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m(
pack_b_ptr += 16; pack_b_ptr += 16;
} }
_mm_storeu_si128((__m128i*)(c_ptr), c_vec[0]); store_overflow<CType>(c_ptr, c_vec[0]);
_mm_storeu_si128((__m128i*)(c_ptr + 4), c_vec[1]); store_overflow<CType>(c_ptr + 4, c_vec[1]);
switch (remain_m) { switch (remain_m) {
case 2: case 2:
_mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]);
break; break;
case 3: case 3:
_mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]);
_mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
_mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc + 4), c_vec[5]); store_overflow<CType>(c_ptr + 2 * ldc + 4, c_vec[5]);
break; break;
case 4: case 4:
_mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]);
_mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
_mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc + 4), c_vec[5]); store_overflow<CType>(c_ptr + 2 * ldc + 4, c_vec[5]);
_mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]);
_mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc + 4), c_vec[7]); store_overflow<CType>(c_ptr + 3 * ldc + 4, c_vec[7]);
default: default:
break; break;
} }
} }
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("sse4.1") MEGDNN_ATTRIBUTE_TARGET("sse4.1")
static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n( static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n(
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
const int ldc, const int k, int remain_n) { const int ldc, const int k, int remain_n) {
constexpr int k_step = 2; constexpr int k_step = 2;
...@@ -301,10 +343,10 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n( ...@@ -301,10 +343,10 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n(
} }
if (remain_n >= 4) { if (remain_n >= 4) {
_mm_storeu_si128((__m128i*)(c_ptr), c_vec[0]); store_overflow<CType>(c_ptr, c_vec[0]);
_mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
_mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]);
c_ptr += 4; c_ptr += 4;
remain_n -= 4; remain_n -= 4;
c_vec[0] = c_vec[1]; c_vec[0] = c_vec[1];
...@@ -312,35 +354,16 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n( ...@@ -312,35 +354,16 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n(
c_vec[4] = c_vec[5]; c_vec[4] = c_vec[5];
c_vec[6] = c_vec[7]; c_vec[6] = c_vec[7];
} }
store_overflow<CType>(c_ptr, c_vec[0], remain_n);
switch (remain_n) { store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n);
case 0: store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n);
break; store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n);
case 1:
*(c_ptr) = _mm_extract_epi32(c_vec[0], 0);
*(c_ptr + ldc) = _mm_extract_epi32(c_vec[2], 0);
*(c_ptr + 2 * ldc) = _mm_extract_epi32(c_vec[4], 0);
*(c_ptr + 3 * ldc) = _mm_extract_epi32(c_vec[6], 0);
break;
case 2:
case 3:
_mm_storel_epi64((__m128i*)(c_ptr), c_vec[0]);
_mm_storel_epi64((__m128i*)(c_ptr + ldc), c_vec[2]);
_mm_storel_epi64((__m128i*)(c_ptr + 2 * ldc), c_vec[4]);
_mm_storel_epi64((__m128i*)(c_ptr + 3 * ldc), c_vec[6]);
break;
}
if (remain_n == 3) {
*(c_ptr + 2) = _mm_extract_epi32(c_vec[0], 2);
*(c_ptr + ldc + 2) = _mm_extract_epi32(c_vec[2], 2);
*(c_ptr + 2 * ldc + 2) = _mm_extract_epi32(c_vec[4], 2);
*(c_ptr + 3 * ldc + 2) = _mm_extract_epi32(c_vec[6], 2);
}
} }
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("sse4.1") MEGDNN_ATTRIBUTE_TARGET("sse4.1")
static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n(
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
const int ldc, const int k, int remain_m, int remain_n) { const int ldc, const int k, int remain_m, int remain_n) {
constexpr int k_step = 2; constexpr int k_step = 2;
...@@ -421,8 +444,7 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( ...@@ -421,8 +444,7 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n(
int index_array[4]{0, 2, 4, 6}; int index_array[4]{0, 2, 4, 6};
if (remain_n >= 4) { if (remain_n >= 4) {
for (int m = 0; m < remain_m; ++m) { for (int m = 0; m < remain_m; ++m) {
_mm_storeu_si128((__m128i*)(c_ptr + m * ldc), store_overflow<CType>(c_ptr + m * ldc, c_vec[index_array[m]]);
c_vec[index_array[m]]);
} }
c_ptr += 4; c_ptr += 4;
remain_n -= 4; remain_n -= 4;
...@@ -431,29 +453,8 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( ...@@ -431,29 +453,8 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n(
c_vec[4] = c_vec[5]; c_vec[4] = c_vec[5];
c_vec[6] = c_vec[7]; c_vec[6] = c_vec[7];
} }
for (int m = 0; m < remain_m; ++m) {
switch (remain_n) { store_overflow<CType>(c_ptr + m * ldc, c_vec[index_array[m]], remain_n);
case 0:
break;
case 1:
for (int m = 0; m < remain_m; ++m) {
*(c_ptr + m * ldc) =
_mm_extract_epi32(c_vec[index_array[m]], 0);
}
break;
case 2:
case 3:
for (int m = 0; m < remain_m; ++m) {
_mm_storel_epi64((__m128i*)(c_ptr + m * ldc),
c_vec[index_array[m]]);
}
break;
}
if (remain_n == 3) {
for (int m = 0; m < remain_m; ++m) {
*(c_ptr + m * ldc + 2) =
_mm_extract_epi32(c_vec[index_array[m]], 2);
}
} }
} }
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/common/utils.h" #include "src/common/utils.h"
...@@ -18,11 +19,9 @@ using namespace megdnn; ...@@ -18,11 +19,9 @@ using namespace megdnn;
using namespace x86; using namespace x86;
using namespace x86::matmul; using namespace x86::matmul;
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_sse_s8s8s32_4x8x2); static inline void gemm_packa(dt_int16* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin, bool transpose) {
int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_sse_4x8x2::gemm_s8s8s32_sse_4x8x2_pack_at(out, in, ldin, y0, matmul_sse_4x8x2::gemm_s8s8s32_sse_4x8x2_pack_at(out, in, ldin, y0,
ymax, k0, kmax); ymax, k0, kmax);
...@@ -31,10 +30,8 @@ void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin, ...@@ -31,10 +30,8 @@ void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin,
ymax, k0, kmax); ymax, k0, kmax);
} }
} }
static inline void gemm_packb(dt_int8* out, const dt_int8* in, int ldin, int x0,
void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, int xmax, int k0, int kmax, bool transpose) {
int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_sse_4x8x2::gemm_s8s8s32_sse_4x8x2_pack_bt(out, in, ldin, x0, matmul_sse_4x8x2::gemm_s8s8s32_sse_4x8x2_pack_bt(out, in, ldin, x0,
xmax, k0, kmax); xmax, k0, kmax);
...@@ -43,20 +40,11 @@ void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, ...@@ -43,20 +40,11 @@ void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin,
xmax, k0, kmax); xmax, k0, kmax);
} }
} }
template <typename CType>
void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr, static inline void gemm_kern(const dt_int16* pack_a_ptr,
const dt_int8* pack_b_ptr, size_t m, size_t n, const dt_int8* pack_b_ptr, size_t m, size_t n,
size_t k, dt_int32* c_ptr, size_t ldc, size_t k, CType* c_ptr, size_t ldc,
bool is_first_k, const dt_int32*, bool is_first_k) {
dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
megdnn_assert(is_first_k == true);
constexpr int m_tile = 4; constexpr int m_tile = 4;
constexpr int n_tile = 8; constexpr int n_tile = 8;
constexpr int k_tile = 2; constexpr int k_tile = 2;
...@@ -99,4 +87,62 @@ void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr, ...@@ -99,4 +87,62 @@ void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr,
} }
} }
} }
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_sse_s8s8s32_4x8x2);
void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose);
}
void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool transpose) const {
gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose);
}
void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr,
const dt_int8* pack_b_ptr, size_t m, size_t n,
size_t k, dt_int32* c_ptr, size_t ldc,
bool is_first_k, const dt_int32*,
dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
megdnn_assert(is_first_k == true);
gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k);
}
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_sse_s8s8s16_4x8x2);
void gemm_sse_s8s8s16_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose);
}
void gemm_sse_s8s8s16_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool transpose) const {
gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose);
}
void gemm_sse_s8s8s16_4x8x2::kern(const dt_int16* pack_a_ptr,
const dt_int8* pack_b_ptr, size_t m, size_t n,
size_t k, dt_int16* c_ptr, size_t ldc,
bool is_first_k, const dt_int32*,
dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS16)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
megdnn_assert(is_first_k == true);
gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k);
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -38,6 +38,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32, ...@@ -38,6 +38,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32,
4, 8, 2, false, false, 4, 8, 2, false, false,
gemm_sse_s8s8s32_4x8x2); gemm_sse_s8s8s32_4x8x2);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int32,
4, 8, 2, false, false,
gemm_sse_s8s8s16_4x8x2);
} // namespace matmul } // namespace matmul
} // namespace x86 } // namespace x86
} // namespace megdnn } // namespace megdnn
......
...@@ -38,6 +38,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { ...@@ -38,6 +38,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x32AVX2M2N4K16 algoint8x8x32avx2_m2n4k16; AlgoInt8x8x32AVX2M2N4K16 algoint8x8x32avx2_m2n4k16;
AlgoInt8x8x32SSEM4N8K2 algoint8x8x32sse_m4n8k2; AlgoInt8x8x32SSEM4N8K2 algoint8x8x32sse_m4n8k2;
AlgoInt8x8x16AVX2 algoint8x8x16avx2_m4n16k2; AlgoInt8x8x16AVX2 algoint8x8x16avx2_m4n16k2;
AlgoInt8x8x16SSE algoint8x8x16sse_m4n8k2;
AlgoF32MK8_8x8 algof32mk8_8x8; AlgoF32MK8_8x8 algof32mk8_8x8;
public: public:
...@@ -51,6 +52,7 @@ public: ...@@ -51,6 +52,7 @@ public:
all_algos.emplace_back(&algoint8x8x16avx2_m4n16k2); all_algos.emplace_back(&algoint8x8x16avx2_m4n16k2);
all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16); all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16);
all_algos.emplace_back(&algoint8x8x32sse_m4n8k2); all_algos.emplace_back(&algoint8x8x32sse_m4n8k2);
all_algos.emplace_back(&algoint8x8x16sse_m4n8k2);
all_algos.emplace_back(&algof32mk8_8x8); all_algos.emplace_back(&algof32mk8_8x8);
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
all_algos.emplace_back(&algoint8x8x32mkldnn); all_algos.emplace_back(&algoint8x8x32mkldnn);
......
...@@ -56,6 +56,7 @@ protected: ...@@ -56,6 +56,7 @@ protected:
class AlgoInt8x8x32AVX2M4N16K2; class AlgoInt8x8x32AVX2M4N16K2;
class AlgoInt8x8x32SSEM4N8K2; class AlgoInt8x8x32SSEM4N8K2;
class AlgoInt8x8x16AVX2; class AlgoInt8x8x16AVX2;
class AlgoInt8x8x16SSE;
class AlgoPack; class AlgoPack;
class AlgoF32MK8_8x8; class AlgoF32MK8_8x8;
}; };
......
...@@ -835,6 +835,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X) { ...@@ -835,6 +835,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X) {
} }
if (::megdnn::x86::is_supported(::megdnn::x86::SIMDType::SSE4_2)) { if (::megdnn::x86::is_supported(::megdnn::x86::SIMDType::SSE4_2)) {
cb("IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2"); cb("IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2");
cb2("IM2COLMATMUL:X86_INT8X8X16_SSE");
} }
#undef cb #undef cb
...@@ -1002,7 +1003,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS) { ...@@ -1002,7 +1003,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS) {
} }
#endif #endif
TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X) {
using namespace conv_bias; using namespace conv_bias;
UniformIntRNG rng{-50, 50}; UniformIntRNG rng{-50, 50};
float epsilon = 0.001; float epsilon = 0.001;
...@@ -1028,10 +1029,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { ...@@ -1028,10 +1029,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) {
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, dtype::Int8{}, dtype::Int32{}, dtype::Int32{},
"CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24"); "CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24");
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int16{}, dtype::Int16{},
"CONV1x1:X86_INT8X8X16_AVX2");
} }
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, dtype::Int8{}, dtype::Int32{}, dtype::Int32{},
"CONV1x1:X86_INT8X8X32_SSE_4X8X2:48"); "CONV1x1:X86_INT8X8X32_SSE_4X8X2:48");
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int16{}, dtype::Int16{},
"CONV1x1:X86_INT8X8X16_SSE");
} }
/************************* End Conv1x1 PackA ************************/ /************************* End Conv1x1 PackA ************************/
......
...@@ -403,6 +403,7 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x16) { ...@@ -403,6 +403,7 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x16) {
benchmark.set_dtype(0, dtype::Int8()) benchmark.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8()) .set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int16()); .set_dtype(2, dtype::Int16());
benchmark.set_before_exec_callback(AlgoChecker<Convolution>(".*"));
benchmark.set_display(false); benchmark.set_display(false);
benchmark.set_times(RUN); benchmark.set_times(RUN);
......
...@@ -52,6 +52,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) { ...@@ -52,6 +52,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "X86_INT8X8X16_AVX2"); handle(), "X86_INT8X8X16_AVX2");
} }
TEST_F(X86, MATRIX_MUL_SSE_8X8X16) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "X86_INT8X8X16_SSE");
}
TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { TEST_F(X86, MATRIX_MUL_SSE_8X8X32) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "X86_INT8X8X32_SSE_4X8X2"); handle(), "X86_INT8X8X32_SSE_4X8X2");
...@@ -132,6 +136,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { ...@@ -132,6 +136,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
benchmarker_avx2_4x16x2_8816.set_before_exec_callback( benchmarker_avx2_4x16x2_8816.set_before_exec_callback(
AlgoChecker<MatrixMul>("X86_INT8X8X16_AVX2")); AlgoChecker<MatrixMul>("X86_INT8X8X16_AVX2"));
Benchmarker<MatrixMul> benchmarker_sse_4x8x2_8816(handle());
benchmarker_sse_4x8x2_8816.set_display(false)
.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_rng(0, rng.get())
.set_rng(1, rng.get());
benchmarker_sse_4x8x2_8816.set_before_exec_callback(
AlgoChecker<MatrixMul>("X86_INT8X8X16_SSE"));
Benchmarker<MatrixMul> benchmarker_avx2_2x4x16(handle()); Benchmarker<MatrixMul> benchmarker_avx2_2x4x16(handle());
benchmarker_avx2_2x4x16.set_display(false) benchmarker_avx2_2x4x16.set_display(false)
.set_times(RUNS) .set_times(RUNS)
...@@ -212,9 +227,15 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { ...@@ -212,9 +227,15 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
std::cout << "sse: " << sse_used << " ms, " std::cout << "sse: " << sse_used << " ms, "
<< computations / sse_used << " Gflops, " << computations / sse_used << " Gflops, "
<< "speed_up " << float_used / sse_used << ", "; << "speed_up " << float_used / sse_used << ", ";
auto sse_used_8816 =
benchmarker_sse_4x8x2_8816.exec({{M, K}, {K, N}, {}}) /
RUNS;
std::cout << "sse_8816: " << sse_used_8816 << " ms, "
<< computations / sse_used_8816 << " Gflops, ";
} }
std::cout << std::endl; std::cout << std::endl;
}; };
run(256, 256, 256);
for (size_t M : {8, 64, 112, 256, 512}) { for (size_t M : {8, 64, 112, 256, 512}) {
for (size_t K : {8, 16, 32, 64, 112, 256, 512}) { for (size_t K : {8, 16, 32, 64, 112, 256, 512}) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册