From 90ca85541e9017d5e1457a7eca498c3f55ec1357 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 9 Apr 2020 17:56:04 +0800 Subject: [PATCH] feat(dnn/x86): add avx2 int8 stride1 chanwise multithread conv GitOrigin-RevId: 8f310c3d139dfc27a4083f354597363681b73ba5 --- dnn/src/common/unroll_macro.h | 9 + dnn/src/x86/conv_bias/int8/algos.cpp | 60 + dnn/src/x86/conv_bias/int8/algos.h | 23 + .../x86/conv_bias/int8/avx2_chanwise_kern.cpp | 1593 +++++++++++++++++ .../x86/conv_bias/int8/avx2_chanwise_kern.h | 39 + .../conv_bias/int8/avx2_chanwise_stride1.cpp | 251 +++ .../conv_bias/int8/avx2_chanwise_stride1.h | 42 + .../int8/avx2_direct_conv_stride1.cpp | 1 - .../int8/avx2_direct_conv_stride2.cpp | 1 - dnn/src/x86/conv_bias/int8/common_helper.h | 1 + dnn/src/x86/conv_bias/opr_impl.cpp | 6 + dnn/src/x86/conv_bias/opr_impl.h | 1 + dnn/src/x86/elemwise_helper/kimpl/typecvt.h | 26 + dnn/test/x86/conv_bias.cpp | 220 +++ 14 files changed, 2271 insertions(+), 2 deletions(-) create mode 100644 dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.cpp create mode 100644 dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.h create mode 100644 dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.cpp create mode 100644 dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.h diff --git a/dnn/src/common/unroll_macro.h b/dnn/src/common/unroll_macro.h index 936286fe9..1e6549bf0 100644 --- a/dnn/src/common/unroll_macro.h +++ b/dnn/src/common/unroll_macro.h @@ -40,6 +40,15 @@ UNROLL_RAW16(cb, v0, ##a) \ cb(16, ##a) cb(17, ##a) cb(18, ##a) cb(19, ##a) cb(20, ##a) cb(21, ##a) \ cb(22, ##a) cb(23, ##a) +#define UNROLL_RAW25(cb, v0, a...) \ + UNROLL_RAW24(cb, v0, ##a) \ + cb(24, ##a) +#define UNROLL_RAW49(cb, v0, a...) \ + UNROLL_RAW25(cb, v0, ##a) \ + cb(25, ##a) cb(26, ##a) cb(27, ##a) cb(28, ##a) cb(29, ##a) cb(30, ##a) \ + cb(31, ##a) cb(32, ##a) cb(33, ##a) cb(34, ##a) cb(35, ##a) cb(36, ##a) \ + cb(37, ##a) cb(38, ##a) cb(39, ##a) cb(40, ##a) cb(41, ##a) cb(42, ##a) \ + cb(43, ##a) cb(44, ##a) cb(45, ##a) cb(46, ##a) cb(47, ##a) cb(48, ##a) #define UNROLL_CALL0(step, cb, v...) UNROLL_RAW##step(cb, 0, ##v) #define UNROLL_CALL1(step, cb, v...) UNROLL_CALL0(step, cb, ##v) diff --git a/dnn/src/x86/conv_bias/int8/algos.cpp b/dnn/src/x86/conv_bias/int8/algos.cpp index 3777d8873..24487fd50 100644 --- a/dnn/src/x86/conv_bias/int8/algos.cpp +++ b/dnn/src/x86/conv_bias/int8/algos.cpp @@ -15,6 +15,7 @@ #include "src/fallback/convolution/img2col_helper.h" #include "src/x86/conv_bias/int8/avx2_direct_conv_stride1.h" #include "src/x86/conv_bias/int8/avx2_direct_conv_stride2.h" +#include "src/x86/conv_bias/int8/avx2_chanwise_stride1.h" #include "src/x86/conv_bias/opr_impl.h" #include "src/x86/conv_bias/postprocess_helper.h" #include "src/x86/handle.h" @@ -31,6 +32,65 @@ using namespace dnnl; using namespace megdnn; using namespace x86; +bool ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::usable( + FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool aviliable = + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + param.dst_type.enumv() == DTypeEnum::QuantizedS8) || + (((param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int32) || + (param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + param.dst_type.enumv() == DTypeEnum::QuantizedS32)))) && + fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + (FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.stride[0] == 1 && + fm.stride[1] == 1 && (fm.icpg == 1) && (fm.ocpg == 1) && + is_supported(SIMDType::AVX2); + return aviliable; +} + +WorkspaceBundle ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::get_bundle( + const NCBKernSizeParam& param) { + size_t nr_threads = param.nr_threads; + size_t IH2, IW2, OH2, OW2; + size_t src_size = 0, dst_size = 0, int32_temp = 0; + + avx2_chanwise_stride1::get_rectified_size(param, IH2, IW2, OH2, OW2); + + if (avx2_chanwise_stride1::need_src_copy(param)) { + src_size = IH2 * IW2 * sizeof(int8_t) * nr_threads; + } + if (avx2_chanwise_stride1::need_dst_copy(param)) { + dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; + } + bool dst_need_convert = param.dst_type.enumv() == DTypeEnum::QuantizedS8; + + if (dst_need_convert) { + int32_temp = OH2 * OW2 * sizeof(int32_t) * nr_threads; + } + return dst_need_convert + ? WorkspaceBundle(nullptr, {src_size, dst_size, int32_temp}) + : WorkspaceBundle(nullptr, {src_size, dst_size}); +} + +size_t ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::get_workspace( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::get_kimpls( + const NCBKernSizeParam& param) const { + auto bundle = get_bundle(param); + return avx2_chanwise_stride1::get_kimpls(param, bundle); +} + bool ConvBiasImpl::AlgoDirectAvx2Stride1Int8::usable( FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { diff --git a/dnn/src/x86/conv_bias/int8/algos.h b/dnn/src/x86/conv_bias/int8/algos.h index cf3eb4280..5a63c0b7a 100644 --- a/dnn/src/x86/conv_bias/int8/algos.h +++ b/dnn/src/x86/conv_bias/int8/algos.h @@ -13,6 +13,29 @@ namespace megdnn { namespace x86 { + +/* ===================== avx2 stride1 chanwise algo ===================== */ +class ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { + return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"; + } + bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(FallbackConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override { + return get_kimpls(param); + } + void* type() const override; +}; + /* ===================== avx2 stride1 direct algo ===================== */ class ConvBiasImpl::AlgoDirectAvx2Stride1Int8 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; diff --git a/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.cpp b/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.cpp new file mode 100644 index 000000000..e00c848e8 --- /dev/null +++ b/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.cpp @@ -0,0 +1,1593 @@ +/** + * \file src/x86/conv_bias/int8/avx2_chanwise_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/x86/conv_bias/int8/avx2_chanwise_kern.h" +#include +#include "src/common/unroll_macro.h" +#include "src/x86/conv_bias/int8/common_helper.h" +#include "src/x86/elemwise_op.h" +#ifdef WIN32CMAKE +#include +#endif + +namespace megdnn { +namespace x86 { +namespace avx2_chanwise_stride1 { + +#define load_filter(i) __m128i k_##i = _mm_set1_epi8(*(filter + i)); +#define load_src0(i) \ + __m256i cvt16_src##i##0 = _mm256_cvtepi8_epi16_from_ptr(r##i); +#define load_src1(i) \ + __m256i cvt16_src##i##1 = _mm256_cvtepi8_epi16_from_ptr(r##i + 1); +#define load_src2(i) \ + __m256i cvt16_src##i##2 = _mm256_cvtepi8_epi16_from_ptr(r##i + 2); +#define load_src3(i) \ + __m256i cvt16_src##i##3 = _mm256_cvtepi8_epi16_from_ptr(r##i + 3); +#define load_src4(i) \ + __m256i cvt16_src##i##4 = _mm256_cvtepi8_epi16_from_ptr(r##i + 4); +#define load_src5(i) \ + __m256i cvt16_src##i##5 = _mm256_cvtepi8_epi16_from_ptr(r##i + 5); +#define load_src6(i) \ + __m256i cvt16_src##i##6 = _mm256_cvtepi8_epi16_from_ptr(r##i + 6); +#define load_src7(i) \ + __m256i cvt16_src##i##7 = _mm256_cvtepi8_epi16_from_ptr(r##i + 7); + +template +void avx2_chanwise_direct_stride1_2x2_int8(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + size_t tail_step = IW - OW; + int8_t* dst0 = dst; + int8_t* dst1 = dst + OW; + int32_t* out_ptr0 = temp; + int32_t* out_ptr1 = temp + OW; + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + 2 * IW; + + UNROLL_CALL0(4, load_filter) + +#define pack_filter(i, j) __m128i k_##i##j = _mm_unpacklo_epi8(k_##i, k_##j) + pack_filter(0, 1); + pack_filter(2, 3); + + __m256i bias_val; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_val = _mm256_set1_epi32(*(bias)); + } else { + bias_val = _mm256_set1_epi32(0); + } +#define cvt_filter(i, j) __m256i filter_##i##j = _mm256_cvtepi8_epi16(k_##i##j) + cvt_filter(0, 1); + cvt_filter(2, 3); + + size_t width = OW >> 4; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + size_t w = 0; + for (; w < width; w++) { + UNROLL_CALL0(3, load_src0) + UNROLL_CALL0(3, load_src1) + + __m256i sum0_odd, sum0_even, sum1_odd, sum1_even; + __m256i tmp0_odd, tmp0_even, tmp1_odd, tmp1_even, tmp2_odd, + tmp2_even, tmp3_odd, tmp3_even; + + tmp0_odd = _mm256_madd_epi16(cvt16_src00, filter_01); + tmp0_even = _mm256_madd_epi16(cvt16_src01, filter_01); + + tmp1_odd = _mm256_madd_epi16(cvt16_src10, filter_23); + tmp1_even = _mm256_madd_epi16(cvt16_src11, filter_23); + + tmp3_odd = _mm256_madd_epi16(cvt16_src10, filter_01); + tmp3_even = _mm256_madd_epi16(cvt16_src11, filter_01); + + tmp2_odd = _mm256_madd_epi16(cvt16_src20, filter_23); + tmp2_even = _mm256_madd_epi16(cvt16_src21, filter_23); + + sum0_odd = _mm256_add_epi32(tmp0_odd, tmp1_odd); + sum0_even = _mm256_add_epi32(tmp0_even, tmp1_even); + + __m256i sum_odd = _mm256_unpacklo_epi32(sum0_odd, sum0_even); + __m256i sum_even = _mm256_unpackhi_epi32(sum0_odd, sum0_even); + + //! switch_mask_low = {00100000} = 32 + //! switch_mask_high = {00110001} = 49 + __m256i sum_left = _mm256_permute2f128_si256(sum_odd, sum_even, 32); + __m256i sum_right = + _mm256_permute2f128_si256(sum_odd, sum_even, 49); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + sum1_odd = _mm256_add_epi32(tmp3_odd, tmp2_odd); + sum1_even = _mm256_add_epi32(tmp3_even, tmp2_even); + + __m256i sum_1_odd = _mm256_unpacklo_epi32(sum1_odd, sum1_even); + __m256i sum_1_even = _mm256_unpackhi_epi32(sum1_odd, sum1_even); + + __m256i sum_1_left = + _mm256_permute2f128_si256(sum_1_odd, sum_1_even, 32); + __m256i sum_1_right = + _mm256_permute2f128_si256(sum_1_odd, sum_1_even, 49); + + sum_1_left = _mm256_add_epi32(sum_1_left, bias_val); + sum_1_right = _mm256_add_epi32(sum_1_right, bias_val); + + if (is_quantized) { + op({{sum_1_left, sum_1_right}}, + reinterpret_cast(dst1)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr1), sum_1_left); + _mm256_storeu_si256((__m256i*)(out_ptr1 + 8), sum_1_right); + } + r0 += 16; + r1 += 16; + r2 += 16; + dst0 += 16; + dst1 += 16; + out_ptr0 += 16; + out_ptr1 += 16; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + + dst0 += OW; + dst1 += OW; + out_ptr0 += OW; + out_ptr1 += OW; + } + + for (; h < OH; h++) { + size_t w = 0; + for (; w < width; w++) { + UNROLL_CALL0(2, load_src0) + UNROLL_CALL0(2, load_src1) + + __m256i sum0_odd, sum0_even; + __m256i tmp0_odd, tmp0_even, tmp1_odd, tmp1_even; + + tmp0_odd = _mm256_madd_epi16(cvt16_src00, filter_01); + tmp0_even = _mm256_madd_epi16(cvt16_src01, filter_01); + + tmp1_odd = _mm256_madd_epi16(cvt16_src10, filter_23); + tmp1_even = _mm256_madd_epi16(cvt16_src11, filter_23); + + sum0_odd = _mm256_add_epi32(tmp0_odd, tmp1_odd); + sum0_even = _mm256_add_epi32(tmp0_even, tmp1_even); + + __m256i sum_odd = _mm256_unpacklo_epi32(sum0_odd, sum0_even); + __m256i sum_even = _mm256_unpackhi_epi32(sum0_odd, sum0_even); + + __m256i sum_left = _mm256_permute2f128_si256(sum_odd, sum_even, 32); + __m256i sum_right = + _mm256_permute2f128_si256(sum_odd, sum_even, 49); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + r0 += 16; + r1 += 16; + dst0 += 16; + out_ptr0 += 16; + } + r0 += tail_step; + r1 += tail_step; + } + MEGDNN_MARK_USED_VAR(IH); +#undef pack_filter +#undef cvt_filter +} + +template +void avx2_chanwise_direct_stride1_3x3_int8(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + size_t tail_step = IW - OW; + int32_t* out_ptr0 = temp; + int32_t* out_ptr1 = temp + OW; + int8_t* dst0 = dst; + int8_t* dst1 = dst + OW; + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + 2 * IW; + const int8_t* r3 = src + 3 * IW; + + uint8_t fill_zero = 0; + UNROLL_CALL0(9, load_filter) + + __m128i k_fill = _mm_set1_epi8(fill_zero); + + __m128i k01 = _mm_unpacklo_epi8(k_0, k_1); + __m128i k20 = _mm_unpacklo_epi8(k_2, k_fill); + + __m128i k34 = _mm_unpacklo_epi8(k_3, k_4); + __m128i k50 = _mm_unpacklo_epi8(k_5, k_fill); + + __m128i k67 = _mm_unpacklo_epi8(k_6, k_7); + __m128i k80 = _mm_unpacklo_epi8(k_8, k_fill); + + __m256i bias_val; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_val = _mm256_set1_epi32(*(bias)); + } else { + bias_val = _mm256_set1_epi32(0); + } + + //! cvt i8 --> i16 + __m256i filter_01 = _mm256_cvtepi8_epi16(k01); + __m256i filter_20 = _mm256_cvtepi8_epi16(k20); + __m256i filter_34 = _mm256_cvtepi8_epi16(k34); + __m256i filter_50 = _mm256_cvtepi8_epi16(k50); + __m256i filter_67 = _mm256_cvtepi8_epi16(k67); + __m256i filter_80 = _mm256_cvtepi8_epi16(k80); + + size_t width = OW >> 4; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + size_t w = 0; + for (; w < width; w++) { + UNROLL_CALL0(4, load_src0) + UNROLL_CALL0(4, load_src1) + UNROLL_CALL0(4, load_src2) + UNROLL_CALL0(4, load_src3) + + __m256i sum00_odd, sum00_even, sum11_odd, sum11_even, sum22_odd, + sum22_even; + __m256i sum11_odd_01, sum11_even_01, sum22_odd_01, sum22_even_01, + sum33_odd, sum33_even; + __m256i temp0, temp1; + + temp0 = _mm256_madd_epi16(cvt16_src00, filter_01); + temp1 = _mm256_madd_epi16(cvt16_src02, filter_20); + sum00_odd = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src01, filter_01); + temp1 = _mm256_madd_epi16(cvt16_src03, filter_20); + sum00_even = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src10, filter_34); + temp1 = _mm256_madd_epi16(cvt16_src12, filter_50); + sum11_odd = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src11, filter_34); + temp1 = _mm256_madd_epi16(cvt16_src13, filter_50); + sum11_even = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src10, filter_01); + temp1 = _mm256_madd_epi16(cvt16_src12, filter_20); + sum11_odd_01 = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src11, filter_01); + temp1 = _mm256_madd_epi16(cvt16_src13, filter_20); + sum11_even_01 = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src20, filter_67); + temp1 = _mm256_madd_epi16(cvt16_src22, filter_80); + sum22_odd = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src21, filter_67); + temp1 = _mm256_madd_epi16(cvt16_src23, filter_80); + sum22_even = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src20, filter_34); + temp1 = _mm256_madd_epi16(cvt16_src22, filter_50); + sum22_odd_01 = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src21, filter_34); + temp1 = _mm256_madd_epi16(cvt16_src23, filter_50); + sum22_even_01 = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src30, filter_67); + temp1 = _mm256_madd_epi16(cvt16_src32, filter_80); + sum33_odd = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src31, filter_67); + temp1 = _mm256_madd_epi16(cvt16_src33, filter_80); + sum33_even = _mm256_add_epi32(temp0, temp1); + + sum00_odd = _mm256_add_epi32(sum00_odd, sum11_odd); + sum00_odd = _mm256_add_epi32(sum00_odd, sum22_odd); + + sum00_even = _mm256_add_epi32(sum00_even, sum11_even); + sum00_even = _mm256_add_epi32(sum00_even, sum22_even); + + __m256i sum_odd = _mm256_unpacklo_epi32(sum00_odd, sum00_even); + __m256i sum_even = _mm256_unpackhi_epi32(sum00_odd, sum00_even); + + __m256i sum_left = _mm256_permute2f128_si256(sum_odd, sum_even, 32); + __m256i sum_right = + _mm256_permute2f128_si256(sum_odd, sum_even, 49); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + sum11_odd_01 = _mm256_add_epi32(sum11_odd_01, sum22_odd_01); + sum11_odd_01 = _mm256_add_epi32(sum11_odd_01, sum33_odd); + + sum11_even_01 = _mm256_add_epi32(sum11_even_01, sum22_even_01); + sum11_even_01 = _mm256_add_epi32(sum11_even_01, sum33_even); + + __m256i sum_oh1_odd = + _mm256_unpacklo_epi32(sum11_odd_01, sum11_even_01); + __m256i sum_oh1_even = + _mm256_unpackhi_epi32(sum11_odd_01, sum11_even_01); + + __m256i sum1_left = + _mm256_permute2f128_si256(sum_oh1_odd, sum_oh1_even, 32); + __m256i sum1_right = + _mm256_permute2f128_si256(sum_oh1_odd, sum_oh1_even, 49); + + sum1_left = _mm256_add_epi32(sum1_left, bias_val); + sum1_right = _mm256_add_epi32(sum1_right, bias_val); + + if (is_quantized) { + op({{sum1_left, sum1_right}}, + reinterpret_cast(dst1)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr1), sum1_left); + _mm256_storeu_si256((__m256i*)(out_ptr1 + 8), sum1_right); + } + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + dst0 += 16; + dst1 += 16; + out_ptr0 += 16; + out_ptr1 += 16; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + + dst0 += OW; + dst1 += OW; + out_ptr0 += OW; + out_ptr1 += OW; + } + + for (; h < OH; h++) { + size_t w = 0; + for (; w < width; w++) { + UNROLL_CALL0(3, load_src0) + UNROLL_CALL0(3, load_src1) + UNROLL_CALL0(3, load_src2) + UNROLL_CALL0(3, load_src3) + + __m256i sum00_odd, sum00_even, sum11_odd, sum11_even, sum22_odd, + sum22_even; + __m256i temp0, temp1; + + temp0 = _mm256_madd_epi16(cvt16_src00, filter_01); + temp1 = _mm256_madd_epi16(cvt16_src02, filter_20); + sum00_odd = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src01, filter_01); + temp1 = _mm256_madd_epi16(cvt16_src03, filter_20); + sum00_even = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src10, filter_34); + temp1 = _mm256_madd_epi16(cvt16_src12, filter_50); + sum11_odd = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src11, filter_34); + temp1 = _mm256_madd_epi16(cvt16_src13, filter_50); + sum11_even = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src20, filter_67); + temp1 = _mm256_madd_epi16(cvt16_src22, filter_80); + sum22_odd = _mm256_add_epi32(temp0, temp1); + + temp0 = _mm256_madd_epi16(cvt16_src21, filter_67); + temp1 = _mm256_madd_epi16(cvt16_src23, filter_80); + sum22_even = _mm256_add_epi32(temp0, temp1); + + sum00_odd = _mm256_add_epi32(sum00_odd, sum11_odd); + sum00_odd = _mm256_add_epi32(sum00_odd, sum22_odd); + + sum00_even = _mm256_add_epi32(sum00_even, sum11_even); + sum00_even = _mm256_add_epi32(sum00_even, sum22_even); + + __m256i sum_odd = _mm256_unpacklo_epi32(sum00_odd, sum00_even); + __m256i sum_even = _mm256_unpackhi_epi32(sum00_odd, sum00_even); + + __m256i sum_left = _mm256_permute2f128_si256(sum_odd, sum_even, 32); + __m256i sum_right = + _mm256_permute2f128_si256(sum_odd, sum_even, 49); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + r0 += 16; + r1 += 16; + r2 += 16; + dst0 += 16; + out_ptr0 += 16; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } +} + +template +void avx2_chanwise_direct_stride1_5x5_int8(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + size_t tail_step = IW - OW; + int8_t* dst0 = dst; + int8_t* dst1 = dst + OW; + int32_t* out_ptr0 = temp; + int32_t* out_ptr1 = temp + OW; + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + 2 * IW; + const int8_t* r3 = src + 3 * IW; + const int8_t* r4 = src + 4 * IW; + const int8_t* r5 = src + 5 * IW; + + uint8_t fill_zero = 0; + UNROLL_CALL0(25, load_filter) + + __m128i k_fill = _mm_set1_epi8(fill_zero); + + __m128i k01 = _mm_unpacklo_epi8(k_0, k_1); + __m128i k23 = _mm_unpacklo_epi8(k_2, k_3); + __m128i k40 = _mm_unpacklo_epi8(k_4, k_fill); + + __m128i k56 = _mm_unpacklo_epi8(k_5, k_6); + __m128i k78 = _mm_unpacklo_epi8(k_7, k_8); + __m128i k90 = _mm_unpacklo_epi8(k_9, k_fill); + + __m128i k1011 = _mm_unpacklo_epi8(k_10, k_11); + __m128i k1213 = _mm_unpacklo_epi8(k_12, k_13); + __m128i k140 = _mm_unpacklo_epi8(k_14, k_fill); + + __m128i k1516 = _mm_unpacklo_epi8(k_15, k_16); + __m128i k1718 = _mm_unpacklo_epi8(k_17, k_18); + __m128i k190 = _mm_unpacklo_epi8(k_19, k_fill); + + __m128i k2021 = _mm_unpacklo_epi8(k_20, k_21); + __m128i k2223 = _mm_unpacklo_epi8(k_22, k_23); + __m128i k240 = _mm_unpacklo_epi8(k_24, k_fill); + + __m256i bias_val; + //! load bias + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_val = _mm256_set1_epi32(*(bias)); + } else { + bias_val = _mm256_set1_epi32(0); + } + + //! cvt i8 --> i16 + __m256i filter_01 = _mm256_cvtepi8_epi16(k01); + __m256i filter_23 = _mm256_cvtepi8_epi16(k23); + __m256i filter_40 = _mm256_cvtepi8_epi16(k40); + + __m256i filter_56 = _mm256_cvtepi8_epi16(k56); + __m256i filter_78 = _mm256_cvtepi8_epi16(k78); + __m256i filter_90 = _mm256_cvtepi8_epi16(k90); + + __m256i filter_1011 = _mm256_cvtepi8_epi16(k1011); + __m256i filter_1213 = _mm256_cvtepi8_epi16(k1213); + __m256i filter_140 = _mm256_cvtepi8_epi16(k140); + + __m256i filter_1516 = _mm256_cvtepi8_epi16(k1516); + __m256i filter_1718 = _mm256_cvtepi8_epi16(k1718); + __m256i filter_190 = _mm256_cvtepi8_epi16(k190); + + __m256i filter_2021 = _mm256_cvtepi8_epi16(k2021); + __m256i filter_2223 = _mm256_cvtepi8_epi16(k2223); + __m256i filter_240 = _mm256_cvtepi8_epi16(k240); + + size_t width = OW >> 4; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + size_t w = 0; + for (; w < width; w++) { + UNROLL_CALL0(6, load_src0) + UNROLL_CALL0(6, load_src1) + UNROLL_CALL0(6, load_src2) + UNROLL_CALL0(6, load_src3) + UNROLL_CALL0(6, load_src4) + UNROLL_CALL0(6, load_src5) + + __m256i sum0_odd, sum0_even, sum1_odd, sum1_even, sum2_odd, + sum2_even, sum3_odd, sum3_even, sum4_odd, sum4_even; + + __m256i sum10_odd, sum10_even, sum20_odd, sum20_even, sum30_odd, + sum30_even, sum40_odd, sum40_even, sum5_odd, sum5_even; + + //! cal src0 + __m256i dot1, dot2, dot3; + dot1 = _mm256_madd_epi16(cvt16_src00, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src02, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src04, filter_40); + sum0_odd = _mm256_add_epi32(dot1, dot2); + sum0_odd = _mm256_add_epi32(sum0_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src01, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src03, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src05, filter_40); + sum0_even = _mm256_add_epi32(dot1, dot2); + sum0_even = _mm256_add_epi32(sum0_even, dot3); + + //! cal src1 + dot1 = _mm256_madd_epi16(cvt16_src10, filter_56); + dot2 = _mm256_madd_epi16(cvt16_src12, filter_78); + dot3 = _mm256_madd_epi16(cvt16_src14, filter_90); + sum1_odd = _mm256_add_epi32(dot1, dot2); + sum1_odd = _mm256_add_epi32(sum1_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src11, filter_56); + dot2 = _mm256_madd_epi16(cvt16_src13, filter_78); + dot3 = _mm256_madd_epi16(cvt16_src15, filter_90); + sum1_even = _mm256_add_epi32(dot1, dot2); + sum1_even = _mm256_add_epi32(sum1_even, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src10, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src12, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src14, filter_40); + sum10_odd = _mm256_add_epi32(dot1, dot2); + sum10_odd = _mm256_add_epi32(sum10_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src11, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src13, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src15, filter_40); + sum10_even = _mm256_add_epi32(dot1, dot2); + sum10_even = _mm256_add_epi32(sum10_even, dot3); + + //! cal src2 + dot1 = _mm256_madd_epi16(cvt16_src20, filter_1011); + dot2 = _mm256_madd_epi16(cvt16_src22, filter_1213); + dot3 = _mm256_madd_epi16(cvt16_src24, filter_140); + sum2_odd = _mm256_add_epi32(dot1, dot2); + sum2_odd = _mm256_add_epi32(sum2_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src21, filter_1011); + dot2 = _mm256_madd_epi16(cvt16_src23, filter_1213); + dot3 = _mm256_madd_epi16(cvt16_src25, filter_140); + sum2_even = _mm256_add_epi32(dot1, dot2); + sum2_even = _mm256_add_epi32(sum2_even, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src20, filter_56); + dot2 = _mm256_madd_epi16(cvt16_src22, filter_78); + dot3 = _mm256_madd_epi16(cvt16_src24, filter_90); + sum20_odd = _mm256_add_epi32(dot1, dot2); + sum20_odd = _mm256_add_epi32(sum20_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src21, filter_56); + dot2 = _mm256_madd_epi16(cvt16_src23, filter_78); + dot3 = _mm256_madd_epi16(cvt16_src25, filter_90); + sum20_even = _mm256_add_epi32(dot1, dot2); + sum20_even = _mm256_add_epi32(sum20_even, dot3); + + //! cal src3 + dot1 = _mm256_madd_epi16(cvt16_src30, filter_1516); + dot2 = _mm256_madd_epi16(cvt16_src32, filter_1718); + dot3 = _mm256_madd_epi16(cvt16_src34, filter_190); + sum3_odd = _mm256_add_epi32(dot1, dot2); + sum3_odd = _mm256_add_epi32(sum3_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src31, filter_1516); + dot2 = _mm256_madd_epi16(cvt16_src33, filter_1718); + dot3 = _mm256_madd_epi16(cvt16_src35, filter_190); + sum3_even = _mm256_add_epi32(dot1, dot2); + sum3_even = _mm256_add_epi32(sum3_even, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src30, filter_1011); + dot2 = _mm256_madd_epi16(cvt16_src32, filter_1213); + dot3 = _mm256_madd_epi16(cvt16_src34, filter_140); + sum30_odd = _mm256_add_epi32(dot1, dot2); + sum30_odd = _mm256_add_epi32(sum30_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src31, filter_1011); + dot2 = _mm256_madd_epi16(cvt16_src33, filter_1213); + dot3 = _mm256_madd_epi16(cvt16_src35, filter_140); + sum30_even = _mm256_add_epi32(dot1, dot2); + sum30_even = _mm256_add_epi32(sum30_even, dot3); + + //! cal src4 + dot1 = _mm256_madd_epi16(cvt16_src40, filter_2021); + dot2 = _mm256_madd_epi16(cvt16_src42, filter_2223); + dot3 = _mm256_madd_epi16(cvt16_src44, filter_240); + sum4_odd = _mm256_add_epi32(dot1, dot2); + sum4_odd = _mm256_add_epi32(sum4_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src41, filter_2021); + dot2 = _mm256_madd_epi16(cvt16_src43, filter_2223); + dot3 = _mm256_madd_epi16(cvt16_src45, filter_240); + sum4_even = _mm256_add_epi32(dot1, dot2); + sum4_even = _mm256_add_epi32(sum4_even, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src40, filter_1516); + dot2 = _mm256_madd_epi16(cvt16_src42, filter_1718); + dot3 = _mm256_madd_epi16(cvt16_src44, filter_190); + sum40_odd = _mm256_add_epi32(dot1, dot2); + sum40_odd = _mm256_add_epi32(sum40_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src41, filter_1516); + dot2 = _mm256_madd_epi16(cvt16_src43, filter_1718); + dot3 = _mm256_madd_epi16(cvt16_src45, filter_190); + sum40_even = _mm256_add_epi32(dot1, dot2); + sum40_even = _mm256_add_epi32(sum40_even, dot3); + + //! cal src5 + dot1 = _mm256_madd_epi16(cvt16_src50, filter_2021); + dot2 = _mm256_madd_epi16(cvt16_src52, filter_2223); + dot3 = _mm256_madd_epi16(cvt16_src54, filter_240); + sum5_odd = _mm256_add_epi32(dot1, dot2); + sum5_odd = _mm256_add_epi32(sum5_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src51, filter_2021); + dot2 = _mm256_madd_epi16(cvt16_src53, filter_2223); + dot3 = _mm256_madd_epi16(cvt16_src55, filter_240); + sum5_even = _mm256_add_epi32(dot1, dot2); + sum5_even = _mm256_add_epi32(sum5_even, dot3); + + __m256i sum_odd, sum_even; + + sum_odd = _mm256_add_epi32(sum0_odd, sum1_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum2_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum3_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum4_odd); + + sum_even = _mm256_add_epi32(sum0_even, sum1_even); + sum_even = _mm256_add_epi32(sum_even, sum2_even); + sum_even = _mm256_add_epi32(sum_even, sum3_even); + sum_even = _mm256_add_epi32(sum_even, sum4_even); + + __m256i sum_odd_0 = _mm256_unpacklo_epi32(sum_odd, sum_even); + __m256i sum_even_0 = _mm256_unpackhi_epi32(sum_odd, sum_even); + + __m256i sum_left = + _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 32); + __m256i sum_right = + _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 49); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + __m256i sum_odd_oh1, sum_even_oh1; + + sum_odd_oh1 = _mm256_add_epi32(sum10_odd, sum20_odd); + sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum30_odd); + sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum40_odd); + sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum5_odd); + + sum_even_oh1 = _mm256_add_epi32(sum10_even, sum20_even); + sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum30_even); + sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum40_even); + sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum5_even); + + __m256i sum_odd_1 = + _mm256_unpacklo_epi32(sum_odd_oh1, sum_even_oh1); + __m256i sum_even_1 = + _mm256_unpackhi_epi32(sum_odd_oh1, sum_even_oh1); + + sum_left = _mm256_permute2f128_si256(sum_odd_1, sum_even_1, 32); + sum_right = _mm256_permute2f128_si256(sum_odd_1, sum_even_1, 49); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst1)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr1), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr1 + 8), sum_right); + } + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + r5 += 16; + dst0 += 16; + dst1 += 16; + out_ptr0 += 16; + out_ptr1 += 16; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + + dst0 += OW; + dst1 += OW; + out_ptr0 += OW; + out_ptr1 += OW; + } + + for (; h < OH; h++) { + size_t w = 0; + for (; w < width; w++) { + UNROLL_CALL0(5, load_src0) + UNROLL_CALL0(5, load_src1) + UNROLL_CALL0(5, load_src2) + UNROLL_CALL0(5, load_src3) + UNROLL_CALL0(5, load_src4) + UNROLL_CALL0(5, load_src5) + + __m256i sum0_odd, sum0_even, sum1_odd, sum1_even, sum2_odd, + sum2_even, sum3_odd, sum3_even, sum4_odd, sum4_even; + + //! cal src0 + __m256i dot1, dot2, dot3; + dot1 = _mm256_madd_epi16(cvt16_src00, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src02, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src04, filter_40); + sum0_odd = _mm256_add_epi32(dot1, dot2); + sum0_odd = _mm256_add_epi32(sum0_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src01, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src03, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src05, filter_40); + sum0_even = _mm256_add_epi32(dot1, dot2); + sum0_even = _mm256_add_epi32(sum0_even, dot3); + + //! cal src1 + dot1 = _mm256_madd_epi16(cvt16_src10, filter_56); + dot2 = _mm256_madd_epi16(cvt16_src12, filter_78); + dot3 = _mm256_madd_epi16(cvt16_src14, filter_90); + sum1_odd = _mm256_add_epi32(dot1, dot2); + sum1_odd = _mm256_add_epi32(sum1_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src11, filter_56); + dot2 = _mm256_madd_epi16(cvt16_src13, filter_78); + dot3 = _mm256_madd_epi16(cvt16_src15, filter_90); + sum1_even = _mm256_add_epi32(dot1, dot2); + sum1_even = _mm256_add_epi32(sum1_even, dot3); + + //! cal src2 + dot1 = _mm256_madd_epi16(cvt16_src20, filter_1011); + dot2 = _mm256_madd_epi16(cvt16_src22, filter_1213); + dot3 = _mm256_madd_epi16(cvt16_src24, filter_140); + sum2_odd = _mm256_add_epi32(dot1, dot2); + sum2_odd = _mm256_add_epi32(sum2_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src21, filter_1011); + dot2 = _mm256_madd_epi16(cvt16_src23, filter_1213); + dot3 = _mm256_madd_epi16(cvt16_src25, filter_140); + sum2_even = _mm256_add_epi32(dot1, dot2); + sum2_even = _mm256_add_epi32(sum2_even, dot3); + + //! cal src3 + dot1 = _mm256_madd_epi16(cvt16_src30, filter_1516); + dot2 = _mm256_madd_epi16(cvt16_src32, filter_1718); + dot3 = _mm256_madd_epi16(cvt16_src34, filter_190); + sum3_odd = _mm256_add_epi32(dot1, dot2); + sum3_odd = _mm256_add_epi32(sum3_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src31, filter_1516); + dot2 = _mm256_madd_epi16(cvt16_src33, filter_1718); + dot3 = _mm256_madd_epi16(cvt16_src35, filter_190); + sum3_even = _mm256_add_epi32(dot1, dot2); + sum3_even = _mm256_add_epi32(sum3_even, dot3); + + //! cal src4 + dot1 = _mm256_madd_epi16(cvt16_src40, filter_2021); + dot2 = _mm256_madd_epi16(cvt16_src42, filter_2223); + dot3 = _mm256_madd_epi16(cvt16_src44, filter_240); + sum4_odd = _mm256_add_epi32(dot1, dot2); + sum4_odd = _mm256_add_epi32(sum4_odd, dot3); + + dot1 = _mm256_madd_epi16(cvt16_src41, filter_2021); + dot2 = _mm256_madd_epi16(cvt16_src43, filter_2223); + dot3 = _mm256_madd_epi16(cvt16_src45, filter_240); + sum4_even = _mm256_add_epi32(dot1, dot2); + sum4_even = _mm256_add_epi32(sum4_even, dot3); + + __m256i sum_odd, sum_even; + + sum_odd = _mm256_add_epi32(sum0_odd, sum1_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum2_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum3_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum4_odd); + + sum_even = _mm256_add_epi32(sum0_even, sum1_even); + sum_even = _mm256_add_epi32(sum_even, sum2_even); + sum_even = _mm256_add_epi32(sum_even, sum3_even); + sum_even = _mm256_add_epi32(sum_even, sum4_even); + + __m256i sum_odd_0 = _mm256_unpacklo_epi32(sum_odd, sum_even); + __m256i sum_even_0 = _mm256_unpackhi_epi32(sum_odd, sum_even); + + __m256i sum_left = + _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 32); + __m256i sum_right = + _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 49); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + dst0 += 16; + out_ptr0 += 16; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } +} + +template +void avx2_chanwise_direct_stride1_7x7_int8(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + size_t tail_step = IW - OW; + int8_t* dst0 = dst; + int8_t* dst1 = dst + OW; + int32_t* out_ptr0 = temp; + int32_t* out_ptr1 = temp + OW; + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + 2 * IW; + const int8_t* r3 = src + 3 * IW; + const int8_t* r4 = src + 4 * IW; + const int8_t* r5 = src + 5 * IW; + const int8_t* r6 = src + 6 * IW; + const int8_t* r7 = src + 7 * IW; + + uint8_t fill_zero = 0; + UNROLL_CALL0(49, load_filter) + + __m128i k_fill = _mm_set1_epi8(fill_zero); + + __m128i k01 = _mm_unpacklo_epi8(k_0, k_1); + __m128i k23 = _mm_unpacklo_epi8(k_2, k_3); + __m128i k45 = _mm_unpacklo_epi8(k_4, k_5); + __m128i k60 = _mm_unpacklo_epi8(k_6, k_fill); + + __m128i k78 = _mm_unpacklo_epi8(k_7, k_8); + __m128i k910 = _mm_unpacklo_epi8(k_9, k_10); + __m128i k1112 = _mm_unpacklo_epi8(k_11, k_12); + __m128i k130 = _mm_unpacklo_epi8(k_13, k_fill); + + __m128i k1415 = _mm_unpacklo_epi8(k_14, k_15); + __m128i k1617 = _mm_unpacklo_epi8(k_16, k_17); + __m128i k1819 = _mm_unpacklo_epi8(k_18, k_19); + __m128i k200 = _mm_unpacklo_epi8(k_20, k_fill); + + __m128i k2122 = _mm_unpacklo_epi8(k_21, k_22); + __m128i k2324 = _mm_unpacklo_epi8(k_23, k_24); + __m128i k2526 = _mm_unpacklo_epi8(k_25, k_26); + __m128i k270 = _mm_unpacklo_epi8(k_27, k_fill); + + __m128i k2829 = _mm_unpacklo_epi8(k_28, k_29); + __m128i k3031 = _mm_unpacklo_epi8(k_30, k_31); + __m128i k3233 = _mm_unpacklo_epi8(k_32, k_33); + __m128i k340 = _mm_unpacklo_epi8(k_34, k_fill); + + __m128i k3536 = _mm_unpacklo_epi8(k_35, k_36); + __m128i k3738 = _mm_unpacklo_epi8(k_37, k_38); + __m128i k3940 = _mm_unpacklo_epi8(k_39, k_40); + __m128i k410 = _mm_unpacklo_epi8(k_41, k_fill); + + __m128i k4243 = _mm_unpacklo_epi8(k_42, k_43); + __m128i k4445 = _mm_unpacklo_epi8(k_44, k_45); + __m128i k4647 = _mm_unpacklo_epi8(k_46, k_47); + __m128i k480 = _mm_unpacklo_epi8(k_48, k_fill); + + __m256i bias_val; + //! load bias + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_val = _mm256_set1_epi32(*(bias)); + } else { + bias_val = _mm256_set1_epi32(0); + } + + //! cvt i8 --> i16 + __m256i filter_01 = _mm256_cvtepi8_epi16(k01); + __m256i filter_23 = _mm256_cvtepi8_epi16(k23); + __m256i filter_45 = _mm256_cvtepi8_epi16(k45); + __m256i filter_60 = _mm256_cvtepi8_epi16(k60); + + __m256i filter_78 = _mm256_cvtepi8_epi16(k78); + __m256i filter_910 = _mm256_cvtepi8_epi16(k910); + __m256i filter_1112 = _mm256_cvtepi8_epi16(k1112); + __m256i filter_130 = _mm256_cvtepi8_epi16(k130); + + __m256i filter_1415 = _mm256_cvtepi8_epi16(k1415); + __m256i filter_1617 = _mm256_cvtepi8_epi16(k1617); + __m256i filter_1819 = _mm256_cvtepi8_epi16(k1819); + __m256i filter_200 = _mm256_cvtepi8_epi16(k200); + + __m256i filter_2122 = _mm256_cvtepi8_epi16(k2122); + __m256i filter_2324 = _mm256_cvtepi8_epi16(k2324); + __m256i filter_2526 = _mm256_cvtepi8_epi16(k2526); + __m256i filter_270 = _mm256_cvtepi8_epi16(k270); + + __m256i filter_2829 = _mm256_cvtepi8_epi16(k2829); + __m256i filter_3031 = _mm256_cvtepi8_epi16(k3031); + __m256i filter_3233 = _mm256_cvtepi8_epi16(k3233); + __m256i filter_340 = _mm256_cvtepi8_epi16(k340); + + __m256i filter_3536 = _mm256_cvtepi8_epi16(k3536); + __m256i filter_3738 = _mm256_cvtepi8_epi16(k3738); + __m256i filter_3940 = _mm256_cvtepi8_epi16(k3940); + __m256i filter_410 = _mm256_cvtepi8_epi16(k410); + + __m256i filter_4243 = _mm256_cvtepi8_epi16(k4243); + __m256i filter_4445 = _mm256_cvtepi8_epi16(k4445); + __m256i filter_4647 = _mm256_cvtepi8_epi16(k4647); + __m256i filter_480 = _mm256_cvtepi8_epi16(k480); + + size_t width = OW >> 4; + size_t h = 0; + for (; h + 1 < OH; h += 2) { + size_t w = 0; + for (; w < width; w++) { + UNROLL_CALL0(8, load_src0) + UNROLL_CALL0(8, load_src1) + UNROLL_CALL0(8, load_src2) + UNROLL_CALL0(8, load_src3) + UNROLL_CALL0(8, load_src4) + UNROLL_CALL0(8, load_src5) + UNROLL_CALL0(8, load_src6) + UNROLL_CALL0(8, load_src7) + + __m256i sum0_odd, sum0_even, sum1_odd, sum1_even, sum2_odd, + sum2_even, sum3_odd, sum3_even, sum4_odd, sum4_even, + sum5_odd, sum5_even, sum6_odd, sum6_even; + + __m256i sum10_odd, sum10_even, sum20_odd, sum20_even, sum30_odd, + sum30_even, sum40_odd, sum40_even, sum50_odd, sum50_even, + sum60_odd, sum60_even, sum7_odd, sum7_even; + + //! cal src0 + __m256i dot1, dot2, dot3, dot4; + dot1 = _mm256_madd_epi16(cvt16_src00, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src02, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src04, filter_45); + dot4 = _mm256_madd_epi16(cvt16_src06, filter_60); + sum0_odd = _mm256_add_epi32(dot1, dot2); + sum0_odd = _mm256_add_epi32(sum0_odd, dot3); + sum0_odd = _mm256_add_epi32(sum0_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src01, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src03, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src05, filter_45); + dot4 = _mm256_madd_epi16(cvt16_src07, filter_60); + sum0_even = _mm256_add_epi32(dot1, dot2); + sum0_even = _mm256_add_epi32(sum0_even, dot3); + sum0_even = _mm256_add_epi32(sum0_even, dot4); + + //! cal src1 + dot1 = _mm256_madd_epi16(cvt16_src10, filter_78); + dot2 = _mm256_madd_epi16(cvt16_src12, filter_910); + dot3 = _mm256_madd_epi16(cvt16_src14, filter_1112); + dot4 = _mm256_madd_epi16(cvt16_src16, filter_130); + sum1_odd = _mm256_add_epi32(dot1, dot2); + sum1_odd = _mm256_add_epi32(sum1_odd, dot3); + sum1_odd = _mm256_add_epi32(sum1_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src11, filter_78); + dot2 = _mm256_madd_epi16(cvt16_src13, filter_910); + dot3 = _mm256_madd_epi16(cvt16_src15, filter_1112); + dot4 = _mm256_madd_epi16(cvt16_src17, filter_130); + sum1_even = _mm256_add_epi32(dot1, dot2); + sum1_even = _mm256_add_epi32(sum1_even, dot3); + sum1_even = _mm256_add_epi32(sum1_even, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src10, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src12, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src14, filter_45); + dot4 = _mm256_madd_epi16(cvt16_src16, filter_60); + sum10_odd = _mm256_add_epi32(dot1, dot2); + sum10_odd = _mm256_add_epi32(sum10_odd, dot3); + sum10_odd = _mm256_add_epi32(sum10_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src11, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src13, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src15, filter_45); + dot4 = _mm256_madd_epi16(cvt16_src17, filter_60); + sum10_even = _mm256_add_epi32(dot1, dot2); + sum10_even = _mm256_add_epi32(sum10_even, dot3); + sum10_even = _mm256_add_epi32(sum10_even, dot4); + + //! cal src2 + dot1 = _mm256_madd_epi16(cvt16_src20, filter_1415); + dot2 = _mm256_madd_epi16(cvt16_src22, filter_1617); + dot3 = _mm256_madd_epi16(cvt16_src24, filter_1819); + dot4 = _mm256_madd_epi16(cvt16_src26, filter_200); + sum2_odd = _mm256_add_epi32(dot1, dot2); + sum2_odd = _mm256_add_epi32(sum2_odd, dot3); + sum2_odd = _mm256_add_epi32(sum2_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src21, filter_1415); + dot2 = _mm256_madd_epi16(cvt16_src23, filter_1617); + dot3 = _mm256_madd_epi16(cvt16_src25, filter_1819); + dot4 = _mm256_madd_epi16(cvt16_src27, filter_200); + sum2_even = _mm256_add_epi32(dot1, dot2); + sum2_even = _mm256_add_epi32(sum2_even, dot3); + sum2_even = _mm256_add_epi32(sum2_even, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src20, filter_78); + dot2 = _mm256_madd_epi16(cvt16_src22, filter_910); + dot3 = _mm256_madd_epi16(cvt16_src24, filter_1112); + dot4 = _mm256_madd_epi16(cvt16_src26, filter_130); + sum20_odd = _mm256_add_epi32(dot1, dot2); + sum20_odd = _mm256_add_epi32(sum20_odd, dot3); + sum20_odd = _mm256_add_epi32(sum20_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src21, filter_78); + dot2 = _mm256_madd_epi16(cvt16_src23, filter_910); + dot3 = _mm256_madd_epi16(cvt16_src25, filter_1112); + dot4 = _mm256_madd_epi16(cvt16_src27, filter_130); + sum20_even = _mm256_add_epi32(dot1, dot2); + sum20_even = _mm256_add_epi32(sum20_even, dot3); + sum20_even = _mm256_add_epi32(sum20_even, dot4); + + //! cal src3 + dot1 = _mm256_madd_epi16(cvt16_src30, filter_2122); + dot2 = _mm256_madd_epi16(cvt16_src32, filter_2324); + dot3 = _mm256_madd_epi16(cvt16_src34, filter_2526); + dot4 = _mm256_madd_epi16(cvt16_src36, filter_270); + sum3_odd = _mm256_add_epi32(dot1, dot2); + sum3_odd = _mm256_add_epi32(sum3_odd, dot3); + sum3_odd = _mm256_add_epi32(sum3_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src31, filter_2122); + dot2 = _mm256_madd_epi16(cvt16_src33, filter_2324); + dot3 = _mm256_madd_epi16(cvt16_src35, filter_2526); + dot4 = _mm256_madd_epi16(cvt16_src37, filter_270); + sum3_even = _mm256_add_epi32(dot1, dot2); + sum3_even = _mm256_add_epi32(sum3_even, dot3); + sum3_even = _mm256_add_epi32(sum3_even, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src30, filter_1415); + dot2 = _mm256_madd_epi16(cvt16_src32, filter_1617); + dot3 = _mm256_madd_epi16(cvt16_src34, filter_1819); + dot4 = _mm256_madd_epi16(cvt16_src36, filter_200); + sum30_odd = _mm256_add_epi32(dot1, dot2); + sum30_odd = _mm256_add_epi32(sum30_odd, dot3); + sum30_odd = _mm256_add_epi32(sum30_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src31, filter_1415); + dot2 = _mm256_madd_epi16(cvt16_src33, filter_1617); + dot3 = _mm256_madd_epi16(cvt16_src35, filter_1819); + dot4 = _mm256_madd_epi16(cvt16_src37, filter_200); + sum30_even = _mm256_add_epi32(dot1, dot2); + sum30_even = _mm256_add_epi32(sum30_even, dot3); + sum30_even = _mm256_add_epi32(sum30_even, dot4); + + //! cal src4 + dot1 = _mm256_madd_epi16(cvt16_src40, filter_2829); + dot2 = _mm256_madd_epi16(cvt16_src42, filter_3031); + dot3 = _mm256_madd_epi16(cvt16_src44, filter_3233); + dot4 = _mm256_madd_epi16(cvt16_src46, filter_340); + sum4_odd = _mm256_add_epi32(dot1, dot2); + sum4_odd = _mm256_add_epi32(sum4_odd, dot3); + sum4_odd = _mm256_add_epi32(sum4_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src41, filter_2829); + dot2 = _mm256_madd_epi16(cvt16_src43, filter_3031); + dot3 = _mm256_madd_epi16(cvt16_src45, filter_3233); + dot4 = _mm256_madd_epi16(cvt16_src47, filter_340); + sum4_even = _mm256_add_epi32(dot1, dot2); + sum4_even = _mm256_add_epi32(sum4_even, dot3); + sum4_even = _mm256_add_epi32(sum4_even, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src40, filter_2122); + dot2 = _mm256_madd_epi16(cvt16_src42, filter_2324); + dot3 = _mm256_madd_epi16(cvt16_src44, filter_2526); + dot4 = _mm256_madd_epi16(cvt16_src46, filter_270); + sum40_odd = _mm256_add_epi32(dot1, dot2); + sum40_odd = _mm256_add_epi32(sum40_odd, dot3); + sum40_odd = _mm256_add_epi32(sum40_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src41, filter_2122); + dot2 = _mm256_madd_epi16(cvt16_src43, filter_2324); + dot3 = _mm256_madd_epi16(cvt16_src45, filter_2526); + dot4 = _mm256_madd_epi16(cvt16_src47, filter_270); + sum40_even = _mm256_add_epi32(dot1, dot2); + sum40_even = _mm256_add_epi32(sum40_even, dot3); + sum40_even = _mm256_add_epi32(sum40_even, dot4); + + //! cal src5 + dot1 = _mm256_madd_epi16(cvt16_src50, filter_3536); + dot2 = _mm256_madd_epi16(cvt16_src52, filter_3738); + dot3 = _mm256_madd_epi16(cvt16_src54, filter_3940); + dot4 = _mm256_madd_epi16(cvt16_src56, filter_410); + sum5_odd = _mm256_add_epi32(dot1, dot2); + sum5_odd = _mm256_add_epi32(sum5_odd, dot3); + sum5_odd = _mm256_add_epi32(sum5_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src51, filter_3536); + dot2 = _mm256_madd_epi16(cvt16_src53, filter_3738); + dot3 = _mm256_madd_epi16(cvt16_src55, filter_3940); + dot4 = _mm256_madd_epi16(cvt16_src57, filter_410); + sum5_even = _mm256_add_epi32(dot1, dot2); + sum5_even = _mm256_add_epi32(sum5_even, dot3); + sum5_even = _mm256_add_epi32(sum5_even, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src50, filter_2829); + dot2 = _mm256_madd_epi16(cvt16_src52, filter_3031); + dot3 = _mm256_madd_epi16(cvt16_src54, filter_3233); + dot4 = _mm256_madd_epi16(cvt16_src56, filter_340); + sum50_odd = _mm256_add_epi32(dot1, dot2); + sum50_odd = _mm256_add_epi32(sum50_odd, dot3); + sum50_odd = _mm256_add_epi32(sum50_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src51, filter_2829); + dot2 = _mm256_madd_epi16(cvt16_src53, filter_3031); + dot3 = _mm256_madd_epi16(cvt16_src55, filter_3233); + dot4 = _mm256_madd_epi16(cvt16_src57, filter_340); + sum50_even = _mm256_add_epi32(dot1, dot2); + sum50_even = _mm256_add_epi32(sum50_even, dot3); + sum50_even = _mm256_add_epi32(sum50_even, dot4); + + //! cal src6 + dot1 = _mm256_madd_epi16(cvt16_src60, filter_4243); + dot2 = _mm256_madd_epi16(cvt16_src62, filter_4445); + dot3 = _mm256_madd_epi16(cvt16_src64, filter_4647); + dot4 = _mm256_madd_epi16(cvt16_src66, filter_480); + sum6_odd = _mm256_add_epi32(dot1, dot2); + sum6_odd = _mm256_add_epi32(sum6_odd, dot3); + sum6_odd = _mm256_add_epi32(sum6_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src61, filter_4243); + dot2 = _mm256_madd_epi16(cvt16_src63, filter_4445); + dot3 = _mm256_madd_epi16(cvt16_src65, filter_4647); + dot4 = _mm256_madd_epi16(cvt16_src67, filter_480); + sum6_even = _mm256_add_epi32(dot1, dot2); + sum6_even = _mm256_add_epi32(sum6_even, dot3); + sum6_even = _mm256_add_epi32(sum6_even, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src60, filter_3536); + dot2 = _mm256_madd_epi16(cvt16_src62, filter_3738); + dot3 = _mm256_madd_epi16(cvt16_src64, filter_3940); + dot4 = _mm256_madd_epi16(cvt16_src66, filter_410); + sum60_odd = _mm256_add_epi32(dot1, dot2); + sum60_odd = _mm256_add_epi32(sum60_odd, dot3); + sum60_odd = _mm256_add_epi32(sum60_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src61, filter_3536); + dot2 = _mm256_madd_epi16(cvt16_src63, filter_3738); + dot3 = _mm256_madd_epi16(cvt16_src65, filter_3940); + dot4 = _mm256_madd_epi16(cvt16_src67, filter_410); + sum60_even = _mm256_add_epi32(dot1, dot2); + sum60_even = _mm256_add_epi32(sum60_even, dot3); + sum60_even = _mm256_add_epi32(sum60_even, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src70, filter_4243); + dot2 = _mm256_madd_epi16(cvt16_src72, filter_4445); + dot3 = _mm256_madd_epi16(cvt16_src74, filter_4647); + dot4 = _mm256_madd_epi16(cvt16_src76, filter_480); + sum7_odd = _mm256_add_epi32(dot1, dot2); + sum7_odd = _mm256_add_epi32(sum7_odd, dot3); + sum7_odd = _mm256_add_epi32(sum7_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src71, filter_4243); + dot2 = _mm256_madd_epi16(cvt16_src73, filter_4445); + dot3 = _mm256_madd_epi16(cvt16_src75, filter_4647); + dot4 = _mm256_madd_epi16(cvt16_src77, filter_480); + sum7_even = _mm256_add_epi32(dot1, dot2); + sum7_even = _mm256_add_epi32(sum7_even, dot3); + sum7_even = _mm256_add_epi32(sum7_even, dot4); + + __m256i sum_odd, sum_even; + + //! add src0 ~ src6 + sum_odd = _mm256_add_epi32(sum0_odd, sum1_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum2_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum3_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum4_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum5_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum6_odd); + + sum_even = _mm256_add_epi32(sum0_even, sum1_even); + sum_even = _mm256_add_epi32(sum_even, sum2_even); + sum_even = _mm256_add_epi32(sum_even, sum3_even); + sum_even = _mm256_add_epi32(sum_even, sum4_even); + sum_even = _mm256_add_epi32(sum_even, sum5_even); + sum_even = _mm256_add_epi32(sum_even, sum6_even); + + __m256i sum_odd_0 = _mm256_unpacklo_epi32(sum_odd, sum_even); + __m256i sum_even_0 = _mm256_unpackhi_epi32(sum_odd, sum_even); + + __m256i sum_left = + _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 32); + __m256i sum_right = + _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 49); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + __m256i sum_odd_oh1, sum_even_oh1; + + //! add src1 ~ src7 + sum_odd_oh1 = _mm256_add_epi32(sum10_odd, sum20_odd); + sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum30_odd); + sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum40_odd); + sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum50_odd); + sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum60_odd); + sum_odd_oh1 = _mm256_add_epi32(sum_odd_oh1, sum7_odd); + + sum_even_oh1 = _mm256_add_epi32(sum10_even, sum20_even); + sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum30_even); + sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum40_even); + sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum50_even); + sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum60_even); + sum_even_oh1 = _mm256_add_epi32(sum_even_oh1, sum7_even); + + __m256i sum_odd_1 = + _mm256_unpacklo_epi32(sum_odd_oh1, sum_even_oh1); + __m256i sum_even_1 = + _mm256_unpackhi_epi32(sum_odd_oh1, sum_even_oh1); + + sum_left = _mm256_permute2f128_si256(sum_odd_1, sum_even_1, 32); + sum_right = _mm256_permute2f128_si256(sum_odd_1, sum_even_1, 49); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst1)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr1), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr1 + 8), sum_right); + } + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + r5 += 16; + r6 += 16; + r7 += 16; + dst0 += 16; + dst1 += 16; + out_ptr0 += 16; + out_ptr1 += 16; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + r6 += tail_step + IW; + r7 += tail_step + IW; + + dst0 += OW; + dst1 += OW; + out_ptr0 += OW; + out_ptr1 += OW; + } + + for (; h < OH; h++) { + size_t w = 0; + for (; w < width; w++) { + UNROLL_CALL0(7, load_src0) + UNROLL_CALL0(7, load_src1) + UNROLL_CALL0(7, load_src2) + UNROLL_CALL0(7, load_src3) + UNROLL_CALL0(7, load_src4) + UNROLL_CALL0(7, load_src5) + UNROLL_CALL0(7, load_src6) + UNROLL_CALL0(7, load_src7) + __m256i sum0_odd, sum0_even, sum1_odd, sum1_even, sum2_odd, + sum2_even, sum3_odd, sum3_even, sum4_odd, sum4_even, + sum5_odd, sum5_even, sum6_odd, sum6_even; + + //! cal src0 + __m256i dot1, dot2, dot3, dot4; + dot1 = _mm256_madd_epi16(cvt16_src00, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src02, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src04, filter_45); + dot4 = _mm256_madd_epi16(cvt16_src06, filter_60); + sum0_odd = _mm256_add_epi32(dot1, dot2); + sum0_odd = _mm256_add_epi32(sum0_odd, dot3); + sum0_odd = _mm256_add_epi32(sum0_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src01, filter_01); + dot2 = _mm256_madd_epi16(cvt16_src03, filter_23); + dot3 = _mm256_madd_epi16(cvt16_src05, filter_45); + dot4 = _mm256_madd_epi16(cvt16_src07, filter_60); + sum0_even = _mm256_add_epi32(dot1, dot2); + sum0_even = _mm256_add_epi32(sum0_even, dot3); + sum0_even = _mm256_add_epi32(sum0_even, dot4); + + //! cal src1 + dot1 = _mm256_madd_epi16(cvt16_src10, filter_78); + dot2 = _mm256_madd_epi16(cvt16_src12, filter_910); + dot3 = _mm256_madd_epi16(cvt16_src14, filter_1112); + dot4 = _mm256_madd_epi16(cvt16_src16, filter_130); + sum1_odd = _mm256_add_epi32(dot1, dot2); + sum1_odd = _mm256_add_epi32(sum1_odd, dot3); + sum1_odd = _mm256_add_epi32(sum1_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src11, filter_78); + dot2 = _mm256_madd_epi16(cvt16_src13, filter_910); + dot3 = _mm256_madd_epi16(cvt16_src15, filter_1112); + dot4 = _mm256_madd_epi16(cvt16_src17, filter_130); + sum1_even = _mm256_add_epi32(dot1, dot2); + sum1_even = _mm256_add_epi32(sum1_even, dot3); + sum1_even = _mm256_add_epi32(sum1_even, dot4); + + //! cal src2 + dot1 = _mm256_madd_epi16(cvt16_src20, filter_1415); + dot2 = _mm256_madd_epi16(cvt16_src22, filter_1617); + dot3 = _mm256_madd_epi16(cvt16_src24, filter_1819); + dot4 = _mm256_madd_epi16(cvt16_src26, filter_200); + sum2_odd = _mm256_add_epi32(dot1, dot2); + sum2_odd = _mm256_add_epi32(sum2_odd, dot3); + sum2_odd = _mm256_add_epi32(sum2_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src21, filter_1415); + dot2 = _mm256_madd_epi16(cvt16_src23, filter_1617); + dot3 = _mm256_madd_epi16(cvt16_src25, filter_1819); + dot4 = _mm256_madd_epi16(cvt16_src27, filter_200); + sum2_even = _mm256_add_epi32(dot1, dot2); + sum2_even = _mm256_add_epi32(sum2_even, dot3); + sum2_even = _mm256_add_epi32(sum2_even, dot4); + + //! cal src3 + dot1 = _mm256_madd_epi16(cvt16_src30, filter_2122); + dot2 = _mm256_madd_epi16(cvt16_src32, filter_2324); + dot3 = _mm256_madd_epi16(cvt16_src34, filter_2526); + dot4 = _mm256_madd_epi16(cvt16_src36, filter_270); + sum3_odd = _mm256_add_epi32(dot1, dot2); + sum3_odd = _mm256_add_epi32(sum3_odd, dot3); + sum3_odd = _mm256_add_epi32(sum3_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src31, filter_2122); + dot2 = _mm256_madd_epi16(cvt16_src33, filter_2324); + dot3 = _mm256_madd_epi16(cvt16_src35, filter_2526); + dot4 = _mm256_madd_epi16(cvt16_src37, filter_270); + sum3_even = _mm256_add_epi32(dot1, dot2); + sum3_even = _mm256_add_epi32(sum3_even, dot3); + sum3_even = _mm256_add_epi32(sum3_even, dot4); + + //! cal src4 + dot1 = _mm256_madd_epi16(cvt16_src40, filter_2829); + dot2 = _mm256_madd_epi16(cvt16_src42, filter_3031); + dot3 = _mm256_madd_epi16(cvt16_src44, filter_3233); + dot4 = _mm256_madd_epi16(cvt16_src46, filter_340); + sum4_odd = _mm256_add_epi32(dot1, dot2); + sum4_odd = _mm256_add_epi32(sum4_odd, dot3); + sum4_odd = _mm256_add_epi32(sum4_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src41, filter_2829); + dot2 = _mm256_madd_epi16(cvt16_src43, filter_3031); + dot3 = _mm256_madd_epi16(cvt16_src45, filter_3233); + dot4 = _mm256_madd_epi16(cvt16_src47, filter_340); + sum4_even = _mm256_add_epi32(dot1, dot2); + sum4_even = _mm256_add_epi32(sum4_even, dot3); + sum4_even = _mm256_add_epi32(sum4_even, dot4); + + //! cal src5 + dot1 = _mm256_madd_epi16(cvt16_src50, filter_3536); + dot2 = _mm256_madd_epi16(cvt16_src52, filter_3738); + dot3 = _mm256_madd_epi16(cvt16_src54, filter_3940); + dot4 = _mm256_madd_epi16(cvt16_src56, filter_410); + sum5_odd = _mm256_add_epi32(dot1, dot2); + sum5_odd = _mm256_add_epi32(sum5_odd, dot3); + sum5_odd = _mm256_add_epi32(sum5_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src51, filter_3536); + dot2 = _mm256_madd_epi16(cvt16_src53, filter_3738); + dot3 = _mm256_madd_epi16(cvt16_src55, filter_3940); + dot4 = _mm256_madd_epi16(cvt16_src57, filter_410); + sum5_even = _mm256_add_epi32(dot1, dot2); + sum5_even = _mm256_add_epi32(sum5_even, dot3); + sum5_even = _mm256_add_epi32(sum5_even, dot4); + + //! cal src6 + dot1 = _mm256_madd_epi16(cvt16_src60, filter_4243); + dot2 = _mm256_madd_epi16(cvt16_src62, filter_4445); + dot3 = _mm256_madd_epi16(cvt16_src64, filter_4647); + dot4 = _mm256_madd_epi16(cvt16_src66, filter_480); + sum6_odd = _mm256_add_epi32(dot1, dot2); + sum6_odd = _mm256_add_epi32(sum6_odd, dot3); + sum6_odd = _mm256_add_epi32(sum6_odd, dot4); + + dot1 = _mm256_madd_epi16(cvt16_src61, filter_4243); + dot2 = _mm256_madd_epi16(cvt16_src63, filter_4445); + dot3 = _mm256_madd_epi16(cvt16_src65, filter_4647); + dot4 = _mm256_madd_epi16(cvt16_src67, filter_480); + sum6_even = _mm256_add_epi32(dot1, dot2); + sum6_even = _mm256_add_epi32(sum6_even, dot3); + sum6_even = _mm256_add_epi32(sum6_even, dot4); + + __m256i sum_odd, sum_even; + + //! add src0 ~ src6 + sum_odd = _mm256_add_epi32(sum0_odd, sum1_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum2_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum3_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum4_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum5_odd); + sum_odd = _mm256_add_epi32(sum_odd, sum6_odd); + + sum_even = _mm256_add_epi32(sum0_even, sum1_even); + sum_even = _mm256_add_epi32(sum_even, sum2_even); + sum_even = _mm256_add_epi32(sum_even, sum3_even); + sum_even = _mm256_add_epi32(sum_even, sum4_even); + sum_even = _mm256_add_epi32(sum_even, sum5_even); + sum_even = _mm256_add_epi32(sum_even, sum6_even); + + __m256i sum_odd_0 = _mm256_unpacklo_epi32(sum_odd, sum_even); + __m256i sum_even_0 = _mm256_unpackhi_epi32(sum_odd, sum_even); + + __m256i sum_left = + _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 32); + __m256i sum_right = + _mm256_permute2f128_si256(sum_odd_0, sum_even_0, 49); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + r0 += 16; + r1 += 16; + r2 += 16; + r3 += 16; + r4 += 16; + r5 += 16; + r6 += 16; + dst0 += 16; + out_ptr0 += 16; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } +} +#undef load_filter +#undef load_src0 +#undef load_src1 +#undef load_src2 +#undef load_src3 +#undef load_src4 +#undef load_src5 +#undef load_src6 +#undef load_src7 + +#define INSTANTIATION(stride, i, bias, is_quantized, Op) \ + template void avx2_chanwise_direct_##stride##_##i##x##i##_int8< \ + bias, is_quantized, Op>(const int8_t*, const int8_t*, \ + const int32_t*, int32_t*, int8_t*, \ + const size_t, const size_t, const size_t, \ + const size_t, const Op&); + +#define FOR_OP(stride, i, is_quantized, bias) \ + INSTANTIATION(stride, i, bias, is_quantized, \ + TypeCvtOp) \ + INSTANTIATION(stride, i, bias, is_quantized, \ + ReluOp) \ + INSTANTIATION(stride, i, bias, is_quantized, \ + HSwishOp) + +#define FOR_BIAS(stride, i, is_quantized) \ + FOR_OP(stride, i, is_quantized, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, is_quantized, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_QUANTIZED(stride, i) \ + FOR_BIAS(stride, i, true) \ + FOR_BIAS(stride, i, false) + +#define FOR_FILTER(stride) \ + FOR_QUANTIZED(stride, 2) \ + FOR_QUANTIZED(stride, 3) \ + FOR_QUANTIZED(stride, 5) \ + FOR_QUANTIZED(stride, 7) + +#define FOR_STRIDE FOR_FILTER(stride1) + +FOR_STRIDE + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_QUANTIZED +#undef FOR_BIAS +#undef FOR_OP +#undef INSTANTIATION +} // namespace avx2_chanwise_stride1 +} // namespace x86 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.h b/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.h new file mode 100644 index 000000000..571f4f394 --- /dev/null +++ b/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.h @@ -0,0 +1,39 @@ +/** + * \file src/x86/conv_bias/int8/avx2_chanwsie_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include "src/x86/conv_bias/opr_impl.h" + +namespace megdnn { +namespace x86 { +namespace avx2_chanwise_stride1 { + +#define KERN(stride, i) \ + template \ + MEGDNN_ATTRIBUTE_TARGET("avx2") \ + void avx2_chanwise_direct_##stride##_##i##x##i##_int8( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, \ + const size_t OH, const size_t OW, const Op& op); + +KERN(stride1, 2) +KERN(stride1, 3) +KERN(stride1, 5) +KERN(stride1, 7) + +#undef KERN + +} // namespace avx2_chanwise_stride1 +} // namespace x86 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.cpp b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.cpp new file mode 100644 index 000000000..25fb140b8 --- /dev/null +++ b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.cpp @@ -0,0 +1,251 @@ +/** + * \file src/x86/conv_bias/int8/avx2_chanwsie_stride1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/x86/conv_bias/int8/avx2_chanwise_stride1.h" +#include "src/x86/conv_bias/int8/avx2_chanwise_kern.h" +#include "src/x86/elemwise_op.h" + +namespace megdnn { +namespace x86 { +namespace avx2_chanwise_stride1 { + +bool need_dst_copy(const NCBKernSizeParam& param) { + return param.osz[1] % 16; +} +bool need_src_copy(const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + return (fm.padding[0] != 0 || fm.padding[1] != 0) ? true + : need_dst_copy(param); +} +void get_rectified_size(const NCBKernSizeParam& param, size_t& IH2, size_t& IW2, + size_t& OH2, size_t& OW2) { + auto&& fm = param.filter_meta; + auto SW = fm.stride[1]; + auto OH = param.osz[0]; + auto OW = param.osz[1]; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 15) & ~15; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; +} +void copy_padding_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + size_t padding_group_size = IH2 * IW2; + bundle.set(kern_param.workspace_ptr); + + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1], + channel_id = ncb_index.ndrange_id[2]; + size_t workspace_group_id = ncb_index.thread_id; + const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); + if (need_src_copy_var) { + int8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size; + std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); + rep(ih, IH) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); + } + } +}; +template +void conv_kimpl(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + bool need_dst_copy_var = need_dst_copy(kern_param); + bool need_post_process = + kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + + Op op = Op(1.0f, 4.0f); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + size_t padding_group_size = IH2 * IW2; + + bundle.set(kern_param.workspace_ptr); + + size_t workspace_group_id = ncb_index.thread_id; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + + const int8_t* sptr = kern_param.src(batch_id, group_id); + const int8_t* fptr = + kern_param.filter(group_id); + void* dst = kern_param.dst(batch_id, group_id); + const int32_t* bptr = kern_param.bias(batch_id, group_id); + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size; + } + void* dptr = nullptr; + int32_t* tptr = nullptr; + if (need_dst_copy_var) { + dptr = reinterpret_cast( + reinterpret_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); + } else { + dptr = dst; + } + +#define KERN_NEED_POST_PROCESS(filter) \ + avx2_chanwise_direct_stride1_##filter##x##filter##_int8( \ + sptr, fptr, bptr, tptr, static_cast(dptr), IH2, IW2, OH2, \ + OW2, op) + +#define KERN_NO_POST_PROCESS(filter) \ + avx2_chanwise_direct_stride1_##filter##x##filter##_int8( \ + sptr, fptr, bptr, static_cast(dptr), nullptr, IH2, IW2, \ + OH2, OW2, op) + + if (need_post_process) { + tptr = static_cast(bundle.get(2)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); + DISPATCH_FILTER(filter, KERN_NEED_POST_PROCESS) + } else { + DISPATCH_FILTER(filter, KERN_NO_POST_PROCESS) + } + +#undef KERN_NEED_POST_PROCESS +#undef KERN_NO_POST_PROCESS + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); + } + } +}; +SmallVector get_kimpls(const NCBKernSizeParam& kern_param, + WorkspaceBundle bundle) { + MEGDNN_MARK_USED_VAR(kern_param); + auto fm = kern_param.filter_meta; + size_t group = fm.group; + size_t n = kern_param.n; + + SmallVector ncb_kerns; + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode, is_quantized, op) \ + do_conv_fun = conv_kimpl; + +#define GET_OP_PARAM(i, bias_mode, is_quantized) \ + switch (kern_param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i, is_quantized) \ + switch (kern_param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS, is_quantized) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS, is_quantized) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_QUANTIZED(i) \ + switch (kern_param.dst_type.enumv()) { \ + case DTypeEnum::QuantizedS8: \ + GET_BIAS_MODE_PARAM(i, true) \ + break; \ + case DTypeEnum::QuantizedS32: \ + GET_BIAS_MODE_PARAM(i, false) \ + break; \ + case DTypeEnum::Int32: \ + GET_BIAS_MODE_PARAM(i, false) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_KERN() \ + switch (kern_param.filter_meta.spatial[0]) { \ + case 2: \ + GET_QUANTIZED(2) \ + break; \ + case 3: \ + GET_QUANTIZED(3) \ + break; \ + case 5: \ + GET_QUANTIZED(5) \ + break; \ + case 7: \ + GET_QUANTIZED(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + + auto exec_one_group = [bundle, do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index); + do_conv_fun(bundle, kern_param, ncb_index); + }; + ncb_kerns.push_back({exec_one_group, {group, n, 1_z}}); + + return ncb_kerns; +} + +} // namespace avx2_chanwise_stride1 +} // namespace x86 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.h b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.h new file mode 100644 index 000000000..f57b23ab1 --- /dev/null +++ b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.h @@ -0,0 +1,42 @@ +/** + * \file src/x86/conv_bias/int8/avx2_chanwsie_stride1.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include "src/x86/conv_bias/opr_impl.h" + +namespace megdnn { +namespace x86 { +namespace avx2_chanwise_stride1 { +using NCBKern = fallback::ConvBiasImpl::NCBKern; +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +using conv_fun = std::function; + +bool need_dst_copy(const NCBKernSizeParam& param); + +bool need_src_copy(const NCBKernSizeParam& param); + +void get_rectified_size(const NCBKernSizeParam& param, size_t& IH2, size_t& IW2, + size_t& OH2, size_t& OW2); + +SmallVector get_kimpls(const NCBKernSizeParam& param, + WorkspaceBundle bundle); + +} // namespace avx2_chanwise_stride1 +} // namespace x86 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride1.cpp b/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride1.cpp index 865f9399a..a5d5baea8 100644 --- a/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride1.cpp +++ b/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride1.cpp @@ -10,7 +10,6 @@ */ #include "src/x86/conv_bias/int8/avx2_direct_conv_stride1.h" -#include "src/common/unroll_macro.h" #include "src/x86/conv_bias/int8/common_helper.h" #include "src/x86/conv_bias/postprocess_helper.h" diff --git a/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride2.cpp b/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride2.cpp index 28f3a4b76..aceb285a7 100644 --- a/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride2.cpp +++ b/dnn/src/x86/conv_bias/int8/avx2_direct_conv_stride2.cpp @@ -10,7 +10,6 @@ */ #include "src/x86/conv_bias/int8/avx2_direct_conv_stride2.h" -#include "src/common/unroll_macro.h" #include "src/x86/conv_bias/int8/common_helper.h" #include "src/x86/conv_bias/postprocess_helper.h" diff --git a/dnn/src/x86/conv_bias/int8/common_helper.h b/dnn/src/x86/conv_bias/int8/common_helper.h index 39d4e002f..4fd875ee7 100644 --- a/dnn/src/x86/conv_bias/int8/common_helper.h +++ b/dnn/src/x86/conv_bias/int8/common_helper.h @@ -11,6 +11,7 @@ #pragma once #include +#include "src/common/unroll_macro.h" #include "megdnn/arch.h" #ifdef WIN32CMAKE #include diff --git a/dnn/src/x86/conv_bias/opr_impl.cpp b/dnn/src/x86/conv_bias/opr_impl.cpp index 40261739b..2669ef2fd 100644 --- a/dnn/src/x86/conv_bias/opr_impl.cpp +++ b/dnn/src/x86/conv_bias/opr_impl.cpp @@ -65,6 +65,10 @@ void* ConvBiasImpl::AlgoAVX2DirectConvStride2::type() const { return x86_algo_type; } +void* ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::type() const { + return x86_algo_type; +} + class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDirect stride1_direct_large_group{true}; AlgoDirect stride1_direct_small_group{false}; @@ -72,6 +76,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDirectStride2 stride2_direct_small_group{false}; AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8; AlgoAVX2DirectConvStride2 avx2_stride2_direct; + AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8; AlgoMatrixMul matmul; #if defined(MEGDNN_X86_WITH_MKL_DNN) AlgoMkldnnMatmulQint8 mkldnn_matmul_qint8; @@ -94,6 +99,7 @@ public: all_algos.emplace_back(&stride2_direct_small_group); all_algos.emplace_back(&avx2_stride1_direct_int8); all_algos.emplace_back(&avx2_stride2_direct); + all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); all_algos.emplace_back(&matmul); static CpuOprDelegationStorage<> storage; diff --git a/dnn/src/x86/conv_bias/opr_impl.h b/dnn/src/x86/conv_bias/opr_impl.h index 88f1ad4f5..dc83ef0e3 100644 --- a/dnn/src/x86/conv_bias/opr_impl.h +++ b/dnn/src/x86/conv_bias/opr_impl.h @@ -31,6 +31,7 @@ public: class AlgoMatrixMul; class AlgoDirectAvx2Stride1Int8; class AlgoAVX2DirectConvStride2; + class AlgoChanWiseAvx2Stride1Qint8; #if defined(MEGDNN_X86_WITH_MKL_DNN) class AlgoMkldnnConv; class AlgoMkldnnQint8; diff --git a/dnn/src/x86/elemwise_helper/kimpl/typecvt.h b/dnn/src/x86/elemwise_helper/kimpl/typecvt.h index eed5a5e6b..e1885bb6d 100644 --- a/dnn/src/x86/elemwise_helper/kimpl/typecvt.h +++ b/dnn/src/x86/elemwise_helper/kimpl/typecvt.h @@ -257,6 +257,32 @@ struct TypeCvtOp } }; +template <> +struct TypeCvtOp + : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + constexpr static size_t SIMD_WIDTH = 8; + + MEGDNN_ATTRIBUTE_TARGET("avx2") + void operator()(const __m256ix2& vsrc, dt_qint8* dst) const { + _mm_store_si128((__m128i*)(dst), (operator()(vsrc))); + } + + MEGDNN_ATTRIBUTE_TARGET("avx2") + __m128i operator()(const __m256ix2& vsrc) const { + auto cvtps_src0 = _mm256_cvtepi32_ps(vsrc.val[0]); + auto cvtps_src1 = _mm256_cvtepi32_ps(vsrc.val[1]); + auto vitem0 = _mm256_mul_ps(cvtps_src0, _mm256_set1_ps(this->scale)); + auto vitem1 = _mm256_mul_ps(cvtps_src1, _mm256_set1_ps(this->scale)); + return QConverter::convert<__m128i, __m256x2>({{vitem0, vitem1}}); + } + + void operator()(src_ctype src, dst_ctype* dst) { + *reinterpret_cast(dst) = saturate( + std::round(src.as_int32() * scale), -128, 127); + } +}; + template <> struct TypeCvtOp : UnaryOpBase { diff --git a/dnn/test/x86/conv_bias.cpp b/dnn/test/x86/conv_bias.cpp index 6550f25f1..8cdad38a5 100644 --- a/dnn/test/x86/conv_bias.cpp +++ b/dnn/test/x86/conv_bias.cpp @@ -40,6 +40,165 @@ TEST_F(X86, CONV_BIAS_FORWARD) { .execs({arg.src, arg.filter, arg.bias, {}, {}}); } } + +TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { + using namespace conv_bias; + std::vector args; + + auto run = [&](size_t ic, size_t w, size_t h, size_t kernel, size_t p, + NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + param.sparse = param::ConvBias::Sparse::GROUP; + //! no bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{ic, 1, 1, kernel, kernel}, TensorShape{}); + //! bias channel + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{ic, 1, 1, kernel, kernel}, + TensorShape{1, ic, 1, 1}); + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t pad : {0, 1}) + for (size_t ic : {1, 5, 17, 20}) + for (size_t h : {7, 16, 38, 40}) + for (size_t w : {16, 25, 40, 55}) + for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) + run(ic, w, h, kernel, pad, nonline_mode); + + Checker checker(handle()); + UniformIntRNG rng{-50, 50}; + checker.set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()) + .set_dtype(4, dtype::Int32()) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_epsilon(1e-3); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1")); + for (auto&& arg : args) { + checker.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}); + } +} + +TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { + using namespace conv_bias; + std::vector args; + + auto run = [&](size_t ic, size_t w, size_t h, size_t kernel, size_t p, + NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + param.sparse = param::ConvBias::Sparse::GROUP; + //! no bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{ic, 1, 1, kernel, kernel}, TensorShape{}); + //! bias channel + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{ic, 1, 1, kernel, kernel}, + TensorShape{1, ic, 1, 1}); + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t pad : {0, 1}) + for (size_t ic : {1, 3, 5, 7, 17}) + for (size_t h : {10, 17, 25, 30}) + for (size_t w : {19, 28, 58, 168}) + for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) + run(ic, w, h, kernel, pad, nonline_mode); + + Checker checker(handle()); + UniformIntRNG rng{-50, 50}; + checker.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, {}) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_epsilon(1e-3); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1")); + for (auto&& arg : args) { + checker.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}); + } +} + +TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { + using namespace conv_bias; + std::vector args; + + auto run = [&](size_t ic, size_t w, size_t h, size_t kernel, size_t p, + NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + + param.sparse = param::ConvBias::Sparse::GROUP; + //! no bias + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{ic, 1, 1, kernel, kernel}, TensorShape{}); + //! bias channel + args.emplace_back(param, TensorShape{2, ic, h, w}, + TensorShape{ic, 1, 1, kernel, kernel}, + TensorShape{1, ic, 1, 1}); + }; + + for (size_t kernel : {2, 3, 5, 7}) + for (size_t pad : {0, 1}) + for (size_t ic : {1, 3, 5, 7, 17}) + for (size_t h : {10, 15, 17, 30}) + for (size_t w : {19, 28, 58, 168}) + for (NonlineMode nonline_mode : + {NonlineMode::IDENTITY, NonlineMode::H_SWISH, + NonlineMode::RELU}) + run(ic, w, h, kernel, pad, nonline_mode); + + Checker checker(handle()); + UniformIntRNG rng{-50, 50}; + checker.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)) + .set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_epsilon(1e-3); + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1")); + for (auto&& arg : args) { + checker.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}); + } +} + TEST_F(X86_MULTI_THREADS, AVX2_CONV_BIAS_DIRECT_STRIDE1_INT8x8x32) { using namespace conv_bias; std::vector args; @@ -1556,6 +1715,67 @@ void benchmark_impl_comp(const param::ConvBias param, } } // namespace +TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_CHANWISE_AVX2_INT8) { + constexpr size_t RUNS = 50; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector data_type = {dtype::Int8(), dtype::Int8(), + dtype::Int32(), dtype::Int32()}; + + std::vector, float>> + shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t H, size_t W, size_t FS) { + param.pad_h = FS / 2; + param.pad_w = FS / 2; + + SmallVector shapes{ + {N, IC, H, W}, {IC, 1, 1, FS, FS}, {}, {}, {}}; + TensorShape dst{N, IC, (H + 2 * param.pad_h - FS) + 1, + (W + 2 * param.pad_w - FS) + 1}; + float computations = (FS * FS * dst.total_nr_elems() * 2) * 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 112, 112, 7); + bench_case(1, 144, 56, 56, 7); + bench_case(1, 192, 28, 28, 7); + bench_case(1, 384, 28, 28, 7); + bench_case(1, 576, 14, 14, 7); + bench_case(1, 960, 7, 7, 7); + + bench_case(1, 32, 112, 112, 5); + bench_case(1, 144, 56, 56, 5); + bench_case(1, 192, 28, 28, 5); + bench_case(1, 384, 28, 28, 5); + bench_case(1, 576, 14, 14, 5); + bench_case(1, 960, 7, 7, 5); + + bench_case(1, 32, 112, 112, 3); + bench_case(1, 144, 56, 56, 3); + bench_case(1, 192, 28, 28, 3); + bench_case(1, 384, 28, 28, 3); + bench_case(1, 576, 14, 14, 3); + bench_case(1, 960, 7, 7, 3); + + bench_case(1, 32, 112, 112, 2); + bench_case(1, 144, 56, 56, 2); + bench_case(1, 192, 28, 28, 2); + bench_case(1, 384, 28, 28, 2); + bench_case(1, 576, 14, 14, 2); + bench_case(1, 960, 7, 7, 2); + + std::string algo_name = "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"; + printf("Benchmark X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1\n"); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, + {4, {4, 5, 6, 7}}, {1, {4}}, data_type); + benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, + {1, {4}}, data_type); + shapes_and_computation.clear(); +} + TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8) { constexpr size_t RUNS = 50; param::ConvBias param; -- GitLab