diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index fc5dcad5c27d5f31c2c5739fce8906453fa45678..8da31ae5448df35daac98ddb7398d04682ee2192 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -189,6 +189,28 @@ public: fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param) const override; }; + +class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { +public: + AlgoDotS8Direct_NCHW44() {} + + bool is_reproducible() const override { return true; } + const char* name() const override { + return "ARMDOTS8DIRECT_NCHW44"; + } + bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(FallbackConvBiasImpl*, + const NCBKernSizeParam&) const override; + + SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + + bool is_preferred(megdnn::fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; +}; #endif class ConvBiasImpl::AlgoS8WinogradF23_8x8 final : public AlgoBase { diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3538b352c42c83841ad260d714e71f6d1e567eb0 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp @@ -0,0 +1,370 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#ifdef __ARM_FEATURE_DOTPROD + +#include "src/arm_common/elemwise_helper/kimpl/typecvt.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" +#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h" +namespace megdnn { +namespace arm_common { +namespace direct_dotprod_nchw44 { + +template <> +void copy_packed_src_int8_nchw44<1>(int8_t* dst, const int dst_step, + const int8_t* src, const int src_step, + const int ic, const int ic_step, + const int ih, const int pad_left, + const int pad_right, const int pad_top, + const int pad_bottom) { + constexpr int IC_PACK_SIZE = 4; + rep_step(ic_idx, ic, IC_PACK_SIZE) { + const int8_t* i_src = src + ic_idx * ic_step; + //! pad top + int bytes_pad_top = pad_top * dst_step * IC_PACK_SIZE * sizeof(int8_t); + memset(dst, 0, bytes_pad_top); + dst += bytes_pad_top / sizeof(int8_t); + rep(ih_idx, ih) { + int bytes_row_in_dst = dst_step * IC_PACK_SIZE * sizeof(int8_t); + memset(dst, 0, bytes_row_in_dst); + + //! left elements + int pad_left_elements = pad_left * IC_PACK_SIZE; + //! copy row [ih_idx, x] + int bytes_copy = src_step * IC_PACK_SIZE * sizeof(int8_t); + memcpy(dst + pad_left_elements, i_src, bytes_copy); + + //! dst move to next row + dst += bytes_row_in_dst / sizeof(int8_t); + //! src move to next row + i_src += bytes_copy / sizeof(int8_t); + } + //! pad bottom + int bytes_pad_bottom = + pad_bottom * dst_step * IC_PACK_SIZE * sizeof(int8_t); + memset(dst, 0, bytes_pad_bottom); + dst += bytes_pad_bottom / sizeof(int8_t); + } +} + +template <> +void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step, + const int8_t* src, const int src_step, + const int ic, const int ic_step, + const int ih, const int pad_left, + const int pad_right, const int pad_top, + const int pad_bottom) { + constexpr int IC_PACK_SIZE = 4; + int odd_start = megdnn::div_ceil(dst_step, 2); + bool nochange = pad_left % 2 == 0; + rep_step(ic_idx, ic, IC_PACK_SIZE) { + const int32_t* i_src = + reinterpret_cast(src + ic_idx * ic_step); + int bytes_pad_top = pad_top * dst_step * IC_PACK_SIZE * sizeof(int8_t); + memset(dst, 0, bytes_pad_top); + dst += bytes_pad_top / sizeof(int8_t); + rep(ih_idx, ih) { + int bytes_row_in_dst = dst_step * IC_PACK_SIZE * sizeof(int8_t); + memset(dst, 0, bytes_row_in_dst); + + int32_t* dst_even = reinterpret_cast(dst) + pad_left / 2 + + pad_left % 2; + int32_t* dst_odd = + reinterpret_cast(dst) + odd_start + pad_left / 2; + int i_src_idx = 0; + + if (nochange) { + for (; i_src_idx + 7 < src_step; i_src_idx += 8) { + int32x4x2_t tmp; + tmp = vld2q_s32(i_src + i_src_idx); + vst1q_s32(dst_even, tmp.val[0]); + vst1q_s32(dst_odd, tmp.val[1]); + dst_even += 4; + dst_odd += 4; + } + } else { + for (; i_src_idx + 7 < src_step; i_src_idx += 8) { + int32x4x2_t tmp; + tmp = vld2q_s32(i_src + i_src_idx); + vst1q_s32(dst_even, tmp.val[1]); + vst1q_s32(dst_odd, tmp.val[0]); + dst_even += 4; + dst_odd += 4; + } + } + + for (; i_src_idx < src_step; ++i_src_idx) { + if (nochange) { + if (i_src_idx % 2 == 0) { + *dst_even = *(i_src + i_src_idx); + dst_even++; + } else { + *dst_odd = *(i_src + i_src_idx); + dst_odd++; + } + } else { + if (i_src_idx % 2 == 0) { + *dst_odd = *(i_src + i_src_idx); + dst_odd++; + } else { + *dst_even = *(i_src + i_src_idx); + dst_even++; + } + } + } + //! dst move to next row + dst += bytes_row_in_dst / sizeof(int8_t); + //! src move to next row + i_src += src_step; + } + //! pad bottom + int bytes_pad_bottom = + pad_bottom * dst_step * IC_PACK_SIZE * sizeof(int8_t); + memset(dst, 0, bytes_pad_bottom); + dst += bytes_pad_bottom / sizeof(int8_t); + } +} + +template +void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, + const int8_t* src, const int ih, const int iw, + const int8_t* filter, const int32_t* bias, + const int oh_size, const int oc, const int ic, + const Op& op) { + constexpr int FH = filter_size; + constexpr int FW = filter_size; + constexpr int IC_PACK_SIZE = 4; + constexpr int OC_PACK_SIZE = 4; + +#if MEGDNN_AARCH64 + constexpr int OC_BIG_INTERVAL = 12; + constexpr int OC_MID_INTERVAL = 8; + constexpr int OC_SMA_INTERVAL = 4; +#else + constexpr int OC_BIG_INTERVAL = 4; + constexpr int OC_MID_INTERVAL = 4; + constexpr int OC_SMA_INTERVAL = 4; +#endif + + constexpr int OW_INTERVAL = 8; + constexpr int SH = stride; + + const int dst_numbers_per_channel = oh * ow; + const int ow_remain = ow % OW_INTERVAL; + const int ow_end_idx = ow - ow_remain; + const int oc_remain = + oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 + const int oc_end_idx = oc - oc_remain; + const int dst_numbers_4channel_packed = + dst_numbers_per_channel * OC_PACK_SIZE; + + using remain_fun = std::function; + + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_mid_oc_remain = nullptr; + remain_fun kern_sma_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KernNeonSdotNCHW44::impl; \ + kern_mid_oc_remain = \ + KernNeonSdotNCHW44::impl; \ + kern_sma_oc_remain = \ + KernNeonSdotNCHW44::impl; \ + break; + UNROLL_CALL_RAW(8, cb); +#undef cb + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } + + //! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC] + //! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL, + //! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates + //! [OW_INTERVAL, 1, OC_INTERVAL] each time + for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) { + const int filter_offset_in_element = oc_idx * ic * FH * FW; + for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { + for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + KernNeonSdotNCHW44:: + impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } + if (ow_remain) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + kern_big_oc_remain(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } + } + } + +#ifdef MEGDNN_AARCH64 + //! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32 + if (oc_remain) { + int oc_idx = oc_end_idx; + const int filter_offset_in_element = oc_idx * ic * FH * FW; + for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { + for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + if (oc_remain == 8) { + KernNeonSdotNCHW44< + dst_type, stride, bias_mode, Op, OW_INTERVAL, + filter_size, OC_MID_INTERVAL, + OW_INTERVAL>::impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, + iw, + filter + + filter_offset_in_element, + bias + bias_offset_in_element, + ic, op); + } else { + KernNeonSdotNCHW44< + dst_type, stride, bias_mode, Op, OW_INTERVAL, + filter_size, OC_SMA_INTERVAL, + OW_INTERVAL>::impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, + iw, + filter + + filter_offset_in_element, + bias + bias_offset_in_element, + ic, op); + } + } + if (ow_remain) { + const int src_offset_in_element = + (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; + const int dst_offset_in_element = + oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; + const int bias_offset_in_element = oc_idx; + if (oc_remain == 8) { + kern_mid_oc_remain(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } else { + kern_sma_oc_remain(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); + } + } + } + } +#endif +} + +#define CONSTRUCT_FUNC(filter_size) \ + template \ + void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \ + dst_type* dst, const int oh, const int ow, const int8_t* src, \ + const int ih, const int iw, const int8_t* weight, \ + const int32_t* bias, const int oh_size, const int oc, \ + const int ic, const Op& op) { \ + conv_direct_sdot_int8_nchw44( \ + dst, oh, ow, src, ih, iw, weight, bias, oh_size, oc, ic, op); \ + } + +CONSTRUCT_FUNC(2); +CONSTRUCT_FUNC(3); +CONSTRUCT_FUNC(5); +CONSTRUCT_FUNC(7); +#undef CONSTRUCT_FUNC + +#define INSTANTIATION(dst_type, stride, i, bias_mode, Op) \ + template void conv_direct_##i##x##i##_int8_nchw44( \ + dst_type * dst, const int oh, const int ow, const int8_t* src, \ + const int ih, const int iw, const int8_t* weight, \ + const int32_t* bias, const int oh_size, const int oc, \ + const int ic, const Op& op); + +#define FOR_OP(stride, i, bias_mode) \ + INSTANTIATION(dt_int8, stride, i, bias_mode, \ + TypeCvtOp) \ + INSTANTIATION(dt_int32, stride, i, bias_mode, \ + NoneOp) \ + INSTANTIATION(dt_int8, stride, i, bias_mode, \ + ReluOp) \ + INSTANTIATION(dt_int8, stride, i, bias_mode, \ + HSwishOp) + +#define FOR_BIAS(stride, i) \ + FOR_OP(stride, i, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) \ + FOR_BIAS(stride, 7) + +FOR_FILTER(1) +FOR_FILTER(2) + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef FOR_REMAIN +#undef INSTANTIATION + +} // namespace direct_dotprod_nchw44 +} // namespace arm_common +} // namespace megdnn +#endif + +//vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h new file mode 100644 index 0000000000000000000000000000000000000000..809befd02ed161c55815968118e6959dc908e081 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h @@ -0,0 +1,87 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#if __ARM_FEATURE_DOTPROD + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" + +namespace megdnn { +namespace arm_common { +namespace direct_dotprod_nchw44 { + +using BiasMode = ConvBiasForward::BiasMode; + +/** + * @brief : do direct conv with no side effect + * input buffer's size is [ih, iw] + * output buffer's size is [oh, ow] + * filter layout is [OC/4, IC/4, FH, FW, 4, 4] + * + * @param : [output ptr] dst + * [input] oh -> dst rows + * [input] ow -> dst cols + * [input ptr] src + * [input] ih -> rows of src used by this this kernel + * [input] iw -> src step in elements [iw2] + * [input ptr] filter + * [input ptr] bias + * [input] oh_size -> rows of result generated by this kernel + * [input] oc -> output channels + * [input] ic -> intput channels + * [input] op -> post process operator + * @return none + */ + +#define KERN(filter_size) \ + template \ + void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \ + dst_type* dst, const int oh, const int ow, const int8_t* src, \ + const int ih, const int iw, const int8_t* weight, \ + const int32_t* bias, const int oh_size, const int oc, \ + const int ic, const Op& op) + +KERN(2); +KERN(3); +KERN(5); +KERN(7); + +#undef KERN +/** + * @brief : copy data from src to dst for direct conv with no side effect + * @param : [output ptr] dst + * [input] dst_step -> step of dst in numbers of elements + * [input ptr] src + * [input] src_step -> step of src in numbers of elements + * [input] ic -> input channels + * [input] ic_step -> step of ic in numbers of elements + * [input] ih -> totle rows to copy + * [input] pad_left -> cols padding at left + * [input] pad_right -> cols padding at right + * [input] pad_top -> rows padding at top + * [input] pad_bottom -> rows padding at bottom + * @return none + */ +template +void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step, + const int8_t* src, const int src_step, + const int ic, const int ic_step, const int ih, + const int pad_left, const int pad_right, + const int pad_top, const int pad_bottom); + +} // namespace direct_dotprod_nchw44 +} // namespace arm_common +} // namespace megdnn + +#endif + +//vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5dc8d5023c9a2eb9ec17620e4e922528f71fad8f --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp @@ -0,0 +1,341 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_dotpord_nchw44_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#if __ARM_FEATURE_DOTPROD + +#include "src/arm_common/conv_bias/block_helper.h" +#include "src/arm_common/conv_bias/int8/algos.h" +#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" +#include "src/arm_common/elemwise_op.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; + +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8) + +using direct_fun = std::function; + +namespace { + +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih, + int& iw, int& oh, int& ow) { + int IC = param.filter_meta.icpg; + int IW = param.isz[1]; + int OH = param.osz[0]; + int OW = param.osz[1]; + + oh = OH; + ow = OW; + + constexpr int cacheline = 64 / sizeof(int8_t); + int oh_tile_size = + l2_block_helper(param.nr_threads, OH, IC * IW * sizeof(int8_t) * 2); + auto&& fm = param.filter_meta; + const int SH = static_cast(fm.stride[0]); + const int FH = static_cast(fm.spatial[0]); + const int PW = static_cast(fm.padding[1]); + ih = oh_tile_size * SH + FH - SH; + iw = round_up(IW + 2 * PW, cacheline); +} + +static inline int get_perthread_cache_bytes(const int ic, const int ih, + const int iw) { + // border_size is used to avoid read illegal memory + int border_size = 64 * 2; + return ic * ih * iw * sizeof(int8_t) + border_size; +} + +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + int IC = fm.icpg; + int ih2, iw2, oh2, ow2; + get_rectified_size(param, ih2, iw2, oh2, ow2); + + int bytes_of_copy_per_thread = get_perthread_cache_bytes(IC, ih2, iw2); + return {nullptr, {bytes_of_copy_per_thread * param.nr_threads}}; +} + +template +static void conv_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& ncb_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { + const int OH = ncb_param.osz[0]; + const int OW = ncb_param.osz[1]; + const int FH = ncb_param.filter_meta.spatial[0]; + const int IC = ncb_param.filter_meta.icpg; + const int OC = ncb_param.filter_meta.ocpg; + const int IH = ncb_param.isz[0]; + const int IW = ncb_param.isz[1]; + const int SH = ncb_param.filter_meta.stride[0]; + const int PH = ncb_param.filter_meta.padding[0]; + const int PW = ncb_param.filter_meta.padding[1]; + + int ih2 = 0; + int iw2 = 0; + int oh2 = 0; + int ow2 = 0; + get_rectified_size(ncb_param, ih2, iw2, oh2, ow2); + + constexpr int IC_PACK_SIZE = 4; + constexpr int OC_PACK_SIZE = 4; + bundle.set(ncb_param.workspace_ptr); + + const int batch_id = ncb_index.ndrange_id[0]; + const int group_id = ncb_index.ndrange_id[1]; + const int oh_tile_id = ncb_index.ndrange_id[2]; + const int thread_id = ncb_index.thread_id; + + const int oh_tile_size = l2_block_helper(ncb_param.nr_threads, OH, + IC * IW * sizeof(int8_t) * 2); + const int oh_start_row = oh_tile_id * oh_tile_size; + const int ih_start_row = std::max(oh_start_row * SH - PH, 0); + + const int oh_real_size = std::min(OH - oh_start_row, oh_tile_size); + const int ih_real_size = oh_real_size * SH + FH - SH; + + const int rows_padding_at_top = std::max(PH - oh_start_row * SH, 0); + const int rows_padding_at_bottom = + std::max((oh_start_row + oh_real_size - 1) * SH + FH - IH - PH, 0); + const int cols_padding_at_left = PW; + const int cols_padding_at_right = std::max(iw2 - IW - PW, 0); + + //! src layout{IC/4, IH, IW, 4} + const int bytes_of_src_offset = + ih_start_row * IW * IC_PACK_SIZE * sizeof(int8_t); + const int8_t* copy_src = static_cast( + ncb_param.src(batch_id, group_id) + bytes_of_src_offset); + + const int bytes_of_copy_per_thread = + get_perthread_cache_bytes(IC, ih2, iw2); + int8_t* copy_dst = reinterpret_cast(bundle.get(0)) + + thread_id * bytes_of_copy_per_thread; + + const int rows_copy_from_src = + ih_real_size - rows_padding_at_top - rows_padding_at_bottom; + + direct_dotprod_nchw44::copy_packed_src_int8_nchw44( + copy_dst, iw2, copy_src, IW, IC, IH * IW, rows_copy_from_src, + cols_padding_at_left, cols_padding_at_right, rows_padding_at_top, + rows_padding_at_bottom); + + const int8_t* weights = ncb_param.filter(group_id); + + dst_type* dst = ncb_param.dst(batch_id, group_id) + + oh_start_row * OW * OC_PACK_SIZE; + + //! only broadcast or no_bias + const int32_t* bias = ncb_param.bias(batch_id, group_id); + + Op op = Op(1.0f, 4.0f); + if (ncb_param.dst_type.enumv() == DTypeEnum::QuantizedS8) { + float scale_bias = + ncb_param.bias_type.param().scale; + float scale_dst = ncb_param.dst_type.param().scale; + op = Op(scale_bias, scale_dst); + } + +#define KERN1_NCHW44_CONV(filter) \ + direct_dotprod_nchw44::conv_direct_##filter##x##filter##_int8_nchw44< \ + dst_type, bias_mode, Op, stride>(dst, OH, OW, copy_dst, \ + ih_real_size, iw2, weights, bias, \ + oh_real_size, OC, IC, op); + DISPATCH_FILTER(filter_size, KERN1_NCHW44_CONV); +#undef KERN1_NCHW44_CONV +} + +} // namespace + +bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( + FallbackConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + auto SH = fm.stride[0]; + auto SW = fm.stride[1]; + auto OC = fm.ocpg; + auto IC = fm.icpg; + + //! src and filter are qint8, dst is qint8. + bool data_type_ok = param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || + param.dst_type.enumv() == DTypeEnum::QuantizedS32); + + if (param.bias_type.valid()) { + data_type_ok &= param.bias_type.enumv() == DTypeEnum::QuantizedS32; + } + + data_type_ok |= param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + param.dst_type.enumv() == DTypeEnum::Int32; + + bool layout_ok = fm.format == param::Convolution::Format::NCHW44_DOT && + IC % 4 == 0 && OC % 4 == 0; + + bool param_ok = !fm.should_flip && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && FH == FW && + (FH >= 2 && FH <= 7); + + bool stride_ok = SH == SW && (SH == 1 || SH == 2); + + return data_type_ok && layout_ok && param_ok && stride_ok; +} + +bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred( + megdnn::fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + return true; +} + +size_t ConvBiasImpl::AlgoDotS8Direct_NCHW44::get_workspace( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( + FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, + midout_iv("ALGODOTS8DIRECT_NCHW44"_hash)) { + auto fm = param.filter_meta; + size_t BATCH = param.n; + size_t GROUP = fm.group; + WorkspaceBundle wbundle = get_bundle(param); + direct_fun kernel = nullptr; + bool quantized = param.dst_type.enumv() == DTypeEnum::QuantizedS8; + +#define DO_CONV_KERN_FUN(dst_type, filter, bias_mode, op, stride) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, \ + midout_iv(#dst_type #filter #bias_mode #op##_hash)) { \ + kernel = conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(i, bias_mode, stride) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + if (quantized) { \ + DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \ + TypeCvtOp, \ + stride) \ + } else { \ + DO_CONV_KERN_FUN(dt_int32, i, bias_mode, \ + NoneOp, \ + stride) \ + } \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + if (quantized) { \ + DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \ + ReluOp, \ + stride) \ + } else { \ + megdnn_assert("No support NoQuantized RELU"); \ + } \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + if (quantized) { \ + DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \ + HSwishOp, \ + stride) \ + } else { \ + megdnn_assert("No support NoQuantized H_SWISH"); \ + } \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_STRIDE_PARAM(filter, bias_mode) \ + switch (fm.stride[0]) { \ + case 1: \ + GET_OP_PARAM(filter, bias_mode, 1); \ + break; \ + case 2: \ + GET_OP_PARAM(filter, bias_mode, 2); \ + break; \ + default: \ + megdnn_assert(0); \ + } + +#define GET_BIAS_MODE_PARAM(filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_STRIDE_PARAM(filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_STRIDE_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define SELECT_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + SELECT_CONV_KERN() + +#undef DO_CONV_KERN_FUN +#undef GET_OP_PARAM +#undef GET_STRIDE_PARAM +#undef GET_BIAS_MODE_PARAM +#undef SELECT_CONV_KERN + + megdnn_assert(kernel); + + SmallVector ret_kerns; + int OH = param.osz[0]; + int IC = param.filter_meta.icpg; + int IW = param.isz[1]; + int oh_tile_size = l2_block_helper(param.nr_threads, OH, + IC * IW * sizeof(int8_t) * 2); + size_t oh_tiles = static_cast(div_ceil(OH, oh_tile_size)); + + auto do_conv = [wbundle, kernel](const NCBKernParam& ncb_param, + const NCBKernIndex& ncb_index) { + kernel(wbundle, ncb_param, std::move(ncb_index)); + }; + + ret_kerns.push_back({do_conv, {BATCH, GROUP, oh_tiles}}); + return ret_kerns; + } + MIDOUT_END(); + return {}; +} + +#endif + +//vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h new file mode 100644 index 0000000000000000000000000000000000000000..7b3be20e3a167758c43fa58ff9b703ebe6810c10 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h @@ -0,0 +1,430 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#ifdef __ARM_FEATURE_DOTPROD + +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/intrinsic_helper.h" +#include "src/arm_common/neon_struct.h" +#include "src/common/unroll_macro.h" + +namespace megdnn { +namespace arm_common { +namespace direct_dotprod_nchw44 { + +constexpr int SIMD_LEN = 16; +constexpr int IC_PACK_SIZE = 4; +constexpr int OC_PACK_SIZE = 4; +constexpr int filter_next_col = + IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + +template +inline void init_ocx_ow8(int32x4_t c[][8], const int32_t* bias_ptr, + int oc_step) { + static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number."); + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { +#define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step); + switch (row) { + case 3: + UNROLL_CALL_RAW(8, BIAS_INIT, 2); + case 2: + UNROLL_CALL_RAW(8, BIAS_INIT, 1); + default: + UNROLL_CALL_RAW(8, BIAS_INIT, 0); + } +#undef BIAS_INIT + } else { +#define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0); + switch (row) { + case 3: + UNROLL_CALL_RAW(8, BIAS_INIT, 2); + case 2: + UNROLL_CALL_RAW(8, BIAS_INIT, 1); + default: + UNROLL_CALL_RAW(8, BIAS_INIT, 0); + } +#undef BIAS_INIT + } +} + +#define cb11(col) \ + op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); + +#define cb21(col) \ + op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); \ + op(res[1][col], \ + reinterpret_cast(dst_ptr + ld_dst_oc + col / 2 * 8)); + +#define cb31(col) \ + op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); \ + op(res[1][col], \ + reinterpret_cast(dst_ptr + ld_dst_oc + col / 2 * 8)); \ + op(res[2][col], reinterpret_cast(dst_ptr + ld_dst_oc + \ + ld_dst_oc + col / 2 * 8)); + +#define cb12(step) \ + op({{res[0][2 * step], res[0][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + step * 8)); + +#define cb22(step) \ + op({{res[0][2 * step], res[0][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + step * 8)); \ + op({{res[1][2 * step], res[1][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + ld_dst_oc + step * 8)); + +#define cb32(step) \ + op({{res[0][2 * step], res[0][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + step * 8)); \ + op({{res[1][2 * step], res[1][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + ld_dst_oc + step * 8)); \ + op({{res[2][2 * step], res[2][2 * step + 1]}}, \ + reinterpret_cast(dst_ptr + 2 * ld_dst_oc + step * 8)); + +template +struct StoreOCxOWx { + static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, + const int ld_dst_oc); +}; + +template +struct StoreOCxOWx<1, ow_remain, Op, T> { + static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, + const int ld_dst_oc) { + switch (ow_remain) { + case 8: + UNROLL_CALL_RAW(4, cb12); + break; + case 7: + cb11(6); + case 6: + UNROLL_CALL_RAW(3, cb12); + break; + case 5: + cb11(4); + case 4: + UNROLL_CALL_RAW(2, cb12); + break; + case 3: + cb11(2); + case 2: + UNROLL_CALL_RAW(1, cb12); + break; + case 1: + cb11(0); + default: + break; + } + } +}; + +template +struct StoreOCxOWx<2, ow_remain, Op, T> { + static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, + const int ld_dst_oc) { + switch (ow_remain) { + case 8: + UNROLL_CALL_RAW(4, cb22); + break; + case 7: + cb21(6); + case 6: + UNROLL_CALL_RAW(3, cb22); + break; + case 5: + cb21(4); + case 4: + UNROLL_CALL_RAW(2, cb22); + break; + case 3: + cb21(2); + case 2: + UNROLL_CALL_RAW(1, cb22); + break; + case 1: + cb21(0); + default: + break; + } + } +}; + +template +struct StoreOCxOWx<3, ow_remain, Op, T> { + static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, + const int ld_dst_oc) { + switch (ow_remain) { + case 8: + UNROLL_CALL_RAW(4, cb32); + break; + case 7: + cb31(6); + case 6: + UNROLL_CALL_RAW(3, cb32); + break; + case 5: + cb31(4); + case 4: + UNROLL_CALL_RAW(2, cb32); + break; + case 3: + cb31(2); + case 2: + UNROLL_CALL_RAW(1, cb32); + break; + case 1: + cb31(0); + default: + break; + } + } +}; + +#undef cb11 +#undef cb21 +#undef cb31 +#undef cb12 +#undef cb22 +#undef cb32 + +template +inline void store_ocx_owx_remain_static(int32x4_t res[][8], const Op& op, + T* dst_ptr, const int ld_dst_oc) { + StoreOCxOWx::impl(res, op, dst_ptr, ld_dst_oc); +} + +template +struct ShiftCalHelper { + static void impl(T& res, T2& src, T3& weight) { +#define cb(step) \ + res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \ + res[res_row][step], weight[weight_idx], \ + src[src_row][(src_start_idx + step) / 4]); + UNROLL_CALL_RAW(8, cb); +#undef cb + } +}; + +template +inline void cal_helper(T& res, T2& src, T3& weight) { + ShiftCalHelper::impl(res, src, weight); +}; + +/** + * oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x) + * gemm like kernel + * */ +template +struct KernNeonSdotNCHW44 { + static void impl(dst_type* dst, const int dst_step, const int8_t* src, + const int ih, const int iw, const int8_t* filter, + const int32_t* bias, const int ic, const Op& op); +}; + +template +struct KernNeonSdotNCHW44 { + static void impl(dst_type* dst, const int dst_step, const int8_t* src, + const int ih, const int iw, const int8_t* filter, + const int32_t* bias, const int ic, const Op& op) { + constexpr int FH = filter_size; + constexpr int FW = filter_size; + constexpr int filter_next_row = + FW * OC_PACK_SIZE * + IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + + const int filter_next_4oc = + FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + const int src_next_ic = ih * iw; + const int src_next_row = iw * IC_PACK_SIZE; + + constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1; + constexpr int LOOP = oc_interval / 4; + + int32x4_t res[3][ow_interval]; + init_ocx_ow8(res, bias, OC_PACK_SIZE); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { + const int8_t* i_src = src + ic_idx * src_next_ic; + const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; + for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { + int8x16_t src[1][4]; + int8x16_t weight[3]; + + load_helper(src, i_src, 0); + +//! do not use switch order 3,2,1 because it will slow the speed. +#define CALC_PART(step) \ + switch (LOOP) { \ + case 1: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ + break; \ + case 2: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ + weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ + filter_next_col * step); \ + cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \ + break; \ + case 3: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ + weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ + filter_next_col * step); \ + cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \ + weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ + filter_next_col * step); \ + cal_helper<2, 0, step, 2, Vdotq_laneq_s32>(res, src, weight); \ + break; \ + default: \ + break; \ + } + + switch (filter_size) { + case 2: + UNROLL_CALL_RAW(2, CALC_PART); + break; + case 3: + UNROLL_CALL_RAW(3, CALC_PART); + break; + case 5: + UNROLL_CALL_RAW(5, CALC_PART); + break; + case 7: + UNROLL_CALL_RAW(7, CALC_PART); + break; + default: + break; + } +#undef CALC_PART + + i_filter += filter_next_row; + i_src += src_next_row; + } + } + store_ocx_owx_remain_static(res, op, dst, + dst_step); + } +}; + +template +struct KernNeonSdotNCHW44 { + static void impl(dst_type* dst, const int dst_step, const int8_t* src, + const int ih, const int iw, const int8_t* filter, + const int32_t* bias, const int ic, const Op& op) { + constexpr int FH = filter_size; + constexpr int FW = filter_size; + constexpr int filter_next_row = + FW * OC_PACK_SIZE * + IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + + const int filter_next_4oc = + FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + const int src_next_ic = ih * iw; + const int src_next_row = iw * IC_PACK_SIZE; + + constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1; + constexpr int LOOP = oc_interval / 4; + + int32x4_t res[3][ow_interval]; + init_ocx_ow8(res, bias, OC_PACK_SIZE); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { + const int8_t* i_src = src + ic_idx * src_next_ic; + const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; + for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { + int8x16_t src[2][3]; + int8x16_t weight[3]; + const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE; + + load_helper(src, i_src, offset); + +//! do not use switch order 3,2,1 because it will slow the speed. +#define CALC_PART(step) \ + switch (LOOP) { \ + case 1: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ + weight); \ + break; \ + case 2: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ + weight); \ + weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ + filter_next_col * step); \ + cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \ + weight); \ + break; \ + case 3: \ + weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ + filter_next_col * step); \ + cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ + weight); \ + weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ + filter_next_col * step); \ + cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \ + weight); \ + weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ + filter_next_col * step); \ + cal_helper<2, step % 2, step / 2, 2, Vdotq_laneq_s32>(res, src, \ + weight); \ + break; \ + default: \ + break; \ + } + + switch (filter_size) { + case 2: + UNROLL_CALL_RAW(2, CALC_PART); + break; + case 3: + UNROLL_CALL_RAW(3, CALC_PART); + break; + case 5: + UNROLL_CALL_RAW(5, CALC_PART); + break; + case 7: + UNROLL_CALL_RAW(7, CALC_PART); + break; + default: + break; + } +#undef CALC_PART + + i_filter += filter_next_row; + i_src += src_next_row; + } + } + store_ocx_owx_remain_static(res, op, dst, + dst_step); + } +}; + +} // namespace direct_dotprod_nchw44 +} // namespace arm_common +} // namespace megdnn + +#endif + +//vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/intrinsic_helper.h b/dnn/src/arm_common/conv_bias/intrinsic_helper.h index daa7568daf5cb9af72467165d4618c16609fd923..67e2d3f7747094763635a02037dc1e77969cd346 100644 --- a/dnn/src/arm_common/conv_bias/intrinsic_helper.h +++ b/dnn/src/arm_common/conv_bias/intrinsic_helper.h @@ -536,6 +536,7 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr, #undef BAIS_INIT } } + /////////////////////////init_ocx_ow8//////////////////// inline float32x4_t neon_vdupq_n(float val) { diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 38ec7b7a8fc012701cd45fbc1c1e19c99aa6c849..caed88db66a65b41ebd39869508c52f108263623 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -64,6 +64,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDotU8DirectStride1 du8_direct_stride1_small_group{false}; AlgoDotU8DirectStride2 du8_direct_stride2_large_group{true}; AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; + + AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; #endif AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; @@ -103,6 +105,8 @@ public: direct_algos.emplace_back(&du8_direct_stride1_small_group); direct_algos.emplace_back(&du8_direct_stride2_large_group); direct_algos.emplace_back(&du8_direct_stride2_small_group); + + direct_algos.emplace_back(&ds8_direct_nchw44); #endif direct_algos.emplace_back(&qu8_direct_stride2_large_group); direct_algos.emplace_back(&qu8_direct_stride2_small_group); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 35c9f35396f6c3c83f312cd62d60f8d1a9073fe0..3482be1dc736ee40dd4aff45646d9d26aef4b298 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -67,6 +67,8 @@ private: class AlgoDotS8DirectStride2; class AlgoDotU8DirectStride1; class AlgoDotU8DirectStride2; + + class AlgoDotS8Direct_NCHW44; #endif class AlgoF32Direct; class AlgoF32DirectStride1; diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 6ce361e93d2853646893e7e0180b4290f50826f1..78b4d388a9d24056ca6db40a1bdefcde84a4735d 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -1809,6 +1809,81 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) { used1 / used0); } } + +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { + using namespace conv_bias; + + std::vector args; + auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, + size_t p, size_t stride, NonlineMode nonline_mode) { + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + param.format = param::ConvBias::Format::NCHW44_DOT; + + //! channel bias + args.emplace_back(param, TensorShape{1, ic/4, h, w, 4}, + TensorShape{oc/4, ic/4, kernel, kernel, 4, 4}, + TensorShape{1, oc/4, 1, 1, 4}); + }; + for (size_t stride : {1, 2}) + for (size_t kernel : {2, 3, 5, 7}) + for(size_t oc : {64}) + for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) { + run(oc, oc, 56, 56, kernel, kernel / 2, stride, nonline_mode); + } + + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark0.set_display(false); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("ARMDOTS8DIRECT_NCHW44")); + + Benchmarker benchmark1(handle()); + benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark1.set_display(false); + benchmark1.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout({arg.src, dtype::Int8()}, + {arg.filter, dtype::Int8()}, + {arg.bias, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * 8.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + auto used1 = benchmark1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + + printf("%s %s: Direct use: %f ms %f Gflops normal: %f ms %f GFlops " + "speedup: %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + used0, computations / used0, used1, computations / used1, + used1 / used0); + } +} + #endif #endif diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index ed9e2072c86b2257630103e47522c74db4dac91a..9e4d5f34aa308219814f3d53989475def68937dc 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -155,7 +155,7 @@ std::vector get_nchw44_conv_bias_args( if (support_sigmoid) { nonlinemode.emplace_back(NLMode::SIGMOID); } - + std::vector bias_mode = { megdnn::BiasMode::BROADCAST_CHANNEL_BIAS}; if (no_bias) { @@ -672,6 +672,63 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) { get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), "ARMDOTU8STRD2_SMALL_GROUP"); } + +/******************************dot int8x8x8 nchw44 ***********************/ +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) { + using namespace conv_bias; + std::vector args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1); + for (auto&& arg : args) + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, true, true); + for (auto&& arg : args) + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) { + using namespace conv_bias; + std::vector args = + get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, true, true); + for (auto&& arg : args) + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) { + using namespace conv_bias; + //! test qint8x8x8 + std::vector args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2); + for (auto&& arg : args) + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) { + using namespace conv_bias; + //! test qint8x8x8 + std::vector args = + get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, true, true); + for (auto&& arg : args) + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) { + using namespace conv_bias; + //! test qint8x8x8 + std::vector args = + get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, true, true); + for (auto&& arg : args) + arg.param.format = param::ConvBias::Format::NCHW44_DOT; + checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44"); +} + #endif TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {