diff --git a/dnn/src/x86/conv_bias/int8/algos.cpp b/dnn/src/x86/conv_bias/int8/algos.cpp index 96a079696fd9d35096b10bae2a3380451591d819..21caf2f1d9a094d17f4590a1cf83a5dd4c8c5de6 100644 --- a/dnn/src/x86/conv_bias/int8/algos.cpp +++ b/dnn/src/x86/conv_bias/int8/algos.cpp @@ -6,16 +6,18 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/x86/conv_bias/int8/algos.h" #include "src/common/opr_delegate.h" #include "src/common/utils.h" #include "src/fallback/convolution/img2col_helper.h" +#include "src/x86/conv_bias/int8/avx2_chanwise_stride1.h" +#include "src/x86/conv_bias/int8/avx2_chanwise_stride2.h" #include "src/x86/conv_bias/int8/avx2_direct_conv_stride1.h" #include "src/x86/conv_bias/int8/avx2_direct_conv_stride2.h" -#include "src/x86/conv_bias/int8/avx2_chanwise_stride1.h" #include "src/x86/conv_bias/opr_impl.h" #include "src/x86/conv_bias/postprocess_helper.h" #include "src/x86/handle.h" @@ -38,6 +40,7 @@ bool ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::usable( auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; bool aviliable = + (param.bias_mode != BiasMode::BIAS) && ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.filter_type.enumv() == DTypeEnum::QuantizedS8 && param.dst_type.enumv() == DTypeEnum::QuantizedS8) || @@ -61,12 +64,12 @@ WorkspaceBundle ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::get_bundle( size_t IH2, IW2, OH2, OW2; size_t src_size = 0, dst_size = 0, int32_temp = 0; - avx2_chanwise_stride1::get_rectified_size(param, IH2, IW2, OH2, OW2); + get_rectified_size(param, IH2, IW2, OH2, OW2); - if (avx2_chanwise_stride1::need_src_copy(param)) { + if (need_src_copy(param)) { src_size = IH2 * IW2 * sizeof(int8_t) * nr_threads; } - if (avx2_chanwise_stride1::need_dst_copy(param)) { + if (need_dst_copy(param)) { dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; } bool dst_need_convert = param.dst_type.enumv() == DTypeEnum::QuantizedS8; @@ -91,6 +94,66 @@ ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::get_kimpls( return avx2_chanwise_stride1::get_kimpls(param, bundle); } +bool ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::usable( + FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + bool aviliable = + (param.bias_mode != BiasMode::BIAS) && + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + param.dst_type.enumv() == DTypeEnum::QuantizedS8) || + (((param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int32) || + (param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + param.dst_type.enumv() == DTypeEnum::QuantizedS32)))) && + fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + (FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.stride[0] == 2 && + fm.stride[1] == 2 && (fm.icpg == 1) && (fm.ocpg == 1) && + is_supported(SIMDType::AVX2); + return aviliable; +} + +WorkspaceBundle ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::get_bundle( + const NCBKernSizeParam& param) { + size_t nr_threads = param.nr_threads; + size_t IH2, IW2, OH2, OW2; + size_t src_size = 0, dst_size = 0, int32_temp = 0; + + get_rectified_size(param, IH2, IW2, OH2, OW2); + + if (need_src_copy(param)) { + src_size = IH2 * IW2 * sizeof(int8_t) * nr_threads; + } + if (need_dst_copy(param)) { + dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; + } + bool dst_need_convert = param.dst_type.enumv() == DTypeEnum::QuantizedS8; + + if (dst_need_convert) { + int32_temp = OH2 * OW2 * sizeof(int32_t) * nr_threads; + } + return dst_need_convert + ? WorkspaceBundle(nullptr, {src_size, dst_size, int32_temp}) + : WorkspaceBundle(nullptr, {src_size, dst_size}); +} + +size_t ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::get_workspace( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::get_kimpls( + const NCBKernSizeParam& param) const { + auto bundle = get_bundle(param); + return avx2_chanwise_stride2::get_kimpls(param, bundle); +} + bool ConvBiasImpl::AlgoDirectAvx2Stride1Int8::usable( FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { diff --git a/dnn/src/x86/conv_bias/int8/algos.h b/dnn/src/x86/conv_bias/int8/algos.h index 00135dd3a88efa9fa0a13055771294fb3ef40afc..a7fe96c95a1ac5405401f7d15d8cdce1085388ef 100644 --- a/dnn/src/x86/conv_bias/int8/algos.h +++ b/dnn/src/x86/conv_bias/int8/algos.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "src/x86/conv_bias/opr_impl.h" @@ -36,6 +37,28 @@ public: void* type() const override; }; +/* ===================== avx2 stride2 chanwise algo ===================== */ +class ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); + +public: + bool is_reproducible() const override { return true; } + const char* name() const override { + return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"; + } + bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(FallbackConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override { + return get_kimpls(param); + } + void* type() const override; +}; + /* ===================== avx2 stride1 direct algo ===================== */ class ConvBiasImpl::AlgoDirectAvx2Stride1Int8 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; @@ -125,7 +148,7 @@ public: void* type() const override; }; #endif -/* ===================== avx2 int8 direct conv stride2 algo ===================== */ +/* ================== avx2 int8 direct conv stride2 algo ================== */ class ConvBiasImpl::AlgoAVX2DirectConvStride2 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); diff --git a/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.cpp b/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.cpp index 335fda4f56cf9aed77a9c27fded5911f5467cd77..ae8eb6bce5579c39cd412a358d1fd86e0cf21d31 100644 --- a/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.cpp +++ b/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.cpp @@ -21,8 +21,6 @@ namespace megdnn { namespace x86 { -namespace avx2_chanwise_stride1 { - #define load_filter(i) __m128i k_##i = _mm_set1_epi8(*(filter + i)); #define load_src0(i) \ __m256i cvt16_src##i##0 = _mm256_cvtepi8_epi16_from_ptr(r##i); @@ -40,6 +38,15 @@ namespace avx2_chanwise_stride1 { __m256i cvt16_src##i##6 = _mm256_cvtepi8_epi16_from_ptr(r##i + 6); #define load_src7(i) \ __m256i cvt16_src##i##7 = _mm256_cvtepi8_epi16_from_ptr(r##i + 7); +#define load_src16(i) \ + __m256i cvt16_src##i##16 = _mm256_cvtepi8_epi16_from_ptr(r##i + 16); +#define load_src18(i) \ + __m256i cvt16_src##i##18 = _mm256_cvtepi8_epi16_from_ptr(r##i + 18); +#define load_src20(i) \ + __m256i cvt16_src##i##20 = _mm256_cvtepi8_epi16_from_ptr(r##i + 20); +#define load_src22(i) \ + __m256i cvt16_src##i##22 = _mm256_cvtepi8_epi16_from_ptr(r##i + 22); +namespace avx2_chanwise_stride1 { template void avx2_chanwise_direct_stride1_2x2_int8(const int8_t* src, @@ -1534,16 +1541,6 @@ void avx2_chanwise_direct_stride1_7x7_int8(const int8_t* src, r6 += tail_step; } } -#undef load_filter -#undef load_src0 -#undef load_src1 -#undef load_src2 -#undef load_src3 -#undef load_src4 -#undef load_src5 -#undef load_src6 -#undef load_src7 - #define INSTANTIATION(stride, i, bias, is_quantized, Op) \ template void avx2_chanwise_direct_##stride##_##i##x##i##_int8< \ bias, is_quantized, Op>(const int8_t*, const int8_t*, \ @@ -1587,6 +1584,697 @@ FOR_STRIDE #undef FOR_OP #undef INSTANTIATION } // namespace avx2_chanwise_stride1 + +namespace avx2_chanwise_stride2 { + +template +void avx2_chanwise_direct_stride2_2x2_int8(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + size_t tail_step = IW - OW * 2; + int8_t* dst0 = dst; + int32_t* out_ptr0 = temp; + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + + UNROLL_CALL0(4, load_filter) + +#define pack_filter(i, j) __m128i k_##i##j = _mm_unpacklo_epi8(k_##i, k_##j) + pack_filter(0, 1); + pack_filter(2, 3); + + __m256i bias_val; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_val = _mm256_set1_epi32(*(bias)); + } else { + bias_val = _mm256_set1_epi32(0); + } +#define cvt_filter(i, j) __m256i filter_##i##j = _mm256_cvtepi8_epi16(k_##i##j) + cvt_filter(0, 1); + cvt_filter(2, 3); + + size_t width = OW >> 4; + for (size_t h = 0; h < OH; h++) { + for (size_t w = 0; w < width; w++) { + UNROLL_CALL0(2, load_src0) + UNROLL_CALL0(2, load_src16) + + __m256i t0_left, t0_right, t1_left, t1_right, sum_left, sum_right; + + t0_left = _mm256_madd_epi16(cvt16_src00, filter_01); + t0_right = _mm256_madd_epi16(cvt16_src016, filter_01); + + t1_left = _mm256_madd_epi16(cvt16_src10, filter_23); + t1_right = _mm256_madd_epi16(cvt16_src116, filter_23); + + sum_left = _mm256_add_epi32(t0_left, t1_left); + sum_right = _mm256_add_epi32(t0_right, t1_right); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + r0 += 32; + r1 += 32; + dst0 += 16; + out_ptr0 += 16; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + } + + MEGDNN_MARK_USED_VAR(IH); +#undef pack_filter +#undef cvt_filter +} + +template +void avx2_chanwise_direct_stride2_3x3_int8(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + size_t tail_step = IW - OW * 2; + int32_t* out_ptr0 = temp; + int8_t* dst0 = dst; + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + 2 * IW; + + uint8_t fill_zero = 0; + UNROLL_CALL0(9, load_filter) + + __m128i k_fill = _mm_set1_epi8(fill_zero); + + __m128i k01 = _mm_unpacklo_epi8(k_0, k_1); + __m128i k20 = _mm_unpacklo_epi8(k_2, k_fill); + + __m128i k34 = _mm_unpacklo_epi8(k_3, k_4); + __m128i k50 = _mm_unpacklo_epi8(k_5, k_fill); + + __m128i k67 = _mm_unpacklo_epi8(k_6, k_7); + __m128i k80 = _mm_unpacklo_epi8(k_8, k_fill); + + __m256i bias_val; + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_val = _mm256_set1_epi32(*(bias)); + } else { + bias_val = _mm256_set1_epi32(0); + } + + //! cvt i8 --> i16 + __m256i filter_01 = _mm256_cvtepi8_epi16(k01); + __m256i filter_20 = _mm256_cvtepi8_epi16(k20); + __m256i filter_34 = _mm256_cvtepi8_epi16(k34); + __m256i filter_50 = _mm256_cvtepi8_epi16(k50); + __m256i filter_67 = _mm256_cvtepi8_epi16(k67); + __m256i filter_80 = _mm256_cvtepi8_epi16(k80); + + size_t width = OW >> 4; + for (size_t h = 0; h < OH; h++) { + for (size_t w = 0; w < width; w++) { + UNROLL_CALL0(3, load_src0) + UNROLL_CALL0(3, load_src2) + UNROLL_CALL0(3, load_src16) + UNROLL_CALL0(3, load_src18) + + __m256i temp, t0_left, t0_right, t1_left, t1_right, t2_left, + t2_right, sum_left, sum_right; + + t0_left = _mm256_madd_epi16(cvt16_src00, filter_01); + temp = _mm256_madd_epi16(cvt16_src02, filter_20); + t0_left = _mm256_add_epi32(t0_left, temp); + + t0_right = _mm256_madd_epi16(cvt16_src016, filter_01); + temp = _mm256_madd_epi16(cvt16_src018, filter_20); + t0_right = _mm256_add_epi32(t0_right, temp); + + t1_left = _mm256_madd_epi16(cvt16_src10, filter_34); + temp = _mm256_madd_epi16(cvt16_src12, filter_50); + t1_left = _mm256_add_epi32(t1_left, temp); + + t1_right = _mm256_madd_epi16(cvt16_src116, filter_34); + temp = _mm256_madd_epi16(cvt16_src118, filter_50); + t1_right = _mm256_add_epi32(t1_right, temp); + + t2_left = _mm256_madd_epi16(cvt16_src20, filter_67); + temp = _mm256_madd_epi16(cvt16_src22, filter_80); + t2_left = _mm256_add_epi32(t2_left, temp); + + t2_right = _mm256_madd_epi16(cvt16_src216, filter_67); + temp = _mm256_madd_epi16(cvt16_src218, filter_80); + t2_right = _mm256_add_epi32(t2_right, temp); + + sum_left = _mm256_add_epi32(t0_left, t1_left); + sum_left = _mm256_add_epi32(sum_left, t2_left); + sum_right = _mm256_add_epi32(t0_right, t1_right); + sum_right = _mm256_add_epi32(sum_right, t2_right); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + r0 += 32; + r1 += 32; + r2 += 32; + dst0 += 16; + out_ptr0 += 16; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + } +} + +template +void avx2_chanwise_direct_stride2_5x5_int8(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + size_t tail_step = IW - OW * 2; + int8_t* dst0 = dst; + int32_t* out_ptr0 = temp; + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + 2 * IW; + const int8_t* r3 = src + 3 * IW; + const int8_t* r4 = src + 4 * IW; + + uint8_t fill_zero = 0; + UNROLL_CALL0(25, load_filter) + + __m128i k_fill = _mm_set1_epi8(fill_zero); + + __m128i k01 = _mm_unpacklo_epi8(k_0, k_1); + __m128i k23 = _mm_unpacklo_epi8(k_2, k_3); + __m128i k40 = _mm_unpacklo_epi8(k_4, k_fill); + + __m128i k56 = _mm_unpacklo_epi8(k_5, k_6); + __m128i k78 = _mm_unpacklo_epi8(k_7, k_8); + __m128i k90 = _mm_unpacklo_epi8(k_9, k_fill); + + __m128i k1011 = _mm_unpacklo_epi8(k_10, k_11); + __m128i k1213 = _mm_unpacklo_epi8(k_12, k_13); + __m128i k140 = _mm_unpacklo_epi8(k_14, k_fill); + + __m128i k1516 = _mm_unpacklo_epi8(k_15, k_16); + __m128i k1718 = _mm_unpacklo_epi8(k_17, k_18); + __m128i k190 = _mm_unpacklo_epi8(k_19, k_fill); + + __m128i k2021 = _mm_unpacklo_epi8(k_20, k_21); + __m128i k2223 = _mm_unpacklo_epi8(k_22, k_23); + __m128i k240 = _mm_unpacklo_epi8(k_24, k_fill); + + __m256i bias_val; + //! load bias + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_val = _mm256_set1_epi32(*(bias)); + } else { + bias_val = _mm256_set1_epi32(0); + } + + //! cvt i8 --> i16 + __m256i filter_01 = _mm256_cvtepi8_epi16(k01); + __m256i filter_23 = _mm256_cvtepi8_epi16(k23); + __m256i filter_40 = _mm256_cvtepi8_epi16(k40); + + __m256i filter_56 = _mm256_cvtepi8_epi16(k56); + __m256i filter_78 = _mm256_cvtepi8_epi16(k78); + __m256i filter_90 = _mm256_cvtepi8_epi16(k90); + + __m256i filter_1011 = _mm256_cvtepi8_epi16(k1011); + __m256i filter_1213 = _mm256_cvtepi8_epi16(k1213); + __m256i filter_140 = _mm256_cvtepi8_epi16(k140); + + __m256i filter_1516 = _mm256_cvtepi8_epi16(k1516); + __m256i filter_1718 = _mm256_cvtepi8_epi16(k1718); + __m256i filter_190 = _mm256_cvtepi8_epi16(k190); + + __m256i filter_2021 = _mm256_cvtepi8_epi16(k2021); + __m256i filter_2223 = _mm256_cvtepi8_epi16(k2223); + __m256i filter_240 = _mm256_cvtepi8_epi16(k240); + + size_t width = OW >> 4; + for (size_t h = 0; h < OH; h++) { + for (size_t w = 0; w < width; w++) { + UNROLL_CALL0(5, load_src0) + UNROLL_CALL0(5, load_src2) + UNROLL_CALL0(5, load_src4) + UNROLL_CALL0(5, load_src16) + UNROLL_CALL0(5, load_src18) + UNROLL_CALL0(5, load_src20) + + __m256i temp, t0_left, t0_right, t1_left, t1_right, t2_left, + t2_right, t3_left, t3_right, t4_left, t4_right, sum_left, + sum_right; + + t0_left = _mm256_madd_epi16(cvt16_src00, filter_01); + temp = _mm256_madd_epi16(cvt16_src02, filter_23); + t0_left = _mm256_add_epi32(t0_left, temp); + temp = _mm256_madd_epi16(cvt16_src04, filter_40); + t0_left = _mm256_add_epi32(t0_left, temp); + + t0_right = _mm256_madd_epi16(cvt16_src016, filter_01); + temp = _mm256_madd_epi16(cvt16_src018, filter_23); + t0_right = _mm256_add_epi32(t0_right, temp); + temp = _mm256_madd_epi16(cvt16_src020, filter_40); + t0_right = _mm256_add_epi32(t0_right, temp); + + t1_left = _mm256_madd_epi16(cvt16_src10, filter_56); + temp = _mm256_madd_epi16(cvt16_src12, filter_78); + t1_left = _mm256_add_epi32(t1_left, temp); + temp = _mm256_madd_epi16(cvt16_src14, filter_90); + t1_left = _mm256_add_epi32(t1_left, temp); + + t1_right = _mm256_madd_epi16(cvt16_src116, filter_56); + temp = _mm256_madd_epi16(cvt16_src118, filter_78); + t1_right = _mm256_add_epi32(t1_right, temp); + temp = _mm256_madd_epi16(cvt16_src120, filter_90); + t1_right = _mm256_add_epi32(t1_right, temp); + + t2_left = _mm256_madd_epi16(cvt16_src20, filter_1011); + temp = _mm256_madd_epi16(cvt16_src22, filter_1213); + t2_left = _mm256_add_epi32(t2_left, temp); + temp = _mm256_madd_epi16(cvt16_src24, filter_140); + t2_left = _mm256_add_epi32(t2_left, temp); + + t2_right = _mm256_madd_epi16(cvt16_src216, filter_1011); + temp = _mm256_madd_epi16(cvt16_src218, filter_1213); + t2_right = _mm256_add_epi32(t2_right, temp); + temp = _mm256_madd_epi16(cvt16_src220, filter_140); + t2_right = _mm256_add_epi32(t2_right, temp); + + t3_left = _mm256_madd_epi16(cvt16_src30, filter_1516); + temp = _mm256_madd_epi16(cvt16_src32, filter_1718); + t3_left = _mm256_add_epi32(t3_left, temp); + temp = _mm256_madd_epi16(cvt16_src34, filter_190); + t3_left = _mm256_add_epi32(t3_left, temp); + + t3_right = _mm256_madd_epi16(cvt16_src316, filter_1516); + temp = _mm256_madd_epi16(cvt16_src318, filter_1718); + t3_right = _mm256_add_epi32(t3_right, temp); + temp = _mm256_madd_epi16(cvt16_src320, filter_190); + t3_right = _mm256_add_epi32(t3_right, temp); + + t4_left = _mm256_madd_epi16(cvt16_src40, filter_2021); + temp = _mm256_madd_epi16(cvt16_src42, filter_2223); + t4_left = _mm256_add_epi32(t4_left, temp); + temp = _mm256_madd_epi16(cvt16_src44, filter_240); + t4_left = _mm256_add_epi32(t4_left, temp); + + t4_right = _mm256_madd_epi16(cvt16_src416, filter_2021); + temp = _mm256_madd_epi16(cvt16_src418, filter_2223); + t4_right = _mm256_add_epi32(t4_right, temp); + temp = _mm256_madd_epi16(cvt16_src420, filter_240); + t4_right = _mm256_add_epi32(t4_right, temp); + + sum_left = _mm256_add_epi32(t0_left, t1_left); + sum_left = _mm256_add_epi32(sum_left, t2_left); + sum_left = _mm256_add_epi32(sum_left, t3_left); + sum_left = _mm256_add_epi32(sum_left, t4_left); + sum_right = _mm256_add_epi32(t0_right, t1_right); + sum_right = _mm256_add_epi32(sum_right, t2_right); + sum_right = _mm256_add_epi32(sum_right, t3_right); + sum_right = _mm256_add_epi32(sum_right, t4_right); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + r0 += 32; + r1 += 32; + r2 += 32; + r3 += 32; + r4 += 32; + dst0 += 16; + out_ptr0 += 16; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + } +} + +template +void avx2_chanwise_direct_stride2_7x7_int8(const int8_t* src, + const int8_t* filter, + const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, + const size_t IW, const size_t OH, + const size_t OW, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); + size_t tail_step = IW - OW * 2; + int8_t* dst0 = dst; + int32_t* out_ptr0 = temp; + const int8_t* r0 = src; + const int8_t* r1 = src + IW; + const int8_t* r2 = src + 2 * IW; + const int8_t* r3 = src + 3 * IW; + const int8_t* r4 = src + 4 * IW; + const int8_t* r5 = src + 5 * IW; + const int8_t* r6 = src + 6 * IW; + + uint8_t fill_zero = 0; + UNROLL_CALL0(49, load_filter) + + __m128i k_fill = _mm_set1_epi8(fill_zero); + + __m128i k01 = _mm_unpacklo_epi8(k_0, k_1); + __m128i k23 = _mm_unpacklo_epi8(k_2, k_3); + __m128i k45 = _mm_unpacklo_epi8(k_4, k_5); + __m128i k60 = _mm_unpacklo_epi8(k_6, k_fill); + + __m128i k78 = _mm_unpacklo_epi8(k_7, k_8); + __m128i k910 = _mm_unpacklo_epi8(k_9, k_10); + __m128i k1112 = _mm_unpacklo_epi8(k_11, k_12); + __m128i k130 = _mm_unpacklo_epi8(k_13, k_fill); + + __m128i k1415 = _mm_unpacklo_epi8(k_14, k_15); + __m128i k1617 = _mm_unpacklo_epi8(k_16, k_17); + __m128i k1819 = _mm_unpacklo_epi8(k_18, k_19); + __m128i k200 = _mm_unpacklo_epi8(k_20, k_fill); + + __m128i k2122 = _mm_unpacklo_epi8(k_21, k_22); + __m128i k2324 = _mm_unpacklo_epi8(k_23, k_24); + __m128i k2526 = _mm_unpacklo_epi8(k_25, k_26); + __m128i k270 = _mm_unpacklo_epi8(k_27, k_fill); + + __m128i k2829 = _mm_unpacklo_epi8(k_28, k_29); + __m128i k3031 = _mm_unpacklo_epi8(k_30, k_31); + __m128i k3233 = _mm_unpacklo_epi8(k_32, k_33); + __m128i k340 = _mm_unpacklo_epi8(k_34, k_fill); + + __m128i k3536 = _mm_unpacklo_epi8(k_35, k_36); + __m128i k3738 = _mm_unpacklo_epi8(k_37, k_38); + __m128i k3940 = _mm_unpacklo_epi8(k_39, k_40); + __m128i k410 = _mm_unpacklo_epi8(k_41, k_fill); + + __m128i k4243 = _mm_unpacklo_epi8(k_42, k_43); + __m128i k4445 = _mm_unpacklo_epi8(k_44, k_45); + __m128i k4647 = _mm_unpacklo_epi8(k_46, k_47); + __m128i k480 = _mm_unpacklo_epi8(k_48, k_fill); + + __m256i bias_val; + //! load bias + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_val = _mm256_set1_epi32(*(bias)); + } else { + bias_val = _mm256_set1_epi32(0); + } + + //! cvt i8 --> i16 + __m256i filter_01 = _mm256_cvtepi8_epi16(k01); + __m256i filter_23 = _mm256_cvtepi8_epi16(k23); + __m256i filter_45 = _mm256_cvtepi8_epi16(k45); + __m256i filter_60 = _mm256_cvtepi8_epi16(k60); + + __m256i filter_78 = _mm256_cvtepi8_epi16(k78); + __m256i filter_910 = _mm256_cvtepi8_epi16(k910); + __m256i filter_1112 = _mm256_cvtepi8_epi16(k1112); + __m256i filter_130 = _mm256_cvtepi8_epi16(k130); + + __m256i filter_1415 = _mm256_cvtepi8_epi16(k1415); + __m256i filter_1617 = _mm256_cvtepi8_epi16(k1617); + __m256i filter_1819 = _mm256_cvtepi8_epi16(k1819); + __m256i filter_200 = _mm256_cvtepi8_epi16(k200); + + __m256i filter_2122 = _mm256_cvtepi8_epi16(k2122); + __m256i filter_2324 = _mm256_cvtepi8_epi16(k2324); + __m256i filter_2526 = _mm256_cvtepi8_epi16(k2526); + __m256i filter_270 = _mm256_cvtepi8_epi16(k270); + + __m256i filter_2829 = _mm256_cvtepi8_epi16(k2829); + __m256i filter_3031 = _mm256_cvtepi8_epi16(k3031); + __m256i filter_3233 = _mm256_cvtepi8_epi16(k3233); + __m256i filter_340 = _mm256_cvtepi8_epi16(k340); + + __m256i filter_3536 = _mm256_cvtepi8_epi16(k3536); + __m256i filter_3738 = _mm256_cvtepi8_epi16(k3738); + __m256i filter_3940 = _mm256_cvtepi8_epi16(k3940); + __m256i filter_410 = _mm256_cvtepi8_epi16(k410); + + __m256i filter_4243 = _mm256_cvtepi8_epi16(k4243); + __m256i filter_4445 = _mm256_cvtepi8_epi16(k4445); + __m256i filter_4647 = _mm256_cvtepi8_epi16(k4647); + __m256i filter_480 = _mm256_cvtepi8_epi16(k480); + + size_t width = OW >> 4; + for (size_t h = 0; h < OH; h++) { + for (size_t w = 0; w < width; w++) { + UNROLL_CALL0(7, load_src0) + UNROLL_CALL0(7, load_src2) + UNROLL_CALL0(7, load_src4) + UNROLL_CALL0(7, load_src6) + UNROLL_CALL0(7, load_src16) + UNROLL_CALL0(7, load_src18) + UNROLL_CALL0(7, load_src20) + UNROLL_CALL0(7, load_src22) + + __m256i temp, t0_left, t0_right, t1_left, t1_right, t2_left, + t2_right, t3_left, t3_right, t4_left, t4_right, sum_left, + t5_left, t5_right, t6_left, t6_right, sum_right; + + t0_left = _mm256_madd_epi16(cvt16_src00, filter_01); + temp = _mm256_madd_epi16(cvt16_src02, filter_23); + t0_left = _mm256_add_epi32(t0_left, temp); + temp = _mm256_madd_epi16(cvt16_src04, filter_45); + t0_left = _mm256_add_epi32(t0_left, temp); + temp = _mm256_madd_epi16(cvt16_src06, filter_60); + t0_left = _mm256_add_epi32(t0_left, temp); + + t0_right = _mm256_madd_epi16(cvt16_src016, filter_01); + temp = _mm256_madd_epi16(cvt16_src018, filter_23); + t0_right = _mm256_add_epi32(t0_right, temp); + temp = _mm256_madd_epi16(cvt16_src020, filter_45); + t0_right = _mm256_add_epi32(t0_right, temp); + temp = _mm256_madd_epi16(cvt16_src022, filter_60); + t0_right = _mm256_add_epi32(t0_right, temp); + + t1_left = _mm256_madd_epi16(cvt16_src10, filter_78); + temp = _mm256_madd_epi16(cvt16_src12, filter_910); + t1_left = _mm256_add_epi32(t1_left, temp); + temp = _mm256_madd_epi16(cvt16_src14, filter_1112); + t1_left = _mm256_add_epi32(t1_left, temp); + temp = _mm256_madd_epi16(cvt16_src16, filter_130); + t1_left = _mm256_add_epi32(t1_left, temp); + + t1_right = _mm256_madd_epi16(cvt16_src116, filter_78); + temp = _mm256_madd_epi16(cvt16_src118, filter_910); + t1_right = _mm256_add_epi32(t1_right, temp); + temp = _mm256_madd_epi16(cvt16_src120, filter_1112); + t1_right = _mm256_add_epi32(t1_right, temp); + temp = _mm256_madd_epi16(cvt16_src122, filter_130); + t1_right = _mm256_add_epi32(t1_right, temp); + + t2_left = _mm256_madd_epi16(cvt16_src20, filter_1415); + temp = _mm256_madd_epi16(cvt16_src22, filter_1617); + t2_left = _mm256_add_epi32(t2_left, temp); + temp = _mm256_madd_epi16(cvt16_src24, filter_1819); + t2_left = _mm256_add_epi32(t2_left, temp); + temp = _mm256_madd_epi16(cvt16_src26, filter_200); + t2_left = _mm256_add_epi32(t2_left, temp); + + t2_right = _mm256_madd_epi16(cvt16_src216, filter_1415); + temp = _mm256_madd_epi16(cvt16_src218, filter_1617); + t2_right = _mm256_add_epi32(t2_right, temp); + temp = _mm256_madd_epi16(cvt16_src220, filter_1819); + t2_right = _mm256_add_epi32(t2_right, temp); + temp = _mm256_madd_epi16(cvt16_src222, filter_200); + t2_right = _mm256_add_epi32(t2_right, temp); + + t3_left = _mm256_madd_epi16(cvt16_src30, filter_2122); + temp = _mm256_madd_epi16(cvt16_src32, filter_2324); + t3_left = _mm256_add_epi32(t3_left, temp); + temp = _mm256_madd_epi16(cvt16_src34, filter_2526); + t3_left = _mm256_add_epi32(t3_left, temp); + temp = _mm256_madd_epi16(cvt16_src36, filter_270); + t3_left = _mm256_add_epi32(t3_left, temp); + + t3_right = _mm256_madd_epi16(cvt16_src316, filter_2122); + temp = _mm256_madd_epi16(cvt16_src318, filter_2324); + t3_right = _mm256_add_epi32(t3_right, temp); + temp = _mm256_madd_epi16(cvt16_src320, filter_2526); + t3_right = _mm256_add_epi32(t3_right, temp); + temp = _mm256_madd_epi16(cvt16_src322, filter_270); + t3_right = _mm256_add_epi32(t3_right, temp); + + t4_left = _mm256_madd_epi16(cvt16_src40, filter_2829); + temp = _mm256_madd_epi16(cvt16_src42, filter_3031); + t4_left = _mm256_add_epi32(t4_left, temp); + temp = _mm256_madd_epi16(cvt16_src44, filter_3233); + t4_left = _mm256_add_epi32(t4_left, temp); + temp = _mm256_madd_epi16(cvt16_src46, filter_340); + t4_left = _mm256_add_epi32(t4_left, temp); + + t4_right = _mm256_madd_epi16(cvt16_src416, filter_2829); + temp = _mm256_madd_epi16(cvt16_src418, filter_3031); + t4_right = _mm256_add_epi32(t4_right, temp); + temp = _mm256_madd_epi16(cvt16_src420, filter_3233); + t4_right = _mm256_add_epi32(t4_right, temp); + temp = _mm256_madd_epi16(cvt16_src422, filter_340); + t4_right = _mm256_add_epi32(t4_right, temp); + + t5_left = _mm256_madd_epi16(cvt16_src50, filter_3536); + temp = _mm256_madd_epi16(cvt16_src52, filter_3738); + t5_left = _mm256_add_epi32(t5_left, temp); + temp = _mm256_madd_epi16(cvt16_src54, filter_3940); + t5_left = _mm256_add_epi32(t5_left, temp); + temp = _mm256_madd_epi16(cvt16_src56, filter_410); + t5_left = _mm256_add_epi32(t5_left, temp); + + t5_right = _mm256_madd_epi16(cvt16_src516, filter_3536); + temp = _mm256_madd_epi16(cvt16_src518, filter_3738); + t5_right = _mm256_add_epi32(t5_right, temp); + temp = _mm256_madd_epi16(cvt16_src520, filter_3940); + t5_right = _mm256_add_epi32(t5_right, temp); + temp = _mm256_madd_epi16(cvt16_src522, filter_410); + t5_right = _mm256_add_epi32(t5_right, temp); + + t6_left = _mm256_madd_epi16(cvt16_src60, filter_4243); + temp = _mm256_madd_epi16(cvt16_src62, filter_4445); + t6_left = _mm256_add_epi32(t6_left, temp); + temp = _mm256_madd_epi16(cvt16_src64, filter_4647); + t6_left = _mm256_add_epi32(t6_left, temp); + temp = _mm256_madd_epi16(cvt16_src66, filter_480); + t6_left = _mm256_add_epi32(t6_left, temp); + + t6_right = _mm256_madd_epi16(cvt16_src616, filter_4243); + temp = _mm256_madd_epi16(cvt16_src618, filter_4445); + t6_right = _mm256_add_epi32(t6_right, temp); + temp = _mm256_madd_epi16(cvt16_src620, filter_4647); + t6_right = _mm256_add_epi32(t6_right, temp); + temp = _mm256_madd_epi16(cvt16_src622, filter_480); + t6_right = _mm256_add_epi32(t6_right, temp); + + sum_left = _mm256_add_epi32(t0_left, t1_left); + sum_left = _mm256_add_epi32(sum_left, t2_left); + sum_left = _mm256_add_epi32(sum_left, t3_left); + sum_left = _mm256_add_epi32(sum_left, t4_left); + sum_left = _mm256_add_epi32(sum_left, t5_left); + sum_left = _mm256_add_epi32(sum_left, t6_left); + sum_right = _mm256_add_epi32(t0_right, t1_right); + sum_right = _mm256_add_epi32(sum_right, t2_right); + sum_right = _mm256_add_epi32(sum_right, t3_right); + sum_right = _mm256_add_epi32(sum_right, t4_right); + sum_right = _mm256_add_epi32(sum_right, t5_right); + sum_right = _mm256_add_epi32(sum_right, t6_right); + + sum_left = _mm256_add_epi32(sum_left, bias_val); + sum_right = _mm256_add_epi32(sum_right, bias_val); + + if (is_quantized) { + op({{sum_left, sum_right}}, reinterpret_cast(dst0)); + + } else { + _mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); + _mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); + } + + r0 += 32; + r1 += 32; + r2 += 32; + r3 += 32; + r4 += 32; + r5 += 32; + r6 += 32; + dst0 += 16; + out_ptr0 += 16; + } + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + r6 += tail_step + IW; + } +} +#define INSTANTIATION(stride, i, bias, is_quantized, Op) \ + template void avx2_chanwise_direct_##stride##_##i##x##i##_int8< \ + bias, is_quantized, Op>(const int8_t*, const int8_t*, \ + const int32_t*, int32_t*, int8_t*, \ + const size_t, const size_t, const size_t, \ + const size_t, const Op&); + +#define FOR_OP(stride, i, is_quantized, bias) \ + INSTANTIATION(stride, i, bias, is_quantized, \ + TypeCvtOp) \ + INSTANTIATION(stride, i, bias, is_quantized, \ + ReluOp) \ + INSTANTIATION(stride, i, bias, is_quantized, \ + HSwishOp) + +#define FOR_BIAS(stride, i, is_quantized) \ + FOR_OP(stride, i, is_quantized, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, is_quantized, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_QUANTIZED(stride, i) \ + FOR_BIAS(stride, i, true) \ + FOR_BIAS(stride, i, false) + +#define FOR_FILTER(stride) \ + FOR_QUANTIZED(stride, 2) \ + FOR_QUANTIZED(stride, 3) \ + FOR_QUANTIZED(stride, 5) \ + FOR_QUANTIZED(stride, 7) + +#define FOR_STRIDE FOR_FILTER(stride2) + +FOR_STRIDE + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_QUANTIZED +#undef FOR_BIAS +#undef FOR_OP +#undef INSTANTIATION +} // namespace avx2_chanwise_stride2 +#undef load_filter +#undef load_src0 +#undef load_src1 +#undef load_src2 +#undef load_src3 +#undef load_src4 +#undef load_src5 +#undef load_src6 +#undef load_src16 +#undef load_src18 +#undef load_src20 +#undef load_src22 } // namespace x86 } // namespace megdnn diff --git a/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.h b/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.h index 571f4f394a6a14b4e6a093d6e587bec50ac064c1..00676b735eb257de488bbc3c0a8a9a8fea7aabfa 100644 --- a/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.h +++ b/dnn/src/x86/conv_bias/int8/avx2_chanwise_kern.h @@ -33,6 +33,25 @@ KERN(stride1, 7) #undef KERN } // namespace avx2_chanwise_stride1 + +namespace avx2_chanwise_stride2 { + +#define KERN(stride, i) \ + template \ + MEGDNN_ATTRIBUTE_TARGET("avx2") \ + void avx2_chanwise_direct_##stride##_##i##x##i##_int8( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, \ + const size_t OH, const size_t OW, const Op& op); + +KERN(stride2, 2) +KERN(stride2, 3) +KERN(stride2, 5) +KERN(stride2, 7) + +#undef KERN + +} // namespace avx2_chanwise_stride2 } // namespace x86 } // namespace megdnn diff --git a/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.cpp b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.cpp index 25fb140b86ed18477b4848b1a6bf9a8abf571b33..19d18ee1e285a153c4ba88f406b49a8151cac68f 100644 --- a/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.cpp +++ b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.cpp @@ -18,57 +18,6 @@ namespace megdnn { namespace x86 { namespace avx2_chanwise_stride1 { -bool need_dst_copy(const NCBKernSizeParam& param) { - return param.osz[1] % 16; -} -bool need_src_copy(const NCBKernSizeParam& param) { - auto&& fm = param.filter_meta; - return (fm.padding[0] != 0 || fm.padding[1] != 0) ? true - : need_dst_copy(param); -} -void get_rectified_size(const NCBKernSizeParam& param, size_t& IH2, size_t& IW2, - size_t& OH2, size_t& OW2) { - auto&& fm = param.filter_meta; - auto SW = fm.stride[1]; - auto OH = param.osz[0]; - auto OW = param.osz[1]; - auto FH = fm.spatial[0]; - auto FW = fm.spatial[1]; - - OH2 = OH; - OW2 = (OW + 15) & ~15; - IH2 = SW * OH + FH - SW; - IW2 = SW * OW2 + FW - SW; -} -void copy_padding_kern(WorkspaceBundle bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index) { - size_t IH = kern_param.isz[0]; - size_t IW = kern_param.isz[1]; - size_t PH = kern_param.filter_meta.padding[0]; - size_t PW = kern_param.filter_meta.padding[1]; - - size_t IH2, IW2, OH2, OW2; - get_rectified_size(kern_param, IH2, IW2, OH2, OW2); - bool need_src_copy_var = need_src_copy(kern_param); - size_t padding_group_size = IH2 * IW2; - bundle.set(kern_param.workspace_ptr); - - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1], - channel_id = ncb_index.ndrange_id[2]; - size_t workspace_group_id = ncb_index.thread_id; - const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); - if (need_src_copy_var) { - int8_t* sptr_base = static_cast(bundle.get(0)) + - workspace_group_id * padding_group_size; - std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); - rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(int8_t) * IW); - } - } -}; template void conv_kimpl(WorkspaceBundle bundle, const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) { @@ -97,8 +46,7 @@ void conv_kimpl(WorkspaceBundle bundle, const NCBKernParam& kern_param, batch_id = ncb_index.ndrange_id[1]; const int8_t* sptr = kern_param.src(batch_id, group_id); - const int8_t* fptr = - kern_param.filter(group_id); + const int8_t* fptr = kern_param.filter(group_id); void* dst = kern_param.dst(batch_id, group_id); const int32_t* bptr = kern_param.bias(batch_id, group_id); if (need_src_copy_var) { @@ -130,9 +78,9 @@ void conv_kimpl(WorkspaceBundle bundle, const NCBKernParam& kern_param, if (need_post_process) { tptr = static_cast(bundle.get(2)) + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); - DISPATCH_FILTER(filter, KERN_NEED_POST_PROCESS) + DISPATCH_FILTER(filter, KERN_NEED_POST_PROCESS) } else { - DISPATCH_FILTER(filter, KERN_NO_POST_PROCESS) + DISPATCH_FILTER(filter, KERN_NO_POST_PROCESS) } #undef KERN_NEED_POST_PROCESS diff --git a/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.h b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.h index f57b23ab1935376191a5d35f926fe144cc0b7723..518501b297fd9092358d5b22118274a484aa8697 100644 --- a/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.h +++ b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride1.h @@ -11,27 +11,15 @@ */ #pragma once +#include "src/x86/conv_bias/int8/common_helper.h" #include "src/x86/conv_bias/opr_impl.h" namespace megdnn { namespace x86 { namespace avx2_chanwise_stride1 { -using NCBKern = fallback::ConvBiasImpl::NCBKern; -using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; -using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; -using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; - using conv_fun = std::function; - -bool need_dst_copy(const NCBKernSizeParam& param); - -bool need_src_copy(const NCBKernSizeParam& param); - -void get_rectified_size(const NCBKernSizeParam& param, size_t& IH2, size_t& IW2, - size_t& OH2, size_t& OW2); - SmallVector get_kimpls(const NCBKernSizeParam& param, WorkspaceBundle bundle); diff --git a/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride2.cpp b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ca4c70a6a56dc6b96c16a9e405a8a511a88d679d --- /dev/null +++ b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride2.cpp @@ -0,0 +1,204 @@ +/** + * \file src/x86/conv_bias/int8/avx2_chanwsie_stride2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/x86/conv_bias/int8/avx2_chanwise_stride2.h" +#include "src/x86/conv_bias/int8/avx2_chanwise_kern.h" +#include "src/x86/elemwise_op.h" + +namespace megdnn { +namespace x86 { +namespace avx2_chanwise_stride2 { + +template +void conv_kimpl(WorkspaceBundle bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + bool need_dst_copy_var = need_dst_copy(kern_param); + bool need_post_process = + kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + + Op op = Op(1.0f, 4.0f); + if (need_post_process) { + float scale_bias = + kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + size_t padding_group_size = IH2 * IW2; + + bundle.set(kern_param.workspace_ptr); + + size_t workspace_group_id = ncb_index.thread_id; + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1]; + + const int8_t* sptr = kern_param.src(batch_id, group_id); + const int8_t* fptr = kern_param.filter(group_id); + void* dst = kern_param.dst(batch_id, group_id); + const int32_t* bptr = kern_param.bias(batch_id, group_id); + if (need_src_copy_var) { + sptr = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size; + } + void* dptr = nullptr; + int32_t* tptr = nullptr; + if (need_dst_copy_var) { + dptr = reinterpret_cast( + reinterpret_cast(bundle.get(1)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); + } else { + dptr = dst; + } + +#define KERN_NEED_POST_PROCESS(filter) \ + avx2_chanwise_direct_stride2_##filter##x##filter##_int8( \ + sptr, fptr, bptr, tptr, static_cast(dptr), IH2, IW2, OH2, \ + OW2, op) + +#define KERN_NO_POST_PROCESS(filter) \ + avx2_chanwise_direct_stride2_##filter##x##filter##_int8( \ + sptr, fptr, bptr, static_cast(dptr), nullptr, IH2, IW2, \ + OH2, OW2, op) + + if (need_post_process) { + tptr = static_cast(bundle.get(2)) + + ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); + DISPATCH_FILTER(filter, KERN_NEED_POST_PROCESS) + } else { + DISPATCH_FILTER(filter, KERN_NO_POST_PROCESS) + } + +#undef KERN_NEED_POST_PROCESS +#undef KERN_NO_POST_PROCESS + if (need_dst_copy_var) { + rep(oh, OH) { + std::memcpy(reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); + } + } +}; +SmallVector get_kimpls(const NCBKernSizeParam& kern_param, + WorkspaceBundle bundle) { + MEGDNN_MARK_USED_VAR(kern_param); + auto fm = kern_param.filter_meta; + size_t group = fm.group; + size_t n = kern_param.n; + + SmallVector ncb_kerns; + conv_fun do_conv_fun = nullptr; + +#define DO_CONV_KERN_FUN(filter, bias_mode, is_quantized, op) \ + do_conv_fun = conv_kimpl; + +#define GET_OP_PARAM(i, bias_mode, is_quantized) \ + switch (kern_param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0, "do not support nonlineMode: %d", \ + static_cast(kern_param.nonlineMode)); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(i, is_quantized) \ + switch (kern_param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS, is_quantized) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS, is_quantized) \ + break; \ + default: \ + megdnn_assert(0, "do not support bias mode: %d", \ + static_cast(kern_param.bias_mode)); \ + break; \ + } + +#define GET_QUANTIZED(i) \ + switch (kern_param.dst_type.enumv()) { \ + case DTypeEnum::QuantizedS8: \ + GET_BIAS_MODE_PARAM(i, true) \ + break; \ + case DTypeEnum::QuantizedS32: \ + GET_BIAS_MODE_PARAM(i, false) \ + break; \ + case DTypeEnum::Int32: \ + GET_BIAS_MODE_PARAM(i, false) \ + break; \ + default: \ + megdnn_assert(0, "do not support dtype: %d", \ + static_cast(kern_param.dst_type.enumv())); \ + break; \ + } + +#define DISPATCH_CONV_KERN() \ + switch (kern_param.filter_meta.spatial[0]) { \ + case 2: \ + GET_QUANTIZED(2) \ + break; \ + case 3: \ + GET_QUANTIZED(3) \ + break; \ + case 5: \ + GET_QUANTIZED(5) \ + break; \ + case 7: \ + GET_QUANTIZED(7) \ + break; \ + default: \ + megdnn_assert( \ + 0, "do not support kernel: %d", \ + static_cast(kern_param.filter_meta.spatial[0])); \ + break; \ + } + + DISPATCH_CONV_KERN(); + + auto exec_one_group = [bundle, do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + copy_padding_kern(bundle, kern_param, ncb_index); + do_conv_fun(bundle, kern_param, ncb_index); + }; + ncb_kerns.push_back({exec_one_group, {group, n, 1_z}}); + + return ncb_kerns; +} + +} // namespace avx2_chanwise_stride2 +} // namespace x86 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride2.h b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride2.h new file mode 100644 index 0000000000000000000000000000000000000000..63ee7df4be449ff80e5805a850c44fc32a78d802 --- /dev/null +++ b/dnn/src/x86/conv_bias/int8/avx2_chanwise_stride2.h @@ -0,0 +1,30 @@ +/** + * \file src/x86/conv_bias/int8/avx2_chanwsie_stride2.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include "src/x86/conv_bias/int8/common_helper.h" +#include "src/x86/conv_bias/opr_impl.h" + +namespace megdnn { +namespace x86 { +namespace avx2_chanwise_stride2 { +using conv_fun = std::function; +SmallVector get_kimpls(const NCBKernSizeParam& param, + WorkspaceBundle bundle); + +} // namespace avx2_chanwise_stride2 +} // namespace x86 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/int8/chanwise_helper.h b/dnn/src/x86/conv_bias/int8/chanwise_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..362ea082b36d9c5c726e71fb7c232d880a90f6b8 --- /dev/null +++ b/dnn/src/x86/conv_bias/int8/chanwise_helper.h @@ -0,0 +1,83 @@ +/** + * \file dnn/src/x86/conv_bias/int8/chainwise_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once + +#include "megdnn/arch.h" +#include "src/x86/conv_bias/opr_impl.h" + +namespace megdnn { +namespace x86 { +using NCBKern = fallback::ConvBiasImpl::NCBKern; +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +static inline bool need_dst_copy(const NCBKernSizeParam& param) { + return param.osz[1] % 16; +} + +static inline bool need_src_copy(const NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + return (fm.padding[0] != 0 || fm.padding[1] != 0) ? true + : need_dst_copy(param); +} + +static inline void get_rectified_size(const NCBKernSizeParam& param, + size_t& IH2, size_t& IW2, size_t& OH2, + size_t& OW2) { + auto&& fm = param.filter_meta; + auto SW = fm.stride[1]; + auto OH = param.osz[0]; + auto OW = param.osz[1]; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + + OH2 = OH; + OW2 = (OW + 15) & ~15; + IH2 = SW * OH + FH - SW; + IW2 = SW * OW2 + FW - SW; +} + +static inline void copy_padding_kern( + WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { + size_t IW = kern_param.isz[1]; + size_t IH = kern_param.isz[0]; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + + size_t IH2, IW2, OH2, OW2; + get_rectified_size(kern_param, IH2, IW2, OH2, OW2); + bool need_src_copy_var = need_src_copy(kern_param); + size_t padding_group_size = IH2 * IW2; + bundle.set(kern_param.workspace_ptr); + + size_t group_id = ncb_index.ndrange_id[0], + batch_id = ncb_index.ndrange_id[1], + channel_id = ncb_index.ndrange_id[2]; + size_t workspace_group_id = ncb_index.thread_id; + const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); + if (need_src_copy_var) { + int8_t* sptr_base = static_cast(bundle.get(0)) + + workspace_group_id * padding_group_size; + std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); + rep(ih, std::min(IH, IH2)) { + std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); + } + } +}; + +} // namespace x86 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/int8/common_helper.h b/dnn/src/x86/conv_bias/int8/common_helper.h index b01d6a637d2bc27ffd5b33b73e5e073540d69ba1..7fa6e96d0ddab0813c913be3699b7c31a577de3a 100644 --- a/dnn/src/x86/conv_bias/int8/common_helper.h +++ b/dnn/src/x86/conv_bias/int8/common_helper.h @@ -6,13 +6,15 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include -#include "src/common/unroll_macro.h" #include "megdnn/arch.h" +#include "src/common/unroll_macro.h" +#include "src/x86/conv_bias/int8/chanwise_helper.h" #ifdef WIN32 #include #endif diff --git a/dnn/src/x86/conv_bias/opr_impl.cpp b/dnn/src/x86/conv_bias/opr_impl.cpp index 04d91d08d884b8f28e01bc382811d6ccce513a69..2ca8f4174975190d0e5aab67884262eb387663b4 100644 --- a/dnn/src/x86/conv_bias/opr_impl.cpp +++ b/dnn/src/x86/conv_bias/opr_impl.cpp @@ -6,17 +6,18 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/x86/conv_bias/opr_impl.h" #include #include -#include "src/x86/matrix_mul/opr_impl.h" #include "src/common/metahelper.h" #include "src/common/opr_delegate.h" #include "src/x86/conv_bias/f32/algos.h" #include "src/x86/conv_bias/int8/algos.h" +#include "src/x86/matrix_mul/opr_impl.h" using namespace megdnn; using namespace x86; @@ -69,6 +70,10 @@ void* ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::type() const { return x86_algo_type; } +void* ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::type() const { + return x86_algo_type; +} + class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDirect stride1_direct_large_group{true}; AlgoDirect stride1_direct_small_group{false}; @@ -77,6 +82,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8; AlgoAVX2DirectConvStride2 avx2_stride2_direct; AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8; + AlgoChanWiseAvx2Stride2Qint8 avx2_stride2_chanwsie_qint8; AlgoMatrixMul matmul; #if MEGDNN_X86_WITH_MKL_DNN AlgoMkldnnMatmulQint8 mkldnn_matmul_qint8; @@ -85,6 +91,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoMkldnnConv mkldnn_conv_fp32; #endif SmallVector> refhold; + public: AlgoPack() { #if MEGDNN_X86_WITH_MKL_DNN @@ -100,6 +107,7 @@ public: all_algos.emplace_back(&avx2_stride1_direct_int8); all_algos.emplace_back(&avx2_stride2_direct); all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); + all_algos.emplace_back(&avx2_stride2_chanwsie_qint8); all_algos.emplace_back(&matmul); static CpuOprDelegationStorage<> storage; @@ -107,7 +115,8 @@ public: auto&& matmul_algos = static_cast(matmul_opr)->algo_pack(); for (auto&& algo : matmul_algos) { - if (algo->type() == nullptr) continue; + if (algo->type() == nullptr) + continue; for (uint32_t tile_size : {8, 16, 24}) { refhold.emplace_back(new AlgoFP32WinogradF63_8x8( static_cast(algo), diff --git a/dnn/src/x86/conv_bias/opr_impl.h b/dnn/src/x86/conv_bias/opr_impl.h index bece5476cdf4b2af01f46ee9f1fa41d781a57cb1..3a948d5b8fc74351ccac28fce909f1e477412d00 100644 --- a/dnn/src/x86/conv_bias/opr_impl.h +++ b/dnn/src/x86/conv_bias/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -32,6 +33,7 @@ public: class AlgoDirectAvx2Stride1Int8; class AlgoAVX2DirectConvStride2; class AlgoChanWiseAvx2Stride1Qint8; + class AlgoChanWiseAvx2Stride2Qint8; #if MEGDNN_X86_WITH_MKL_DNN class AlgoMkldnnConv; class AlgoMkldnnQint8; diff --git a/dnn/test/x86/conv_bias.cpp b/dnn/test/x86/conv_bias.cpp index ef442f5017fa71425fe5651f61aa040dc277e78e..bc65c0a79a4cd8500f06d55e4bd9fe410adf22f5 100644 --- a/dnn/test/x86/conv_bias.cpp +++ b/dnn/test/x86/conv_bias.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/x86/utils.h" #include "test/x86/fixture.h" @@ -41,7 +42,8 @@ TEST_F(X86, CONV_BIAS_FORWARD) { } } -TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { +static void avx2_chanwise_direct_int8x8x32(Handle* handle, uint32_t stride, + const char* algo) { using namespace conv_bias; std::vector args; @@ -50,8 +52,8 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { if (w + 2 * p < kernel || h + 2 * p < kernel) return; param::ConvBias param; - param.stride_h = 1; - param.stride_w = 1; + param.stride_h = stride; + param.stride_w = stride; param.pad_h = p; param.pad_w = p; param.nonlineMode = nonline_mode; @@ -74,7 +76,7 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) run(ic, w, h, kernel, pad, nonline_mode); - Checker checker(handle()); + Checker checker(handle); UniformIntRNG rng{-50, 50}; checker.set_dtype(0, dtype::Int8()) .set_dtype(1, dtype::Int8()) @@ -85,15 +87,25 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { .set_rng(2, &rng) .set_epsilon(1e-3); checker.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker( - "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1")); + conv_bias::ConvBiasAlgoChecker(algo)); for (auto&& arg : args) { checker.set_param(arg.param).exec( {arg.src, arg.filter, arg.bias, {}, {}}); } } -TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { +TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { + avx2_chanwise_direct_int8x8x32(handle(), 1, + "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"); +} + +TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE2_INT8x8x32) { + avx2_chanwise_direct_int8x8x32(handle(), 2, + "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"); +} + +static void avx2_chanwise_direct_quantizeds32(Handle* handle, uint32_t stride, + const char* algo) { using namespace conv_bias; std::vector args; @@ -102,8 +114,8 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { if (w + 2 * p < kernel || h + 2 * p < kernel) return; param::ConvBias param; - param.stride_h = 1; - param.stride_w = 1; + param.stride_h = stride; + param.stride_w = stride; param.pad_h = p; param.pad_w = p; param.nonlineMode = nonline_mode; @@ -126,7 +138,7 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) run(ic, w, h, kernel, pad, nonline_mode); - Checker checker(handle()); + Checker checker(handle); UniformIntRNG rng{-50, 50}; checker.set_dtype(0, dtype::QuantizedS8(2.5f)) .set_dtype(1, dtype::QuantizedS8(2.5f)) @@ -137,15 +149,26 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { .set_rng(2, &rng) .set_epsilon(1e-3); checker.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker( - "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1")); + conv_bias::ConvBiasAlgoChecker(algo)); for (auto&& arg : args) { checker.set_param(arg.param).exec( {arg.src, arg.filter, arg.bias, {}, {}}); } } -TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { +TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { + avx2_chanwise_direct_quantizeds32( + handle(), 1, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"); +} + +TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE2_QuantizedS32) { + avx2_chanwise_direct_quantizeds32( + handle(), 2, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"); +} + +static void avx2_chanwise_direct_quantizeds8x8x8(Handle* handle, + uint32_t stride, + const char* algo) { using namespace conv_bias; std::vector args; @@ -154,8 +177,8 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { if (w + 2 * p < kernel || h + 2 * p < kernel) return; param::ConvBias param; - param.stride_h = 1; - param.stride_w = 1; + param.stride_h = stride; + param.stride_w = stride; param.pad_h = p; param.pad_w = p; param.nonlineMode = nonline_mode; @@ -180,7 +203,7 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { NonlineMode::RELU}) run(ic, w, h, kernel, pad, nonline_mode); - Checker checker(handle()); + Checker checker(handle); UniformIntRNG rng{-50, 50}; checker.set_dtype(0, dtype::QuantizedS8(2.5f)) .set_dtype(1, dtype::QuantizedS8(2.5f)) @@ -191,14 +214,23 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { .set_rng(2, &rng) .set_epsilon(1e-3); checker.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker( - "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1")); + conv_bias::ConvBiasAlgoChecker(algo)); for (auto&& arg : args) { checker.set_param(arg.param).exec( {arg.src, arg.filter, arg.bias, {}, {}}); } } +TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { + avx2_chanwise_direct_quantizeds8x8x8( + handle(), 1, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"); +} + +TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE2_QuantizedS8x8x8) { + avx2_chanwise_direct_quantizeds8x8x8( + handle(), 2, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"); +} + TEST_F(X86_MULTI_THREADS, AVX2_CONV_BIAS_DIRECT_STRIDE1_INT8x8x32) { using namespace conv_bias; std::vector args; @@ -343,7 +375,6 @@ TEST_F(X86_MULTI_THREADS, AVX2_CONV_BIAS_DIRECT_STRIDE1_S8S8S8) { args.emplace_back(param, TensorShape{2, 2 * ic, h, w}, TensorShape{2, oc / 2, ic, kernel, kernel}, TensorShape{1, oc, 1, 1}); - }; for (size_t kernel : {2, 3, 5, 7}) @@ -967,8 +998,8 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { #if MEGDNN_X86_WITH_MKL_DNN if (x86::is_supported(x86::SIMDType::VNNI)) { checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, - dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, - "CONV1x1:X86_INT8X8X32_MKLDNN:24"); + dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, + "CONV1x1:X86_INT8X8X32_MKLDNN:24"); } #endif #if MEGDNN_X86_WITH_VNNI @@ -983,8 +1014,8 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, "CONV1x1:X86_INT8X8X32_AVX2_4X16X2:24"); checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, - dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, - "CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24"); + dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, + "CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24"); } checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, @@ -1231,7 +1262,7 @@ TEST_F(X86_MULTI_THREADS, BENCHMARK_CONVBIAS_FP32_MKLDNN) { #endif /************************* Winograd ****************************/ -namespace{ +namespace { std::vector get_winograd_mk_nchw88_args() { std::vector args; param::ConvBias cur_param; @@ -1265,17 +1296,17 @@ std::vector get_winograd_mk_nchw88_args() { TensorShape{2, oc, ic, 3, 3, 8, 8}, TensorShape{1, 2 * oc, 1, 1, 8});*/ }}} - // clang-format on - //! test for multi-thread OC parallel - cur_param.sparse = param::ConvBias::Sparse::DENSE; - cur_param.pad_h = cur_param.pad_w = 1; - args.emplace_back(cur_param, TensorShape{2, 1, 9, 9, 8}, - TensorShape{128, 1, 3, 3, 8, 8}, - TensorShape{1, 128, 1, 1, 8}); - /*cur_param.sparse = param::ConvBias::Sparse::GROUP; - args.emplace_back(cur_param, TensorShape{2, 2, 9, 9, 8}, - TensorShape{2, 128, 1, 3, 3, 8, 8}, - TensorShape{1, 2 * 128, 1, 1, 8});*/ + // clang-format on + //! test for multi-thread OC parallel + cur_param.sparse = param::ConvBias::Sparse::DENSE; + cur_param.pad_h = cur_param.pad_w = 1; + args.emplace_back(cur_param, TensorShape{2, 1, 9, 9, 8}, + TensorShape{128, 1, 3, 3, 8, 8}, + TensorShape{1, 128, 1, 1, 8}); + /*cur_param.sparse = param::ConvBias::Sparse::GROUP; + args.emplace_back(cur_param, TensorShape{2, 2, 9, 9, 8}, + TensorShape{2, 128, 1, 3, 3, 8, 8}, + TensorShape{1, 2 * 128, 1, 1, 8});*/ } return args; } @@ -1329,7 +1360,8 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_WINOGRAD_WEIGHT_PREPROCESS) { auto conv_bias_opr = handle->create_operator(); conv_bias_opr->param() = param; - conv_bias_opr->param().format = param::ConvBias::Format::NCHW88_WINOGRAD; + conv_bias_opr->param().format = + param::ConvBias::Format::NCHW88_WINOGRAD; conv_bias_opr->param().output_block_size = m; size_t conv_bias_workspace_in_bytes = conv_bias_opr->get_workspace_in_bytes( @@ -1720,17 +1752,16 @@ void benchmark_impl(const param::ConvBias param, } } -void benchmark_impl_comp(const param::ConvBias param, - std::vector, float>>& - shapes_and_computation, - const std::string algo_name, const std::string algo_name1,size_t RUNS, - TaskExecutorConfig&& multi_thread_config, - TaskExecutorConfig&& single_thread_config,std::vector dtype_v) { - +void benchmark_impl_comp( + const param::ConvBias param, + std::vector, float>>& + shapes_and_computation, + const std::string algo_name, const std::string algo_name1, size_t RUNS, + TaskExecutorConfig&& multi_thread_config, + TaskExecutorConfig&& single_thread_config, std::vector dtype_v) { std::vector data_type = {dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; - std::vector multi_thread_times, single_thread_times; { auto multi_thread_hanle = @@ -1738,10 +1769,10 @@ void benchmark_impl_comp(const param::ConvBias param, auto benchmarker = Benchmarker(multi_thread_hanle.get()); benchmarker.set_times(RUNS) .set_display(false) - .set_dtype(0,dtype_v[0]) - .set_dtype(1,dtype_v[1]) - .set_dtype(2,dtype_v[2]) - .set_dtype(4,dtype_v[3]) + .set_dtype(0, dtype_v[0]) + .set_dtype(1, dtype_v[1]) + .set_dtype(2, dtype_v[2]) + .set_dtype(4, dtype_v[3]) .set_param(param) .set_before_exec_callback( conv_bias::ConvBiasAlgoChecker( @@ -1756,10 +1787,10 @@ void benchmark_impl_comp(const param::ConvBias param, auto benchmarker = Benchmarker(single_thread_handle.get()); benchmarker.set_times(RUNS) .set_display(false) - .set_dtype(0,dtype_v[0]) - .set_dtype(1,dtype_v[1]) - .set_dtype(2,dtype_v[2]) - .set_dtype(4,dtype_v[3]) + .set_dtype(0, dtype_v[0]) + .set_dtype(1, dtype_v[1]) + .set_dtype(2, dtype_v[2]) + .set_dtype(4, dtype_v[3]) .set_param(param) .set_before_exec_callback( conv_bias::ConvBiasAlgoChecker( @@ -1789,11 +1820,13 @@ void benchmark_impl_comp(const param::ConvBias param, } } // namespace -TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_CHANWISE_AVX2_INT8) { + +static void benchmark_convbias_chanwise_avx2_int8(uint32_t stride, + const char* algo) { constexpr size_t RUNS = 50; param::ConvBias param; - param.stride_h = 1; - param.stride_w = 1; + param.stride_h = stride; + param.stride_w = stride; param.sparse = param::ConvBias::Sparse::GROUP; std::vector data_type = {dtype::Int8(), dtype::Int8(), @@ -1841,14 +1874,23 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_CHANWISE_AVX2_INT8) { bench_case(1, 576, 14, 14, 2); bench_case(1, 960, 7, 7, 2); - std::string algo_name = "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"; - printf("Benchmark X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1\n"); + std::string algo_name = algo; + printf("Benchmark %s\n", algo); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, data_type); benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, data_type); shapes_and_computation.clear(); } +TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_CHANWISE_AVX2_INT8_S1) { + benchmark_convbias_chanwise_avx2_int8( + 1, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"); +} + +TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_CHANWISE_AVX2_INT8_S2) { + benchmark_convbias_chanwise_avx2_int8( + 2, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"); +} TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8) { constexpr size_t RUNS = 50; @@ -2129,7 +2171,8 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32) { shapes_and_computation.clear(); } -TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) { +TEST_F(X86_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) { constexpr size_t RUNS = 50; param::ConvBias param; @@ -2143,9 +2186,8 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) dtype::Float32(), dtype::Float32()}; std::vector, float>> shapes_and_computation; - auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, - size_t W, size_t FS, - size_t group) { + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS, size_t group) { SmallVector shapes{{N, IC, H, W}, {OC / group, IC / group, FS, FS}, {1, OC, 1, 1}, @@ -2167,7 +2209,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) bench_case(1, 32, 32, 100, 100, 3, 1); bench_case(1, 32, 32, 80, 80, 3, 1); bench_case(1, 32, 32, 80, 80, 3, 1); - + bench_case(1, 64, 32, 7, 7, 3, 1); bench_case(1, 64, 64, 7, 7, 3, 1); bench_case(1, 64, 128, 7, 7, 3, 1); @@ -2192,10 +2234,10 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) std::string algo_name = "IM2COLMATMUL:X86_F32_MKL_PACKA:192"; std::string algo_name1 = "IM2COLMATMUL:X86_F32_BLAS:192"; printf("Benchmark IM2COLMATMUL:X86_F32_BLAS algo\n"); - benchmark_impl_comp(param, shapes_and_computation, algo_name,algo_name1, RUNS, - {1, {4}}, {1, {4}},data_type); - benchmark_impl_comp(param, shapes_and_computation, algo_name,algo_name1, RUNS, - {1, {7}}, {1, {7}},data_type); + benchmark_impl_comp(param, shapes_and_computation, algo_name, algo_name1, + RUNS, {1, {4}}, {1, {4}}, data_type); + benchmark_impl_comp(param, shapes_and_computation, algo_name, algo_name1, + RUNS, {1, {7}}, {1, {7}}, data_type); shapes_and_computation.clear(); } @@ -2269,7 +2311,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_INT8X8X32) { shapes_and_computation.clear(); } -namespace{ +namespace { std::vector get_winograd_benchmark_args(size_t kernel, size_t pack_size) { std::vector args; @@ -2290,14 +2332,14 @@ std::vector get_winograd_benchmark_args(size_t kernel, param.pad_h = p; param.pad_w = p; - args.push_back(conv_bias::TestArg{param, - TensorShape{1, ic/8, h, w, 8}, - TensorShape{oc/8, ic/8, kernel, kernel, 8, 8}, - {1, oc/8, 1, 1, 8}}); - + args.push_back(conv_bias::TestArg{ + param, + TensorShape{1, ic / 8, h, w, 8}, + TensorShape{oc / 8, ic / 8, kernel, kernel, 8, 8}, + {1, oc / 8, 1, 1, 8}}); }; for (size_t ic : {64, 128, 256}) { - for (size_t oc : {64,128,256}) { + for (size_t oc : {64, 128, 256}) { pack(oc, ic, 56, 56, kernel, kernel / 2); pack(oc, ic, 14, 14, kernel, kernel / 2); pack(oc, ic, 28, 28, kernel, kernel / 2); @@ -2317,8 +2359,8 @@ std::vector get_winograd_benchmark_args(size_t kernel, return args; } -void benchmark_winograd(const char* algo_name, Handle* handle, - size_t kernel, size_t pack_size) { +void benchmark_winograd(const char* algo_name, Handle* handle, size_t kernel, + size_t pack_size) { auto&& args = get_winograd_benchmark_args(kernel, pack_size); using namespace conv_bias; constexpr size_t RUN = 10; @@ -2361,7 +2403,7 @@ void benchmark_winograd(const char* algo_name, Handle* handle, computations / used_winograd, used / used_winograd); } } -} +} // namespace TEST_F(X86, BENCHMARK_CONVBIAS_WINOGRAD_F63_8x8) { benchmark_winograd("WINOGRAD:X86_F32MK8_8X8:8:6:8", handle(), 3, 8);