diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h index ffed95c57b1e790e4d80847a77b2210beacc634f..c1e3711663e9578f7f36be42f2b685023a425795 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -48,6 +49,7 @@ class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { const NCBKernParam& kern_param, const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); + public: bool is_reproducible() const override { return true; } const char* name() const override { return "I8816STRD2"; } @@ -84,6 +86,21 @@ public: const NCBKernSizeParam& param) const override; }; +class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + +public: + AlgoI8x8x16DirectNCHWNCHW44() {} + bool is_reproducible() const override { return true; } + const char* name() const override { return "I8816_CONV_NCHW_NCHW44"; } + bool usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + const NCBKernSizeParam& param) const override; +}; + } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d23f13f94394aaf7b2d888db026de8e6f34d90bd --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp @@ -0,0 +1,357 @@ +/** + * \file + dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. + */ + +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/block_helper.h" +#include "src/arm_common/conv_bias/int8x8x16/algos.h" +#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; +using conv_fun = std::function; + +MIDOUT_DECL(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44) +namespace { +static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, + const int iw2) { + //! border_size is used to avoid read illegal memory + constexpr int iw_expand = 8; + int border_size = 64 * 2; + return ic * ih2 * iw2 * sizeof(int8_t) * iw_expand + border_size; +} +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, + int& iw2, int& oh2, int& ow2) { + int iw = param.isz[1]; + int oh = param.osz[0]; + int ow = param.osz[1]; + + oh2 = oh; + ow2 = ow; + + constexpr int iw_expand = 8; + auto&& fm = param.filter_meta; + const int stride_h = static_cast(fm.stride[0]); + const int filter_h = static_cast(fm.spatial[0]); + const int ic = fm.icpg; + iw2 = iw + 2 * static_cast(fm.padding[1]); + int block_oh = l2_block_helper(param.nr_threads, oh, + ic * iw2 * stride_h * iw_expand); + + ih2 = block_oh * stride_h + filter_h - stride_h; +} + +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + int group = fm.group; + int ic = fm.icpg; + int oc = fm.ocpg; + int fh = fm.spatial[0]; + int fw = fm.spatial[1]; + int stride = fm.stride[0]; + int ih2, iw2, oh2, ow2; + get_rectified_size(param, ih2, iw2, oh2, ow2); + + constexpr int pack_oc = 8; + const int weight_expand = stride == 1 ? 2 : 1; + size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); + size_t weight_size = group * round_up(oc, 8) * ic * fh * fw * + sizeof(int8_t) * weight_expand; + size_t bisa_size = 0; + if (param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && + oc % pack_oc != 0) { + bisa_size = round_up(oc, 8) * sizeof(int16_t); + } + return {nullptr, {src_size * param.nr_threads, weight_size, bisa_size}}; +}; + +static inline void copy_pad_src(int8_t* sptr_base, const int8_t* sptr_origin, + int ph, int pw, int pad_right, int ih, int iw, + int iw2, int pad_top, int pad_bottom, int ic, + int ic_stride) { + constexpr int iw_expand = 8; + MEGDNN_MARK_USED_VAR(ph); + rep(ic_idx, ic) { + const int8_t* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, sizeof(int8_t) * iw2 * pad_top * iw_expand); + sptr_base += iw2 * pad_top * iw_expand; + rep(ih_idx, ih) { + memset(sptr_base, 0, sizeof(int8_t) * pw * iw_expand); + sptr_base += pw * iw_expand; + memcpy_s8_dup(sptr_base, sptr, iw); + sptr_base += iw * iw_expand; + sptr += iw; + memset(sptr_base, 0, sizeof(int8_t) * pad_right * iw_expand); + sptr_base += pad_right * iw_expand; + } + memset(sptr_base, 0, sizeof(int8_t) * iw2 * pad_bottom * iw_expand); + sptr_base += iw2 * pad_bottom * iw_expand; + } +} +static void pack_weight(const WorkspaceBundle& bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { + const int group_id = ncb_index.ndrange_id[0]; + int fh = kern_param.filter_meta.spatial[0]; + int fw = kern_param.filter_meta.spatial[1]; + int oc = kern_param.filter_meta.ocpg; + int ic = kern_param.filter_meta.icpg; + int oc_block = oc; + int stride = kern_param.filter_meta.stride[0]; + constexpr int oc_idx = 0; + const int8_t* fptr = + kern_param.filter(group_id) + oc_idx * fh * fw * ic; + auto packed_weight = reinterpret_cast(bundle.get(1)) + + group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; + switch (stride) { + case 1: + i8i8i16_direct_nchw_nchw44::pack_weight_int8_nchw_nchw44<1>( + fptr, packed_weight, oc_block, fh, fw, ic); + break; + case 2: + i8i8i16_direct_nchw_nchw44::pack_weight_int8_nchw_nchw44<2>( + fptr, packed_weight, oc_block, fh, fw, ic); + break; + default: + break; + } + constexpr int pack_oc = 8; + if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && + oc % pack_oc != 0) { + auto packed_bias = reinterpret_cast(bundle.get(2)); + memcpy(packed_bias, kern_param.bias_ptr, + round_up(oc, 8) * sizeof(int16_t)); + } +} + +template +static void do_conv_kern(const WorkspaceBundle& bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange&, const CpuNDRange&) { + const int oh = kern_param.osz[0]; + const int ow = kern_param.osz[1]; + const int fh = kern_param.filter_meta.spatial[0]; + const int fw = kern_param.filter_meta.spatial[1]; + const int ic = kern_param.filter_meta.icpg; + const int oc = kern_param.filter_meta.ocpg; + const int ih = kern_param.isz[0]; + const int iw = kern_param.isz[1]; + const int stride_h = stride; + const int ph = kern_param.filter_meta.padding[0]; + const int pw = kern_param.filter_meta.padding[1]; + int ih2 = 0; + int iw2 = 0; + int oh2 = 0; + int ow2 = 0; + get_rectified_size(kern_param, ih2, iw2, oh2, ow2); + + constexpr int src_expand = 8; + constexpr int weight_expand = stride == 1 ? 2 : 1; + constexpr int pack_c = 4; + const int batch_id = ncb_index.ndrange_id[0]; + const int group_id = ncb_index.ndrange_id[1]; + constexpr int oc_idx = 0; + int oc_block = oc; + int oh_block = l2_block_helper(kern_param.nr_threads, oh, + ic * iw2 * stride_h * src_expand); + const int oh_idx = ncb_index.ndrange_id[2]; + const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); + const int ih_real = oh_block_real * stride_h + fh - stride_h; + const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0); + const int src_bottom_pad = std::max( + (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, + 0); + const int remain_right_pad = std::max(iw2 - iw - pw, 0); + const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw; + const int8_t* origin_sptr = + static_cast( + kern_param.src(batch_id, group_id, 0, 1, 1)) + + src_offset; + const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); + int8_t* sptr = reinterpret_cast((int8_t*)bundle.get(0) + + ncb_index.thread_id * src_size); + + copy_pad_src(sptr, origin_sptr, ph, pw, remain_right_pad, + ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, + src_bottom_pad, ic, ih * iw); + //! pack weight + auto packed_weight = + reinterpret_cast(bundle.get(1)) + + (group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw) * + weight_expand; + //! get param + int16_t* dst = kern_param.dst(batch_id, group_id) + + oh_idx * oh_block * ow * pack_c; + const int16_t* bptr = + kern_param.bias(batch_id, group_id) + oc_idx; + constexpr int pack_oc = 8; + if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && + oc % pack_oc != 0) { + bptr = reinterpret_cast(bundle.get(2)); + } + Op op; + + i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44< + bias_mode, Op, filter_size, stride>( + sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, + oh, oh_block_real, ow, op, ph, pw); +} + +} // namespace + +bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + auto&& fm = param.filter_meta; + auto fh = fm.spatial[0]; + int oc = fm.ocpg; + bool ok_type = ((param.src_type.enumv() == DTypeEnum::Int8 && + param.filter_type.enumv() == DTypeEnum::Int8 && + (param.dst_type.enumv() == DTypeEnum::Int16))) && + (fm.format == param::Convolution::Format::NCHW44); + bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1; + bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && + (fh == 2 || fh == 3 || fh == 5 || fh == 7); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 2 || fm.stride[0] == 1); + bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS && + param.nonlineMode == param::ConvBias::NonlineMode::IDENTITY; + bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; + return avaible; +} + +size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace( + const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::dispatch_kerns( + const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + const int batch = param.n; + const int group = fm.group; + WorkspaceBundle bundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; + //! NOTE: remain_w is not used to gen hash of midout for compatible with + //! shape runtime +#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44, \ + midout_iv(#stride #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, NoneOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define GET_BIAS_MODE_PARAM(stride, filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_KERN(stride) \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(stride, 2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(stride, 3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(stride, 5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(stride, 7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + switch (param.filter_meta.stride[0]) { + case 1: + DISPATCH_CONV_KERN(1); + break; + case 2: + DISPATCH_CONV_KERN(2); + break; + default: + megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", + param.filter_meta.stride[0]) + .c_str()); + break; + } + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN + + megdnn_assert(do_conv_fun); + constexpr int iw_expand = 8; + SmallVector ret_kerns; + int oh = param.osz[0]; + int ih2, iw2, oh2, ow2; + const int stride_h = static_cast(fm.stride[0]); + const int ic = fm.icpg; + get_rectified_size(param, ih2, iw2, oh2, ow2); + int oh_block = l2_block_helper(param.nr_threads, oh, + ic * iw2 * stride_h * iw_expand); + + auto do_pack_weight = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + bundle.set(kern_param.workspace_ptr); + pack_weight(bundle, kern_param, ncb_index); + }; + ret_kerns.push_back({do_pack_weight, {static_cast(group)}}); + CpuNDRange ncb_range = {static_cast(batch), + static_cast(group), + static_cast(div_ceil(oh, oh_block))}; + auto do_conv = [bundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + bundle.set(kern_param.workspace_ptr); + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, + ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + + return ret_kerns; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h new file mode 100644 index 0000000000000000000000000000000000000000..9eb0fd56f1969987bcd62fd8ff8930eb78ce5a2e --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h @@ -0,0 +1,151 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/arch.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace i8i8i16_direct_nchw_nchw44 { +/** + * @brief + * stride2 from [oc / 4, fh, fw, ic, 4] to [oc / 8, ic, fh, fw, 8] + * stride1 from [oc / 4, fh, fw, ic, 4] to [oc / 8, ic, fh, fw, 16] + * @param in_ptr + * @param dst_ptr + * @param oc + * @param kh + * @param kw + * @param ic + */ +template +inline void pack_weight_int8_nchw_nchw44(const int8_t* in_ptr, int8_t* dst_ptr, + const int oc, const int kh, + const int kw, const int ic); +template <> +inline void pack_weight_int8_nchw_nchw44<2>(const int8_t* in_ptr, + int8_t* dst_ptr, const int oc, + const int kh, const int kw, + const int ic) { + constexpr int in_pack_oc = 4; + constexpr int out_pack_oc = 8; + constexpr int out_pair = 2; + const int filter_size = kh * kw; + const int in_oc_stride = filter_size * ic; + const int oc_remain = oc % out_pack_oc; + const int oc_end = oc - oc_remain; + int32_t* pack_dst_ptr = (int32_t*)dst_ptr; + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += out_pack_oc) { + const int32_t* in_oc0_ptr = (int32_t*)(in_ptr + oc_idx * in_oc_stride); + const int32_t* in_oc1_ptr = + (int32_t*)(in_ptr + (oc_idx + in_pack_oc) * in_oc_stride); + + for (int filter_idx = 0; filter_idx < filter_size; ++filter_idx) { + for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { + int32_t temp0 = *in_oc0_ptr++; + int32_t temp1 = *in_oc1_ptr++; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 0] = temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 1] = temp1; + } + } + pack_dst_ptr += ic * filter_size * out_pair; + } + if (oc_remain > 0) { + const int32_t* in_oc0_ptr = (int32_t*)(in_ptr + oc_end * in_oc_stride); + + for (int filter_idx = 0; filter_idx < filter_size; ++filter_idx) { + for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { + int32_t temp0 = *in_oc0_ptr++; + int32_t temp1 = 0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 0] = temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 1] = temp1; + } + } + } +} +template <> +inline void pack_weight_int8_nchw_nchw44<1>(const int8_t* in_ptr, + int8_t* dst_ptr, const int oc, + const int kh, const int kw, + const int ic) { + constexpr int in_pack_oc = 4; + constexpr int out_pack_oc = 8; + constexpr int out_pair = 4; + const int filter_size = kh * kw; + const int in_oc_stride = filter_size * ic; + const int oc_remain = oc % out_pack_oc; + const int oc_end = oc - oc_remain; + int32_t* pack_dst_ptr = (int32_t*)dst_ptr; + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += out_pack_oc) { + const int32_t* in_oc0_ptr = (int32_t*)(in_ptr + oc_idx * in_oc_stride); + const int32_t* in_oc1_ptr = + (int32_t*)(in_ptr + (oc_idx + in_pack_oc) * in_oc_stride); + + for (int filter_idx = 0; filter_idx < filter_size; ++filter_idx) { + for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { + int32_t temp0 = *in_oc0_ptr++; + int32_t temp1 = *in_oc1_ptr++; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 0] = temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 1] = temp1; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 2] = temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 3] = temp1; + } + } + pack_dst_ptr += ic * filter_size * out_pair; + } + if (oc_remain > 0) { + const int32_t* in_oc0_ptr = (int32_t*)(in_ptr + oc_end * in_oc_stride); + + for (int filter_idx = 0; filter_idx < filter_size; ++filter_idx) { + for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { + int32_t temp0 = *in_oc0_ptr++; + int32_t temp1 = 0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 0] = temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 1] = temp1; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 2] = temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + + 3] = temp1; + } + } + } +} + +template +void conv_direct_i8i8i16_nchw_nchw44(const int8_t* src, const int8_t* filter, + const int16_t* bias, int8_t*, int16_t* dst, + const int oc, const int ic, const int ih, + const int iw, const int oh, + const int oh_block, const int ow, + const Op& op, const int, const int); + +} // namespace i8i8i16_direct_nchw_nchw44 + +} // namespace arm_common +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..1f33ed79a0bf0c835d821529540e16b0290ce3c5 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h @@ -0,0 +1,456 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/arch.h" +#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +namespace { +/** + * @brief kernel helper to do core computation + * + * @tparam src_idx src reg offset + * @tparam weight_idx weight reg offset + * @tparam c_dim first dim of c reg + * @tparam ow_block output width + * @tparam half_adv half calculation + * @tparam stride + * @tparam T + * @tparam T2 + * @tparam T3 + * @tparam T4 + */ +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); +}; + +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step] = vmlal_s8(c[0][step], vget_low_s8(weight[0][weight_idx]), \ + vget_low_s8(src[step + src_idx])); \ + c[0][step] = vmlal_high_s8(c[0][step], weight[0][weight_idx], \ + src[step + src_idx]); + + UNROLL_CALL_RAW(8, cb); + +#undef cb + } +}; + +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][step] = vmlal_s8(c[0][step], vget_low_s8(weight[0][weight_idx]), \ + vget_low_s8(src[step + src_idx])); + + UNROLL_CALL_RAW(8, cb); + +#undef cb + } +}; + +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { + //! for compatible with stride2 kernel, step, weight_idx, src_idx should + //! mul 2 +#define cb(step) \ + c[0][2 * step] = \ + vmlal_s8(c[0][2 * step], vget_low_s8(weight[0][2 * weight_idx]), \ + vget_low_s8(src[step + src_idx])); \ + c[0][2 * step + 1] = \ + vmlal_high_s8(c[0][2 * step + 1], weight[0][2 * weight_idx], \ + src[step + src_idx]); + + UNROLL_CALL_RAW(4, cb); + +#undef cb +#define cb(step) \ + c[0][2 * step] = \ + vmlal_high_s8(c[0][2 * step], weight[0][2 * weight_idx + 1], \ + src[step + src_idx]); \ + c[0][2 * step + 1] = vmlal_s8(c[0][2 * step + 1], \ + vget_low_s8(weight[0][2 * weight_idx + 1]), \ + vget_low_s8(src[step + 1 + src_idx])); + + UNROLL_CALL_RAW(4, cb); + +#undef cb + } +}; + +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { +#define cb(step) \ + c[0][2 * step] = \ + vmlal_s8(c[0][2 * step], vget_low_s8(weight[0][2 * weight_idx]), \ + vget_low_s8(src[step + src_idx])); \ + c[0][2 * step + 1] = \ + vmlal_high_s8(c[0][2 * step + 1], weight[0][2 * weight_idx], \ + src[step + src_idx]); + + UNROLL_CALL_RAW(4, cb); + +#undef cb + } +}; + +template +struct OCHelper { +public: + static const int val = 0; +}; +template <> +struct OCHelper<4> { +public: + static const int val = 1; +}; +template <> +struct OCHelper<8> { +public: + static const int val = 2; +}; + +template +MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl(c, src, weight); +}; +template +struct KerNeonXXs2NchwNchw44I8I8I16 { + static void impl(const int8_t* src_ptr_origin, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op&, const int remain_ow); +}; +template +struct KerNeonXXs2NchwNchw44I8I8I16 { + static void impl(const int8_t* src_ptr_origin, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op&, const int remain_ow) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 2; + constexpr int iw_expand = 8; + constexpr int simd_len = 16; + constexpr int filter_pack_oc = stride == 1 ? 16 : 8; + + const int ld_src_ic = ih * iw * iw_expand; + const int ld_src_iw = iw * iw_expand; + + constexpr int ld_weight_ic = filter_pack_oc * filter_size * filter_size; + constexpr int c_dim = 1; + constexpr int reg_pair = stride; + constexpr int div_pad = stride - 1; + constexpr int iw_reg = + ow_block + (filter_size - stride + div_pad) / reg_pair; + constexpr int filter_reg = (filter_size + div_pad) / reg_pair; + int16x8_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, 0); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + + int8x16_t src[iw_reg]; + int8x16_t weight[1][filter_reg]; +#define cb(step) \ + load_helper( \ + src, src_ptr + step * ld_src_iw, 0); \ + load_helper( \ + weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ + cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); + UNROLL_CALL_RAW(2, cb) +#undef cb + + src_ptr += ld_src_iw; + weight_ptr += ld_weight_ic; + } + constexpr int output_c_group = OCHelper::val; + store_oc4_ow8_remain_static( + c, dst_ptr, ld_dst_oc, remain_ow); + }; +}; + +template +struct KerNeonXXs2NchwNchw44I8I8I16 { + static void impl(const int8_t* src_ptr_origin, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op&, const int remain_ow) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 3; + constexpr int iw_expand = 8; + constexpr int simd_len = 16; + constexpr int filter_pack_oc = stride == 1 ? 16 : 8; + constexpr int c_dim = 1; + + const int ld_src_ic = ih * iw * iw_expand; + const int ld_src_iw = iw * iw_expand; + + constexpr int ld_weight_ic = filter_pack_oc * filter_size * filter_size; + constexpr int reg_pair = stride; + constexpr int div_pad = stride - 1; + constexpr int iw_reg = + ow_block + (filter_size - stride + div_pad) / reg_pair; + constexpr int filter_reg = (filter_size + div_pad) / reg_pair; + int16x8_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, 0); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + + int8x16_t src[iw_reg]; + int8x16_t weight[1][filter_reg]; +#define cb(step) \ + load_helper( \ + src, src_ptr + step * ld_src_iw, 0); \ + load_helper( \ + weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ + cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); \ + cal_helper<1, 1, c_dim, ow_block, true, stride>(c, src, weight); + UNROLL_CALL_RAW(3, cb) +#undef cb + + src_ptr += ld_src_iw; + weight_ptr += ld_weight_ic; + } + constexpr int output_c_group = OCHelper::val; + store_oc4_ow8_remain_static( + c, dst_ptr, ld_dst_oc, remain_ow); + }; +}; + +template +struct KerNeonXXs2NchwNchw44I8I8I16 { + static void impl(const int8_t* src_ptr_origin, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op&, const int remain_ow) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 5; + constexpr int iw_expand = 8; + constexpr int simd_len = 16; + constexpr int filter_pack_oc = stride == 1 ? 16 : 8; + constexpr int c_dim = 1; + + const int ld_src_ic = ih * iw * iw_expand; + const int ld_src_iw = iw * iw_expand; + + constexpr int ld_weight_ic = filter_pack_oc * filter_size * filter_size; + constexpr int reg_pair = stride; + constexpr int div_pad = stride - 1; + constexpr int iw_reg = + ow_block + (filter_size - stride + div_pad) / reg_pair; + constexpr int filter_reg = (filter_size + div_pad) / reg_pair; + int16x8_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, 0); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + + int8x16_t src[iw_reg]; + int8x16_t weight[1][filter_reg]; +#define cb(step) \ + load_helper( \ + src, src_ptr + step * ld_src_iw, 0); \ + load_helper( \ + weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ + cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); \ + cal_helper<1, 1, c_dim, ow_block, false, stride>(c, src, weight); \ + cal_helper<2, 2, c_dim, ow_block, true, stride>(c, src, weight); + UNROLL_CALL_RAW(5, cb) +#undef cb + + src_ptr += ld_src_iw; + weight_ptr += ld_weight_ic; + } + constexpr int output_c_group = OCHelper::val; + store_oc4_ow8_remain_static( + c, dst_ptr, ld_dst_oc, remain_ow); + }; +}; + +template +struct KerNeonXXs2NchwNchw44I8I8I16 { + static void impl(const int8_t* src_ptr_origin, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, + int iw, int ld_dst_oc, const Op&, const int remain_ow) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 7; + constexpr int iw_expand = 8; + constexpr int simd_len = 16; + constexpr int filter_pack_oc = stride == 1 ? 16 : 8; + constexpr int c_dim = 1; + + const int ld_src_ic = ih * iw * iw_expand; + const int ld_src_iw = iw * iw_expand; + + constexpr int ld_weight_ic = filter_pack_oc * filter_size * filter_size; + constexpr int reg_pair = stride; + constexpr int div_pad = stride - 1; + constexpr int iw_reg = + ow_block + (filter_size - stride + div_pad) / reg_pair; + constexpr int filter_reg = (filter_size + div_pad) / reg_pair; + int16x8_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, 0); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const int8_t* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + + int8x16_t src[iw_reg]; + int8x16_t weight[1][filter_reg]; +#define cb(step) \ + load_helper( \ + src, src_ptr + step * ld_src_iw, 0); \ + load_helper( \ + weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ + cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); \ + cal_helper<1, 1, c_dim, ow_block, false, stride>(c, src, weight); \ + cal_helper<2, 2, c_dim, ow_block, false, stride>(c, src, weight); \ + cal_helper<3, 3, c_dim, ow_block, true, stride>(c, src, weight); + UNROLL_CALL_RAW(7, cb) +#undef cb + + src_ptr += ld_src_iw; + weight_ptr += ld_weight_ic; + } + constexpr int output_c_group = OCHelper::val; + store_oc4_ow8_remain_static( + c, dst_ptr, ld_dst_oc, remain_ow); + }; +}; + +} // namespace + +template +void i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44( + const int8_t* src, const int8_t* filter, const int16_t* bias, int8_t*, + int16_t* dst, const int oc, const int ic, const int ih, const int iw, + const int oh, const int oh_block, const int ow, const Op& op, const int, + const int) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 1; + constexpr int big_oc_step = 8; + constexpr int oc_step = 4; + constexpr int ih_step = 1; + constexpr int oh_step = 1; + constexpr int ow_step = 8; + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int iw_expand = 8; + constexpr int weight_expand = stride == 1 ? 2 : 1; + + const int img_stride = oh * ow; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = (oc_idx * ic * fh * fw) * weight_expand; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * iw_expand; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44I8I8I16< + bias_mode, Op, filter_size, big_oc_step, stride, + ow_step>::impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op, ow_step); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * iw_expand; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + KerNeonXXs2NchwNchw44I8I8I16< + bias_mode, Op, filter_size, big_oc_step, stride, + ow_step>::impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op, ow_remain); + } + } + } + if (oc_remain > 0) { + int oc_idx = oc_end; + const int weight_offset = (oc_idx * ic * fh * fw) * weight_expand; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * iw_expand; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44I8I8I16< + bias_mode, Op, filter_size, oc_step, stride, + ow_step>::impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op, ow_step); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * iw_expand; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + KerNeonXXs2NchwNchw44I8I8I16< + bias_mode, Op, filter_size, oc_step, stride, + ow_step>::impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op, ow_remain); + } + } + } +} + +#define INSTANTIATION(stride, filter_size, bias_mode, Op) \ + template void i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44< \ + bias_mode, Op, filter_size, stride>( \ + const int8_t* src, const int8_t* filter, const int16_t* bias, \ + int8_t*, int16_t* dst, const int oc, const int ic, const int ih, \ + const int iw, const int oh, const int oh_block, const int ow, \ + const Op& op, const int, const int); + +#define FOR_OP(stride, filter, bias) \ + INSTANTIATION(stride, filter, bias, NoneOp) + +#define INSTANCE_CONV(filter, stride) \ + FOR_OP(stride, filter, BiasMode::NO_BIAS) \ + FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..adc7486c5918bfe01fad5cafc9a0e9a0a867aa24 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp @@ -0,0 +1,19 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" +INSTANCE_CONV(2, 1); +INSTANCE_CONV(3, 1); +INSTANCE_CONV(5, 1); +INSTANCE_CONV(7, 1); + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6733876925d5d4b41c60dfea0f1d27a011dcec06 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp @@ -0,0 +1,19 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl_s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h" +INSTANCE_CONV(2, 2); +INSTANCE_CONV(3, 2); +INSTANCE_CONV(5, 2); +INSTANCE_CONV(7, 2); + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/intrinsic_helper.h b/dnn/src/arm_common/conv_bias/intrinsic_helper.h index e27f8e2946b3b08de8201d11a76bff9398bec828..772d3f8e0fbe44606bf0ec620631ea887af38f3a 100644 --- a/dnn/src/arm_common/conv_bias/intrinsic_helper.h +++ b/dnn/src/arm_common/conv_bias/intrinsic_helper.h @@ -375,6 +375,89 @@ __ai void store_ocx_ow8_remain_static_dt(T& c, const Op& op, T2 dst_ptr, StoreOcxOw8Remain::impl(c, op, dst_ptr, ld_dst_oc); } +////////////////////Store_OCX_OW8_Remain///////////////////////// +template +struct StoreOc4Ow8Remain { + static __ai void impl(T& c, T2 dst_ptr, int ld_dst_oc, const int ow_remain); +}; + +#define cb(step) \ + vst1q_lane_s64((int64_t*)(dst_ptr + step * 4), \ + vreinterpretq_s64_s16(c[0][step]), 0); \ + vst1q_lane_s64((int64_t*)(dst_ptr + step * 4 + ld_dst_oc), \ + vreinterpretq_s64_s16(c[0][step]), 1); + +#define cb2(step) \ + vst1q_lane_s64((int64_t*)(dst_ptr + step * 4), \ + vreinterpretq_s64_s16(c[0][step]), 0); + +#define cb_case(step) \ + case step: \ + UNROLL_CALL_RAW(step, cb); \ + break; + +#define cb_case2(step) \ + case step: \ + UNROLL_CALL_RAW(step, cb2); \ + break; +template +struct StoreOc4Ow8Remain<1, 8, 2, 2, T, T2, T3> { + static __ai void impl(T& c, T2 dst_ptr, int ld_dst_oc, + const int ow_remain) { + if (ow_remain == 8) { + UNROLL_CALL_RAW(8, cb) + } else { + switch (ow_remain) { + cb_case(7); + cb_case(6); + cb_case(5); + cb_case(4); + cb_case(3); + cb_case(2); + cb_case(1); + + default: + break; + } + } + } +}; +template +struct StoreOc4Ow8Remain<1, 8, 2, 1, T, T2, T3> { + static __ai void impl(T& c, T2 dst_ptr, int, const int ow_remain) { + if (ow_remain == 8) { + UNROLL_CALL_RAW(8, cb2) + } else { + switch (ow_remain) { + cb_case2(7); + cb_case2(6); + cb_case2(5); + cb_case2(4); + cb_case2(3); + cb_case2(2); + cb_case2(1); + + default: + break; + } + } + } +}; + +#undef cb +#undef cb2 +#undef cb_case +#undef cb_case2 + +template +__ai void store_oc4_ow8_remain_static(T& c, T2 dst_ptr, const int ld_dst_oc, + const int ow_remain) { + StoreOc4Ow8Remain::impl( + c, dst_ptr, ld_dst_oc, ow_remain); +} + ////////////////////Store_OC8_OW8_Remain///////////////////////// template @@ -548,13 +631,18 @@ __ai float32x4_t neon_vdupq_n(float val) { __ai int32x4_t neon_vdupq_n(int val) { return vdupq_n_s32(val); } +__ai int16x8_t neon_vdupq_n(int16_t val) { + return vdupq_n_s16(val); +} __ai float32x4_t neon_vld1q(const float* ptr) { return vld1q_f32(ptr); } - __ai int32x4_t neon_vld1q(const int* ptr) { return vld1q_s32(ptr); } +__ai int16x8_t neon_vld1q(const int16_t* ptr) { + return vld1q_s16(ptr); +} template struct InitOcxOw8 { @@ -725,6 +813,39 @@ __ai void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) { } /////////////////////////////////////// +static inline void memcpy_s8_dup(int8_t* outptr, const int8_t* inptr, + int count) { + constexpr int expand = 8; + for (; count >= 8; count -= 8) { + int8x8_t in = vld1_s8(inptr); + int8x8_t in0 = vdup_lane_s8(in, 0); + int8x8_t in1 = vdup_lane_s8(in, 1); + int8x8_t in2 = vdup_lane_s8(in, 2); + int8x8_t in3 = vdup_lane_s8(in, 3); + int8x8_t in4 = vdup_lane_s8(in, 4); + int8x8_t in5 = vdup_lane_s8(in, 5); + int8x8_t in6 = vdup_lane_s8(in, 6); + int8x8_t in7 = vdup_lane_s8(in, 7); + + vst1_s8(outptr + 0 * 8, in0); + vst1_s8(outptr + 1 * 8, in1); + vst1_s8(outptr + 2 * 8, in2); + vst1_s8(outptr + 3 * 8, in3); + vst1_s8(outptr + 4 * 8, in4); + vst1_s8(outptr + 5 * 8, in5); + vst1_s8(outptr + 6 * 8, in6); + vst1_s8(outptr + 7 * 8, in7); + + inptr += 8; + outptr += 8 * expand; + } + for (; count > 0; --count) { + int8x8_t in0 = vld1_dup_s8(inptr++); + vst1_s8(outptr, in0); + outptr += 1 * expand; + } +} + } // namespace } // namespace megdnn #undef __ai diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index c7192803a0576b25d8794ad995e148dcc1045f80..b71588c524b4d6ff63407b67fa3d7c3a1b171b3e 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -71,6 +71,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoI8x8x16Direct i8x8x16_direct; AlgoI8x8x16Stride2 i8x8x16_stride2; AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2; + AlgoI8x8x16DirectNCHWNCHW44 i8x8x16_nchw_nchw44; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC AlgoF16Direct f16_direct; AlgoF16DirectStride1 f16_direct_stride1; @@ -107,6 +108,7 @@ public: direct_algos.emplace_back(&i8x8x16_direct); direct_algos.emplace_back(&i8x8x16_stride2_filter2); direct_algos.emplace_back(&i8x8x16_stride2); + direct_algos.emplace_back(&i8x8x16_nchw_nchw44); direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); direct_algos.emplace_back(&f32_chanel_wise_nchw44); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 780c07b90b025738e58a4dee33022f7665d98ecb..4a6144ee3b9c08d5f0cbc7493151aaa67e2cbfe0 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -81,6 +81,7 @@ private: class AlgoI8x8x16Direct; class AlgoI8x8x16Stride2; class AlgoI8x8x16Stride2Filter2; + class AlgoI8x8x16DirectNCHWNCHW44; class AlgoS8WinogradF23_8x8; class AlgoS8CF32WinogradF23_4x4_NCHW44; class AlgoS8WinogradF23_8x8_NCHW44; diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index ceb2920289fb466b0250c4fce1006d50f508331d..7f7b72f0f9f8fa687659e95fc52fa8909f5658a1 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -9,7 +9,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/fallback/convolution/opr_impl.h" #include "src/common/algo_chooser.h" #include "src/common/metahelper.h" #include "src/common/opr_delegate.h" @@ -19,6 +18,7 @@ #include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" #include "src/fallback/conv_bias/im2col/algos.h" #include "src/fallback/conv_bias/opr_impl.h" +#include "src/fallback/convolution/opr_impl.h" #include "src/naive/convolution/algorithms.h" #include "src/naive/handle.h" @@ -479,7 +479,8 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, //! four format of weight layout //! 1. {oc/4, ic/4, fh, fw, 4, 4}, //! 2. {g, oc/4, ic/4, fh, fw, 4, 4}, - //! 3. {g/4, fh, fw, 1, 1, 4}, 4. {oc/4, fh, fw, ic, 4} + //! 3. {g/4, fh, fw, 1, 1, 4}, + //! 4. {oc/4, fh, fw, ic, 4} megdnn_assert((icpg % 4 == 0 && ocpg % 4 == 0) || (group % 4 == 0 && icpg == 1 && ocpg == 1 && pack_group_size > 1) || diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 7afa3408b399e980129bbb137ec51e8581bad57c..dc1eebb796f79d6e63d6f05f0fc15955c62c9a35 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -116,7 +116,8 @@ CB_TEST(H_SWISH); #if MEGDNN_WITH_BENCHMARK static void benchmark_convbias(Handle* handle, std::string int_name, - std::string float_name, bool is_fp32 = false) { + std::string float_name, bool is_fp32 = false, + bool is_8x8x16 = false) { constexpr size_t RUNS = 30; Benchmarker benchmarker_int(handle); @@ -142,6 +143,13 @@ static void benchmark_convbias(Handle* handle, std::string int_name, .set_dtype(2, dtype::Float32()) .set_dtype(4, dtype::Float32()) .set_display(false); + } else if (is_8x8x16) { + benchmarker_nchw44.set_times(RUNS) + .set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int16()) + .set_dtype(4, dtype::Int16()) + .set_display(false); } else { benchmarker_nchw44.set_times(RUNS) .set_dtype(0, dtype::QuantizedS8(2.5)) @@ -163,6 +171,9 @@ static void benchmark_convbias(Handle* handle, std::string int_name, size_t FS, size_t stride, bool input_nchw = false) { param::ConvBias param; param.nonlineMode = param::ConvBias::NonlineMode::RELU; + if (is_8x8x16) { + param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; + } param.stride_h = stride; param.stride_w = stride; @@ -235,6 +246,7 @@ static void benchmark_convbias(Handle* handle, std::string int_name, run(1, 512, 512, 7, 7, 3, 1, false); } else { run(1, 1, 4, 112, 112, 2, 2, true); + run(1, 3, 8, 224, 224, 3, 2, true); run(1, 3, 32, 224, 224, 3, 2, true); run(1, 3, 32, 224, 224, 5, 2, true); run(1, 3, 64, 224, 224, 7, 2, true); @@ -271,11 +283,15 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { "IM2COLMATMUL:AARCH64_F32K8X12X1:192", true); benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", "IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); + benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", + "IM2COLMATMUL:AARCH64_F32K8X12X1:192", false, true); #else benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", "IM2COLMATMUL:ARMV7_F32:192", true); benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", "IM2COLMATMUL:ARMV7_F32:192", false); + benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", + "IM2COLMATMUL:ARMV7_F32:192", false, true); #endif } TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index f4f35b508def941e7d813642ee313cecf65b199b..91336e86cd0e38c68c0d326e9c912dd0343081ea 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -449,7 +449,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) { get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(), "I8816STRD2"); } - +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S2) { + checker_conv_bias_int8x8x16( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true, + true), + handle(), "I8816_CONV_NCHW_NCHW44"); +} +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S1) { + checker_conv_bias_int8x8x16( + get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true, + true), + handle(), "I8816_CONV_NCHW_NCHW44"); +} /**********************************algo 8-8-32 direct************************/ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) { checker_conv_bias_int8x8x32_multi(