提交 25b6a131 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn/x86): add x86 avx2 8x8x16 matmul

GitOrigin-RevId: d2172c50b244a0683b00710a88bc507d02a9734f
上级 273f891b
...@@ -318,6 +318,79 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) { ...@@ -318,6 +318,79 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) {
} }
} // namespace } // namespace
void MatrixMulImpl::AlgoInt8x8x16AVX2::gemm_s8s8s16_avx2_4x16x2(
const MatrixMulImpl::KernParam& kern_param) {
MEGDNN_MARK_USED_VAR(kern_param);
MIDOUT_BEGIN(megdnn_x86_matmul_kern_avx2_4x16x2, midout_iv(1)) {
constexpr int cacheline = 64;
const size_t m = kern_param.M;
const size_t n = kern_param.N;
const size_t k = kern_param.K;
const bool trans_a = kern_param.trA;
const bool trans_b = kern_param.trB;
const size_t lda = kern_param.LDA;
const size_t ldb = kern_param.LDB;
const size_t ldc = kern_param.LDC;
auto a_type = kern_param.A_type;
auto b_type = kern_param.B_type;
auto c_type = kern_param.C_type;
const auto a_ptr = kern_param.A<dt_int8>();
const auto b_ptr = kern_param.B<dt_int8>();
auto c_ptr = kern_param.C<dt_int16>();
x86::matmul::gemm_avx2_s8s8s16_4x16x2 strategy(m, n, k, a_type, b_type,
c_type);
megdnn::matmul::GemmInterleaved<x86::matmul::gemm_avx2_s8s8s16_4x16x2>(
m, n, k, trans_a, trans_b, strategy, cacheline)
.execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_kern(
const KernSizeParam&) const {
return gemm_s8s8s16_avx2_4x16x2;
}
bool MatrixMulImpl::AlgoInt8x8x16AVX2::usable(
const KernSizeParam& kern_size_param) const {
bool is_ab_same =
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv();
bool is_type_ok =
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
kern_size_param.C_type.enumv() == DTypeEnum::Int16) ||
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16));
bool is_mode_ok =
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
is_supported(SIMDType::AVX2);
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok;
return is_param_ok;
}
bool MatrixMulImpl::AlgoInt8x8x16AVX2::preferred(const KernSizeParam&) const {
return true;
}
size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace(
const KernSizeParam& kern_param) const {
constexpr int cacheline = 64;
const size_t m = kern_param.M;
const size_t n = kern_param.N;
const size_t k = kern_param.K;
const bool trans_a = kern_param.trA;
const bool trans_b = kern_param.trB;
auto a_type = kern_param.A_type;
auto b_type = kern_param.B_type;
auto c_type = kern_param.C_type;
x86::matmul::gemm_avx2_s8s8s16_4x16x2 strategy(m, n, k, a_type, b_type,
c_type);
return megdnn::matmul::GemmInterleaved<
x86::matmul::gemm_avx2_s8s8s16_4x16x2>(
m, n, k, trans_a, trans_b, strategy, cacheline)
.get_workspace_size();
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, 8,
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16);
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern(
const KernSizeParam&) const { const KernSizeParam&) const {
......
...@@ -6,13 +6,14 @@ ...@@ -6,13 +6,14 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include "src/x86/matrix_mul/opr_impl.h"
#include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/gemm_common.h"
#include "src/x86/matrix_mul/opr_impl.h"
namespace megdnn { namespace megdnn {
namespace x86 { namespace x86 {
...@@ -71,6 +72,23 @@ public: ...@@ -71,6 +72,23 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
class MatrixMulImpl::AlgoInt8x8x16AVX2 : public AlgoBase {
private:
static void gemm_s8s8s16_avx2_4x16x2(
const MatrixMulImpl::KernParam& kern_param);
static MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 m_algo;
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "X86_INT8X8X16_AVX2"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
bool preferred(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
......
...@@ -6,16 +6,17 @@ ...@@ -6,16 +6,17 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include <x86intrin.h> #include <x86intrin.h>
#ifdef WIN32 #ifdef WIN32
#include <avxintrin.h>
#include <smmintrin.h>
#include <avx2intrin.h> #include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h> #include <fmaintrin.h>
#include <smmintrin.h>
#endif #endif
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
...@@ -787,19 +788,49 @@ static inline void transpose_4x8_k2_int8_to_int16(const int8_t* inptr0, ...@@ -787,19 +788,49 @@ static inline void transpose_4x8_k2_int8_to_int16(const int8_t* inptr0,
MEGDNN_ATTRIBUTE_TARGET("avx2") MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline __v8si _m256_continue_mask_v8si(const int& x) { static inline __v8si _m256_continue_mask_v8si(const int& x) {
// clang-format off
static __v8si map[9] = { static __v8si map[9] = {
{0, 0, 0, 0, 0, 0, 0, 0}, {-1, 0, 0, 0, 0, 0, 0, 0}, {00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, 0, 0, 0, 0, 0, 0}, {-1, -1, -1, 0, 0, 0, 0, 0}, {-1, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, 0, 0, 0, 0}, {-1, -1, -1, -1, -1, 0, 0, 0}, {-1, -1, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, 0, 0}, {-1, -1, -1, -1, -1, -1, -1, 0}, {-1, -1, -1, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, 00},
{-1, -1, -1, -1, -1, -1, -1, -1}}; {-1, -1, -1, -1, -1, -1, -1, -1}};
return map[x]; return map[x];
// clang-format on
} }
MEGDNN_ATTRIBUTE_TARGET("avx2") MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline __m256i _m256_continue_mask(const int& x) { static inline __m256i _m256_continue_mask(const int& x) {
return (__m256i)_m256_continue_mask_v8si(x); return (__m256i)_m256_continue_mask_v8si(x);
} }
MEGDNN_ATTRIBUTE_TARGET("sse2")
static inline __m128i _mm_continue_mask(const int& x) {
static __v16qi map[17] = {
{00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
};
return (__m128i)map[x];
}
MEGDNN_ATTRIBUTE_TARGET("sse2") MEGDNN_ATTRIBUTE_TARGET("sse2")
static inline void transpose_4xk_int8_to_int16_pad(const int8_t* inptr0, static inline void transpose_4xk_int8_to_int16_pad(const int8_t* inptr0,
const int8_t* inptr1, const int8_t* inptr1,
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/common/utils.h" #include "src/common/utils.h"
...@@ -18,10 +19,9 @@ using namespace megdnn; ...@@ -18,10 +19,9 @@ using namespace megdnn;
using namespace x86; using namespace x86;
using namespace x86::matmul; using namespace x86::matmul;
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_avx2_s8s8s32_4x16x2); static inline void gemm_packa(dt_int16* out, const dt_int8* in, int ldin,
void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in, int y0, int ymax, int k0, int kmax,
int ldin, int y0, int ymax, int k0, bool transpose) {
int kmax, bool transpose) const {
if (transpose) { if (transpose) {
matmul_avx2_4x16x2::gemm_s8s8s32_avx2_4x16x2_pack_at(out, in, ldin, y0, matmul_avx2_4x16x2::gemm_s8s8s32_avx2_4x16x2_pack_at(out, in, ldin, y0,
ymax, k0, kmax); ymax, k0, kmax);
...@@ -30,10 +30,8 @@ void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in, ...@@ -30,10 +30,8 @@ void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in,
ymax, k0, kmax); ymax, k0, kmax);
} }
} }
static inline void gemm_packb(dt_int8* out, const dt_int8* in, int ldin, int x0,
void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, int xmax, int k0, int kmax, bool transpose) {
int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_avx2_4x16x2::gemm_s8s8s32_avx2_4x16x2_pack_bt(out, in, ldin, x0, matmul_avx2_4x16x2::gemm_s8s8s32_avx2_4x16x2_pack_bt(out, in, ldin, x0,
xmax, k0, kmax); xmax, k0, kmax);
...@@ -42,20 +40,11 @@ void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, ...@@ -42,20 +40,11 @@ void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin,
xmax, k0, kmax); xmax, k0, kmax);
} }
} }
template <typename CType>
void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr, static inline void gemm_kern(const dt_int16* pack_a_ptr,
const dt_int8* pack_b_ptr, size_t m, const dt_int8* pack_b_ptr, size_t m, size_t n,
size_t n, size_t k, dt_int32* c_ptr, size_t k, CType* c_ptr, size_t ldc,
size_t ldc, bool is_first_k, bool is_first_k) {
const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
megdnn_assert(is_first_k == true);
constexpr size_t m_tile = 4; constexpr size_t m_tile = 4;
constexpr size_t n_tile = 16; constexpr size_t n_tile = 16;
constexpr size_t k_tile = 2; constexpr size_t k_tile = 2;
...@@ -109,4 +98,62 @@ void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr, ...@@ -109,4 +98,62 @@ void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr,
} }
} }
} }
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_avx2_s8s8s32_4x16x2);
void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose);
}
void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool transpose) const {
gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose);
}
void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr,
const dt_int8* pack_b_ptr, size_t m,
size_t n, size_t k, dt_int32* c_ptr,
size_t ldc, bool is_first_k,
const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
megdnn_assert(is_first_k == true);
gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k);
}
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_avx2_s8s8s16_4x16x2);
void gemm_avx2_s8s8s16_4x16x2::pack_A(dt_int16* out, const dt_int8* in,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose);
}
void gemm_avx2_s8s8s16_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool transpose) const {
gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose);
}
void gemm_avx2_s8s8s16_4x16x2::kern(const dt_int16* pack_a_ptr,
const dt_int8* pack_b_ptr, size_t m,
size_t n, size_t k, dt_int16* c_ptr,
size_t ldc, bool is_first_k,
const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS16)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
megdnn_assert(is_first_k == true);
gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k);
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include <immintrin.h> #include <immintrin.h>
...@@ -20,11 +21,47 @@ namespace megdnn { ...@@ -20,11 +21,47 @@ namespace megdnn {
namespace x86 { namespace x86 {
namespace matmul_avx2_4x16x2 { namespace matmul_avx2_4x16x2 {
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("avx2")
void store_overflow(void* ptr, __m256i a);
template <>
void store_overflow<int16_t>(void* ptr, __m256i a) {
static __m256i idx = _mm256_setr_epi32(0, 2, 4, 6, 0, 0, 0, 0);
a = _mm256_shufflelo_epi16(a, 0x08);
a = _mm256_shufflehi_epi16(a, 0x08);
a = _mm256_permutevar8x32_epi32(a, idx);
_mm_storeu_si128((__m128i*)ptr, _mm256_extractf128_si256(a, 0));
}
template <>
void store_overflow<int32_t>(void* ptr, __m256i a) {
_mm256_storeu_si256((__m256i*)(ptr), a);
}
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("avx2")
void store_overflow(void* ptr, __m256i a, int remain);
template <>
void store_overflow<int16_t>(void* ptr, __m256i a, int remain) {
__m128i mask = _mm_continue_mask(remain * sizeof(int16_t));
static __m256i idx = _mm256_setr_epi32(0, 2, 4, 6, 0, 0, 0, 0);
a = _mm256_shufflelo_epi16(a, 0x08);
a = _mm256_shufflehi_epi16(a, 0x08);
a = _mm256_permutevar8x32_epi32(a, idx);
_mm_maskmoveu_si128(_mm256_extractf128_si256(a, 0), mask,
reinterpret_cast<char*>(ptr));
}
template <>
void store_overflow<int32_t>(void* ptr, __m256i a, int remain) {
__m256i mask = _m256_continue_mask(remain);
_mm256_maskstore_epi32(reinterpret_cast<int32_t*>(ptr), mask, a);
}
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("avx2") MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr, static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr,
const int8_t* pack_b_ptr, const int8_t* pack_b_ptr,
int32_t* c_ptr, CType* c_ptr,
const uint32_t ldc, const uint32_t ldc,
const uint32_t k) { const uint32_t k) {
constexpr uint32_t k_step = 2; constexpr uint32_t k_step = 2;
...@@ -104,19 +141,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr, ...@@ -104,19 +141,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr,
pack_b_ptr += 32; pack_b_ptr += 32;
} }
_mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); store_overflow<CType>(c_ptr, c_vec[0]);
_mm256_storeu_si256((__m256i*)(c_ptr + 8), c_vec[1]); store_overflow<CType>(c_ptr + 8, c_vec[1]);
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); store_overflow<CType>(c_ptr + ldc + 8, c_vec[3]);
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc + 8), c_vec[5]); store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5]);
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]);
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc + 8), c_vec[7]); store_overflow<CType>(c_ptr + 3 * ldc + 8, c_vec[7]);
} }
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("avx2") MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n(
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
const uint32_t ldc, const uint32_t k, const uint32_t remain_n) { const uint32_t ldc, const uint32_t k, const uint32_t remain_n) {
constexpr uint32_t k_step = 2; constexpr uint32_t k_step = 2;
...@@ -173,15 +210,15 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( ...@@ -173,15 +210,15 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n(
pack_b_ptr += 32; pack_b_ptr += 32;
} }
__m256i mask = _m256_continue_mask(remain_n); store_overflow<CType>(c_ptr, c_vec[0], remain_n);
_mm256_maskstore_epi32((c_ptr), mask, c_vec[0]); store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n);
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n);
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n);
_mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]);
} }
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("avx2") MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n(
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
const uint32_t ldc, const uint32_t k, const uint32_t remain_m, const uint32_t ldc, const uint32_t k, const uint32_t remain_m,
uint32_t remain_n) { uint32_t remain_n) {
constexpr uint32_t k_step = 2; constexpr uint32_t k_step = 2;
...@@ -239,29 +276,29 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( ...@@ -239,29 +276,29 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n(
pack_b_ptr += 32; pack_b_ptr += 32;
} }
__m256i mask = _m256_continue_mask(remain_n); store_overflow<CType>(c_ptr, c_vec[0], remain_n);
_mm256_maskstore_epi32((c_ptr), mask, c_vec[0]);
switch (remain_m) { switch (remain_m) {
case 2: case 2:
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n);
break; break;
case 3: case 3:
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n);
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n);
break; break;
case 4: case 4:
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n);
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n);
_mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n);
break; break;
default: default:
break; break;
} }
} }
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("avx2") MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m( static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m(
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
const uint32_t ldc, const uint32_t k, const uint32_t remain_m) { const uint32_t ldc, const uint32_t k, const uint32_t remain_m) {
constexpr uint32_t k_step = 2; constexpr uint32_t k_step = 2;
...@@ -339,34 +376,36 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m( ...@@ -339,34 +376,36 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m(
pack_a_ptr += 8; pack_a_ptr += 8;
pack_b_ptr += 32; pack_b_ptr += 32;
} }
_mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]);
_mm256_storeu_si256((__m256i*)(c_ptr + 8), c_vec[1]); store_overflow<CType>(c_ptr, c_vec[0]);
store_overflow<CType>(c_ptr + 8, c_vec[1]);
switch (remain_m) { switch (remain_m) {
case 2: case 2:
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); store_overflow<CType>(c_ptr + ldc + 8, c_vec[3]);
break; break;
case 3: case 3:
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); store_overflow<CType>(c_ptr + ldc + 8, c_vec[3]);
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc + 8), c_vec[5]); store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5]);
break; break;
case 4: case 4:
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); store_overflow<CType>(c_ptr + ldc + 8, c_vec[3]);
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc + 8), c_vec[5]); store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5]);
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]);
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc + 8), c_vec[7]); store_overflow<CType>(c_ptr + 3 * ldc + 8, c_vec[7]);
default: default:
break; break;
} }
} }
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("avx2") MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n( static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n(
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
const uint32_t ldc, const uint32_t k, uint32_t remain_n) { const uint32_t ldc, const uint32_t k, uint32_t remain_n) {
constexpr uint32_t k_step = 2; constexpr uint32_t k_step = 2;
...@@ -446,29 +485,28 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n( ...@@ -446,29 +485,28 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n(
} }
if (remain_n >= 8) { if (remain_n >= 8) {
_mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); store_overflow<CType>(c_ptr, c_vec[0]);
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]);
remain_n -= 8; remain_n -= 8;
if (remain_n > 0) { if (remain_n > 0) {
__m256i mask = _m256_continue_mask(remain_n); store_overflow<CType>(c_ptr + 8, c_vec[1], remain_n);
_mm256_maskstore_epi32((c_ptr + 8), mask, c_vec[1]); store_overflow<CType>(c_ptr + ldc + 8, c_vec[3], remain_n);
_mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5], remain_n);
_mm256_maskstore_epi32((c_ptr + 2 * ldc + 8), mask, c_vec[5]); store_overflow<CType>(c_ptr + 3 * ldc + 8, c_vec[7], remain_n);
_mm256_maskstore_epi32((c_ptr + 3 * ldc + 8), mask, c_vec[7]);
} }
} else { } else {
__m256i mask = _m256_continue_mask(remain_n); store_overflow<CType>(c_ptr, c_vec[0], remain_n);
_mm256_maskstore_epi32((c_ptr), mask, c_vec[0]); store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n);
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n);
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n);
_mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]);
} }
} }
template <typename CType>
MEGDNN_ATTRIBUTE_TARGET("avx2") MEGDNN_ATTRIBUTE_TARGET("avx2")
static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n( static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n(
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr,
const uint32_t ldc, const uint32_t k, const uint32_t remain_m, const uint32_t ldc, const uint32_t k, const uint32_t remain_m,
uint32_t remain_n) { uint32_t remain_n) {
constexpr uint32_t k_step = 2; constexpr uint32_t k_step = 2;
...@@ -549,19 +587,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n( ...@@ -549,19 +587,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n(
} }
if (remain_n >= 8) { if (remain_n >= 8) {
_mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); store_overflow<CType>(c_ptr, c_vec[0]);
switch (remain_m) { switch (remain_m) {
case 2: case 2:
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
break; break;
case 3: case 3:
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
break; break;
case 4: case 4:
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2]);
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]);
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]);
break; break;
default: default:
break; break;
...@@ -569,43 +607,41 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n( ...@@ -569,43 +607,41 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n(
remain_n -= 8; remain_n -= 8;
if (remain_n > 0) { if (remain_n > 0) {
__m256i mask = _m256_continue_mask(remain_n); store_overflow<CType>(c_ptr + 8, c_vec[1], remain_n);
_mm256_maskstore_epi32((c_ptr + 8), mask, c_vec[1]);
switch (remain_m) { switch (remain_m) {
case 2: case 2:
_mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); store_overflow<CType>(c_ptr + ldc + 8, c_vec[3], remain_n);
break; break;
case 3: case 3:
_mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); store_overflow<CType>(c_ptr + ldc + 8, c_vec[3], remain_n);
_mm256_maskstore_epi32((c_ptr + 2 * ldc + 8), mask, store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5],
c_vec[5]); remain_n);
break; break;
case 4: case 4:
_mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); store_overflow<CType>(c_ptr + ldc + 8, c_vec[3], remain_n);
_mm256_maskstore_epi32((c_ptr + 2 * ldc + 8), mask, store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5],
c_vec[5]); remain_n);
_mm256_maskstore_epi32((c_ptr + 3 * ldc + 8), mask, store_overflow<CType>(c_ptr + 3 * ldc + 8, c_vec[7],
c_vec[7]); remain_n);
break; break;
default: default:
break; break;
} }
} }
} else { } else {
__m256i mask = _m256_continue_mask(remain_n); store_overflow<CType>(c_ptr, c_vec[0], remain_n);
_mm256_maskstore_epi32((c_ptr), mask, c_vec[0]);
switch (remain_m) { switch (remain_m) {
case 2: case 2:
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n);
break; break;
case 3: case 3:
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n);
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n);
break; break;
case 4: case 4:
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n);
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n);
_mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]); store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n);
break; break;
default: default:
break; break;
...@@ -833,4 +869,5 @@ static inline void gemm_s8s8s32_avx2_4x16x2_pack_at(dt_int16* out, ...@@ -833,4 +869,5 @@ static inline void gemm_s8s8s32_avx2_4x16x2_pack_at(dt_int16* out,
} // namespace x86 } // namespace x86
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file // vim: syntax=cpp.doxygen
\ No newline at end of file
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/gemm_common.h"
...@@ -29,6 +30,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32, ...@@ -29,6 +30,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32,
4, 16, 2, false, false, 4, 16, 2, false, false,
gemm_avx2_s8s8s32_4x16x2); gemm_avx2_s8s8s32_4x16x2);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int32,
4, 16, 2, false, false,
gemm_avx2_s8s8s16_4x16x2);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32, MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32,
4, 8, 2, false, false, 4, 8, 2, false, false,
gemm_sse_s8s8s32_4x8x2); gemm_sse_s8s8s32_4x8x2);
......
...@@ -37,6 +37,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { ...@@ -37,6 +37,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x32AVX2M4N16K2 algoint8x8x32avx2_m4n16k2; AlgoInt8x8x32AVX2M4N16K2 algoint8x8x32avx2_m4n16k2;
AlgoInt8x8x32AVX2M2N4K16 algoint8x8x32avx2_m2n4k16; AlgoInt8x8x32AVX2M2N4K16 algoint8x8x32avx2_m2n4k16;
AlgoInt8x8x32SSEM4N8K2 algoint8x8x32sse_m4n8k2; AlgoInt8x8x32SSEM4N8K2 algoint8x8x32sse_m4n8k2;
AlgoInt8x8x16AVX2 algoint8x8x16avx2_m4n16k2;
AlgoF32MK8_8x8 algof32mk8_8x8; AlgoF32MK8_8x8 algof32mk8_8x8;
public: public:
...@@ -47,6 +48,7 @@ public: ...@@ -47,6 +48,7 @@ public:
#endif #endif
} }
all_algos.emplace_back(&algoint8x8x32avx2_m4n16k2); all_algos.emplace_back(&algoint8x8x32avx2_m4n16k2);
all_algos.emplace_back(&algoint8x8x16avx2_m4n16k2);
all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16); all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16);
all_algos.emplace_back(&algoint8x8x32sse_m4n8k2); all_algos.emplace_back(&algoint8x8x32sse_m4n8k2);
all_algos.emplace_back(&algof32mk8_8x8); all_algos.emplace_back(&algof32mk8_8x8);
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
...@@ -54,6 +55,7 @@ protected: ...@@ -54,6 +55,7 @@ protected:
class AlgoInt8x8x32AVX2M2N4K16; class AlgoInt8x8x32AVX2M2N4K16;
class AlgoInt8x8x32AVX2M4N16K2; class AlgoInt8x8x32AVX2M4N16K2;
class AlgoInt8x8x32SSEM4N8K2; class AlgoInt8x8x32SSEM4N8K2;
class AlgoInt8x8x16AVX2;
class AlgoPack; class AlgoPack;
class AlgoF32MK8_8x8; class AlgoF32MK8_8x8;
}; };
......
...@@ -752,7 +752,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) { ...@@ -752,7 +752,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) {
} }
} }
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X) {
using namespace conv_bias; using namespace conv_bias;
std::vector<TestArg> args; std::vector<TestArg> args;
...@@ -807,6 +807,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { ...@@ -807,6 +807,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
.set_param(arg.param) \ .set_param(arg.param) \
.execs({arg.src, arg.filter, {}, {}, {}}); \ .execs({arg.src, arg.filter, {}, {}, {}}); \
} }
#define cb2(algo_name) \
checker.set_before_exec_callback( \
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \
checker.set_dtype(0, dtype::Int8()); \
checker.set_dtype(1, dtype::Int8()); \
checker.set_dtype(2, dtype::Int16()); \
checker.set_dtype(4, dtype::Int16()); \
for (auto&& arg : args) { \
checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); \
}
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
if (megdnn::x86::is_supported(x86::SIMDType::VNNI)) { if (megdnn::x86::is_supported(x86::SIMDType::VNNI)) {
...@@ -821,12 +831,14 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { ...@@ -821,12 +831,14 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
if (megdnn::x86::is_supported(x86::SIMDType::AVX2)) { if (megdnn::x86::is_supported(x86::SIMDType::AVX2)) {
cb("IM2COLMATMUL:X86_INT8X8X32_AVX2_2X4X16"); cb("IM2COLMATMUL:X86_INT8X8X32_AVX2_2X4X16");
cb("IM2COLMATMUL:X86_INT8X8X32_AVX2_4X16X2"); cb("IM2COLMATMUL:X86_INT8X8X32_AVX2_4X16X2");
cb2("IM2COLMATMUL:X86_INT8X8X16_AVX2");
} }
if (::megdnn::x86::is_supported(::megdnn::x86::SIMDType::SSE4_2)) { if (::megdnn::x86::is_supported(::megdnn::x86::SIMDType::SSE4_2)) {
cb("IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2"); cb("IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2");
} }
#undef cb #undef cb
#undef cb2
} }
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) { TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) {
...@@ -1964,6 +1976,39 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8) { ...@@ -1964,6 +1976,39 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8) {
shapes_and_computation.clear(); shapes_and_computation.clear();
} }
TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_8816) {
constexpr size_t RUNS = 30;
param::ConvBias param;
param.stride_h = 1;
param.stride_w = 1;
param.sparse = param::ConvBias::Sparse::DENSE;
std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(),
dtype::Int16(), dtype::Int16()};
std::vector<std::pair<SmallVector<TensorShape>, float>>
shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS) {
param.pad_h = FS / 2;
param.pad_w = FS / 2;
SmallVector<TensorShape> shapes{
{N, IC, H, W}, {OC, IC, FS, FS}, {}, {}, {}};
TensorShape dst{N, OC, (H + 2 * param.pad_h - FS) / param.stride_h + 1,
(W + 2 * param.pad_w - FS) / param.stride_w + 1};
float computations = (IC * FS * FS * dst.total_nr_elems() * 2) * 1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations));
};
bench_case(1, 48, 192, 15, 15, 1);
std::string algo_name = "IM2COLMATMUL:X86_INT8X8X16_AVX2";
benchmark_impl(param, shapes_and_computation, algo_name, RUNS,
{4, {4, 5, 6, 7}}, {1, {4}}, data_type);
shapes_and_computation.clear();
}
TEST_F(X86_BENCHMARK_MULTI_THREADS, TEST_F(X86_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8_STRIDE2) { BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8_STRIDE2) {
constexpr size_t RUNS = 50; constexpr size_t RUNS = 50;
...@@ -1985,7 +2030,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, ...@@ -1985,7 +2030,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS,
SmallVector<TensorShape> shapes{ SmallVector<TensorShape> shapes{
{N, IC, H, W}, {OC, IC, FS, FS}, {}, {}, {}}; {N, IC, H, W}, {OC, IC, FS, FS}, {}, {}, {}};
TensorShape dst{N, OC, (H + 2 * param.pad_h - FS) / param.stride_h + 1, TensorShape dst{N, OC, (H + 2 * param.pad_h - FS) / param.stride_h + 1,
(W + 2 * param.pad_w - FS) / param.pad_w + 1}; (W + 2 * param.pad_w - FS) / param.stride_w + 1};
float computations = (IC * FS * FS * dst.total_nr_elems() * 2) * 1e-6; float computations = (IC * FS * FS * dst.total_nr_elems() * 2) * 1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations)); shapes_and_computation.push_back(std::make_pair(shapes, computations));
}; };
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "test/x86/fixture.h" #include "test/x86/fixture.h"
...@@ -369,6 +370,63 @@ TEST_F(X86, CONVOLUTION_DIRECT_MKLDNN_C8) { ...@@ -369,6 +370,63 @@ TEST_F(X86, CONVOLUTION_DIRECT_MKLDNN_C8) {
#endif #endif
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x16) {
using namespace convolution;
using Param = param::Convolution;
std::vector<TestArg> args;
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
size_t stride, size_t group = 1) {
Param param;
param.stride_h = stride;
param.stride_w = stride;
param.pad_h = kernel / 2;
param.pad_w = kernel / 2;
if (group > 1) {
param.sparse = param::Convolution::Sparse::GROUP;
args.emplace_back(
param, TensorShape{1, ic, h, w},
TensorShape{group, oc / group, ic / group, kernel, kernel});
} else {
param.sparse = param::Convolution::Sparse::DENSE;
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel});
}
};
run(48, 96, 15, 15, 1, 1);
run(64, 64, 60, 60, 3, 1);
run(64, 64, 60, 60, 3, 1, 64);
constexpr size_t RUN = 30;
Benchmarker<Convolution> benchmark(handle());
benchmark.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int16());
benchmark.set_display(false);
benchmark.set_times(RUN);
for (auto&& arg : args) {
TensorLayout dst_layout;
auto opr = handle()->create_operator<Convolution>();
opr->param() = arg.param;
opr->deduce_layout({arg.src, dtype::Float32()},
{arg.filter, dtype::Float32()}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float icpg = arg.filter.ndim == 4 ? arg.filter[1] : arg.filter[2];
float filter = arg.filter.ndim == 4 ? arg.filter[2] : arg.filter[3];
float computations = dst_layout.total_nr_elems() * icpg * filter *
filter * 2.0 / (1024 * 1024 * 1024) * 1e3;
auto used_int =
benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) /
RUN;
printf("%s %s: int: %f ms %f Gflops \n", arg.src.to_string().c_str(),
arg.filter.to_string().c_str(), used_int,
computations / used_int);
}
}
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) { TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) {
using namespace convolution; using namespace convolution;
...@@ -419,7 +477,6 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) { ...@@ -419,7 +477,6 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) {
float computations = dst_layout.total_nr_elems() * arg.filter[1] * float computations = dst_layout.total_nr_elems() * arg.filter[1] *
arg.filter[2] * arg.filter[3] * 2.0 / arg.filter[2] * arg.filter[3] * 2.0 /
(1024 * 1024 * 1024) * 1e3; (1024 * 1024 * 1024) * 1e3;
auto used_int = auto used_int =
benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) /
RUN; RUN;
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "test/x86/fixture.h" #include "test/x86/fixture.h"
...@@ -47,6 +48,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) { ...@@ -47,6 +48,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "X86_INT8X8X32_AVX2_4X16X2"); handle(), "X86_INT8X8X32_AVX2_4X16X2");
} }
TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "X86_INT8X8X16_AVX2");
}
TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { TEST_F(X86, MATRIX_MUL_SSE_8X8X32) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "X86_INT8X8X32_SSE_4X8X2"); handle(), "X86_INT8X8X32_SSE_4X8X2");
...@@ -116,6 +121,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { ...@@ -116,6 +121,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
benchmarker_avx2_4x16x2.set_before_exec_callback( benchmarker_avx2_4x16x2.set_before_exec_callback(
AlgoChecker<MatrixMul>("X86_INT8X8X32_AVX2_4X16X2")); AlgoChecker<MatrixMul>("X86_INT8X8X32_AVX2_4X16X2"));
Benchmarker<MatrixMul> benchmarker_avx2_4x16x2_8816(handle());
benchmarker_avx2_4x16x2_8816.set_display(false)
.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_rng(0, rng.get())
.set_rng(1, rng.get());
benchmarker_avx2_4x16x2_8816.set_before_exec_callback(
AlgoChecker<MatrixMul>("X86_INT8X8X16_AVX2"));
Benchmarker<MatrixMul> benchmarker_avx2_2x4x16(handle()); Benchmarker<MatrixMul> benchmarker_avx2_2x4x16(handle());
benchmarker_avx2_2x4x16.set_display(false) benchmarker_avx2_2x4x16.set_display(false)
.set_times(RUNS) .set_times(RUNS)
...@@ -183,6 +199,12 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { ...@@ -183,6 +199,12 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) {
<< "k2_speed_up " << float_used / avx2_used_4x16x2 << "k2_speed_up " << float_used / avx2_used_4x16x2
<< ", k16_speed_up " << float_used / avx2_used_2x4x16 << ", k16_speed_up " << float_used / avx2_used_2x4x16
<< ","; << ",";
auto avx2_used_4x16x2_8816 =
benchmarker_avx2_4x16x2_8816.exec({{M, K}, {K, N}, {}}) /
RUNS;
std::cout << "avx2_8816: " << avx2_used_4x16x2_8816
<< " ms, 8816 throughput "
<< computations / avx2_used_4x16x2_8816 << " Gflops,";
} }
if (is_supported(SIMDType::SSE4_1)) { if (is_supported(SIMDType::SSE4_1)) {
auto sse_used = auto sse_used =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册