diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h index 2a63fad089431e31d1403c6cc0265f25b1d68547..eb92bb026b29e70bd5c4391dc603c0508719711a 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -178,6 +178,23 @@ public: const NCBKernSizeParam& param) const override; }; +class ConvBiasImpl::AlgoF32DirectStride2NCHW44 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + +public: + AlgoF32DirectStride2NCHW44() {} + bool is_reproducible() const override { return true; } + const char* name() const override { return "F32_CONV_NCHW44_DIRECT_S2"; } + bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; bool m_large_group; diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_algo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..627c6ebcdb63cce7d3ab408c33169527f6e1a9dc --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_algo.cpp @@ -0,0 +1,281 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. + */ + +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/fp32/algos.h" +#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; +using conv_fun = std::function; +MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride2) +namespace { +// block_helper is used to calculate oh block size +static inline int block_helper(const int nthread, const int amount, + const int size_per_unit) { + constexpr int l2_cache_size = 256 * 1024; + const int block_per_thread = div_ceil(amount, nthread); + const int best_block = std::min( + amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); + const int max_block_num = div_ceil(block_per_thread, best_block); + const int min_block_num = std::max(max_block_num - 1, 1); + const int max_block = div_ceil(block_per_thread, max_block_num); + const int min_block = div_ceil(block_per_thread, min_block_num); + const int max_loss = std::abs(max_block_num * max_block - block_per_thread); + const int min_loss = std::abs(min_block_num * min_block - block_per_thread); + int block = max_loss > min_loss ? min_block : max_block; + return block; +} +static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, + const int iw2) { + // border_size is used to avoid read illegal memory + int border_size = 64 * 2; + return ic * ih2 * iw2 * sizeof(float) + border_size; +} +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, + int& iw2, int& oh2, int& ow2) { + int ic = param.filter_meta.icpg; + int iw = param.isz[1]; + int oh = param.osz[0]; + int ow = param.osz[1]; + + oh2 = oh; + ow2 = ow; + constexpr int cacheline = 64 / sizeof(float); + int block_oh = + block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2); + auto&& fm = param.filter_meta; + const int stride_h = static_cast(fm.stride[0]); + const int filter_h = static_cast(fm.spatial[0]); + ih2 = block_oh * stride_h + filter_h - stride_h; + iw2 = round_up(iw + 2 * static_cast(fm.padding[1]), cacheline); +} + +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + int ic = fm.icpg; + int ih2, iw2, oh2, ow2; + get_rectified_size(param, ih2, iw2, oh2, ow2); + + size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); + return {nullptr, {src_size * param.nr_threads}}; +}; + +template +static void do_conv_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange&, const CpuNDRange&) { + const int oh = kern_param.osz[0]; + const int ow = kern_param.osz[1]; + const int fh = kern_param.filter_meta.spatial[0]; + const int fw = kern_param.filter_meta.spatial[1]; + const int ic = kern_param.filter_meta.icpg; + const int oc = kern_param.filter_meta.ocpg; + const int ih = kern_param.isz[0]; + const int iw = kern_param.isz[1]; + const int stride_h = kern_param.filter_meta.stride[0]; + const int ph = kern_param.filter_meta.padding[0]; + const int pw = kern_param.filter_meta.padding[1]; + int ih2 = 0; + int iw2 = 0; + int oh2 = 0; + int ow2 = 0; + get_rectified_size(kern_param, ih2, iw2, oh2, ow2); + bundle.set(kern_param.workspace_ptr); + + constexpr int pack_c = 4; + const int batch_id = ncb_index.ndrange_id[0]; + const int group_id = ncb_index.ndrange_id[1]; + constexpr int oc_idx = 0; + int oc_block = oc; + int oh_block = block_helper(kern_param.nr_threads, oh2, + ic * iw * sizeof(float) * 2); + const int oh_idx = ncb_index.ndrange_id[2]; + const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); + const int ih_real = oh_block_real * stride_h + fh - stride_h; + const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0); + const int src_bottom_pad = std::max( + (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, + 0); + const int remain_right_pad = std::max(iw2 - iw - pw, 0); + const int src_offset = + std::max(oh_idx * oh_block * stride_h - ph, 0) * iw * pack_c; + const float* origin_sptr = static_cast(kern_param.src( + batch_id, group_id, 0, 1, 1)) + + src_offset; + const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); + float* sptr = reinterpret_cast((int8_t*)bundle.get(0) + + ncb_index.thread_id * src_size); + + conv_bias::pack_src_fp32_nchw44_stride2( + sptr, origin_sptr, ph, pw, remain_right_pad, + ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, + src_bottom_pad, ic, ih * iw); + + const float* fptr = + kern_param.filter(group_id) + oc_idx * fh * fw * ic; + float_t* dst = kern_param.dst(batch_id, group_id) + + oh_idx * oh_block * ow * pack_c; + const int bias_offset = bias_mode == BiasMode::BIAS + ? oh_idx * oh_block * ow * pack_c + : oc_idx; + const float* bptr = + kern_param.bias(batch_id, group_id) + bias_offset; + + Op op; +#define KERN1_NCHW44_CONV(filter) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw44< \ + \ + bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \ + ih_real, iw2, oh, oh_block_real, ow, op, ph, pw) + + DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); +#undef KERN1_NCHW44_CONV +} + +} // namespace + +/* ===================== stride2 algo ===================== */ +bool ConvBiasImpl::AlgoF32DirectStride2NCHW44::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + auto&& fm = param.filter_meta; + auto fh = fm.spatial[0]; + int oc = fm.ocpg; + bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + (param.dst_type.enumv() == DTypeEnum::Float32))) && + (fm.format == param::Convolution::Format::NCHW44); + bool ok_src_dst = (oc % 4 == 0 && oc >= 4); + bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && + (fh == 2 || fh == 3 || fh == 5 || fh == 7); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2; + bool ok_conv = !fm.should_flip; + bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; + return avaible; +} + +size_t ConvBiasImpl::AlgoF32DirectStride2NCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoF32DirectStride2NCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + const int batch = param.n; + const int group = fm.group; + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; + // NOTE: remain_w is not used to gen hash of midout for compatible with +// shape runtime +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw44_stride2, \ + midout_iv(#filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(filter, bias_mode, NoneOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(filter, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(filter, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + case BiasMode::BIAS: \ + GET_OP_PARAM(filter, BiasMode::BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN + + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + WorkspaceBundle bundle = wbundle; + int oh = param.osz[0]; + int ic = param.filter_meta.icpg; + int iw = param.isz[1]; + int oh_block = + block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2); + CpuNDRange ncb_range = {static_cast(batch), + static_cast(group), + static_cast(div_ceil(oh, oh_block))}; + auto do_conv = [bundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, + ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + return ret_kerns; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7f607ca09fa9e67c4c39e1a15597f652c0c15b71 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp @@ -0,0 +1,748 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +namespace { + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight); +}; + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step, lane) \ + c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 8]); \ + c[1][step] = Func::template impl(c[1][step], weight[1][lane], \ + src[(step + src_idx) % 8]); + + UNROLL_CALL_RAW(8, cb, 0); + UNROLL_CALL_RAW(8, cb, 1); + UNROLL_CALL_RAW(8, cb, 2); + UNROLL_CALL_RAW(8, cb, 3); +#undef cb + } +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step, lane) \ + c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 4]); \ + c[1][step] = Func::template impl(c[1][step], weight[1][lane], \ + src[(step + src_idx) % 4]); + + UNROLL_CALL_RAW(4, cb, 0); + UNROLL_CALL_RAW(4, cb, 1); + UNROLL_CALL_RAW(4, cb, 2); + UNROLL_CALL_RAW(4, cb, 3); +#undef cb + } +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step, lane) \ + c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 8]); + + UNROLL_CALL_RAW(8, cb, 0); + UNROLL_CALL_RAW(8, cb, 1); + UNROLL_CALL_RAW(8, cb, 2); + UNROLL_CALL_RAW(8, cb, 3); +#undef cb + } +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step, lane) \ + c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 4]); + + UNROLL_CALL_RAW(4, cb, 0); + UNROLL_CALL_RAW(4, cb, 1); + UNROLL_CALL_RAW(4, cb, 2); + UNROLL_CALL_RAW(4, cb, 3); +#undef cb + } +}; + +template +inline void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl(c, src, weight); +}; +template +struct OCHelper { +public: + static const int val = -1; +}; + +template <> +struct OCHelper<4> { +public: + static const int val = 1; +}; +#if MEGDNN_AARCH64 +template <> +struct OCHelper<8> { +public: + static const int val = 2; +}; +#endif +/** + * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel + * */ +template +struct KerNeonXXs2Nchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op, + const float32_t* src_ptr_odd); +}; + +template +struct KerNeonXXs2Nchw44FP32 { + static void impl(const float32_t* src_ptr_origin, + const float32_t* weight_ptr, const float32_t* bias_ptr, + float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, + const Op& op, const float32_t* src_ptr_odd_origin) { + constexpr int loop_ic_step = 4; + constexpr int filter_size = 2; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + + constexpr int ld_weight = oc_step * oc_step; + const int ld_bias = bias_mode == BiasMode::BIAS ? ld_dst_oc : oc_step; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_fh = oc_step * oc_step * filter_size; + const int ld_src_ic = ih * iw; + const int ld_src_iw = iw * oc_step; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; + + float32x4_t src[ow_block]; + float32x4_t weight[c_dim][4]; + /////////row 0///////////// + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, + ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + + load_helper(src, src_ptr_odd, + 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + src_ptr += ld_src_iw; + src_ptr_odd += ld_src_iw; + weight_ptr += ld_weight_fh; + /////////row 1///////////// + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, + ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + + load_helper(src, src_ptr_odd, + 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + src_ptr += ld_src_iw; + src_ptr_odd += ld_src_iw; + weight_ptr += ld_weight_fh; + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +template +struct KerNeonXXs2Nchw44FP32 { + static void impl(const float32_t* src_ptr_origin, + const float32_t* weight_ptr, const float32_t* bias_ptr, + float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, + const Op& op, const float32_t* src_ptr_odd_origin) { + constexpr int loop_ic_step = 4; + constexpr int filter_size = 3; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + + constexpr int ld_weight = oc_step * oc_step; + const int ld_bias = bias_mode == BiasMode::BIAS ? ld_dst_oc : oc_step; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_fh = oc_step * oc_step * filter_size; + const int ld_src_ic = ih * iw; + const int ld_src_iw = iw * oc_step; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, ld_bias); + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; + + float32x4_t src[ow_block]; + float32x4_t weight[c_dim][4]; + /////////row 0///////////// + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, + ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + + src[0] = vld1q_f32(src_ptr + ow_block * simd_len); + load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + + load_helper(src, src_ptr_odd, + 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + src_ptr += ld_src_iw; + src_ptr_odd += ld_src_iw; + weight_ptr += ld_weight_fh; + /////////row 1///////////// + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, + ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + src[0] = vld1q_f32(src_ptr + ow_block * simd_len); + load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + + load_helper(src, src_ptr_odd, + 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + src_ptr += ld_src_iw; + src_ptr_odd += ld_src_iw; + weight_ptr += ld_weight_fh; + //////////row 2///////////// + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, + ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + src[0] = vld1q_f32(src_ptr + ow_block * simd_len); + + load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + + load_helper(src, src_ptr_odd, + 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); + src_ptr += ld_src_iw; + src_ptr_odd += ld_src_iw; + weight_ptr += ld_weight_fh; + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +template +struct KerNeonXXs2Nchw44FP32 { + static void impl(const float32_t* src_ptr_origin, + const float32_t* weight_ptr, const float32_t* bias_ptr, + float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, + const Op& op, const float32_t* src_ptr_odd_origin) { + constexpr int loop_ic_step = 4; + constexpr int filter_size = 5; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + + constexpr int ld_weight = oc_step * oc_step; + const int ld_bias = bias_mode == BiasMode::BIAS ? ld_dst_oc : oc_step; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_fh = oc_step * oc_step * filter_size; + const int ld_src_ic = ih * iw; + const int ld_src_iw = iw * oc_step; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; + + for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { + float32x4_t src[ow_block]; + float32x4_t weight[c_dim][4]; + // even element + load_helper(src, src_ptr, + 0); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, + ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[0] = vld1q_f32(src_ptr + ow_block * simd_len); + load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); + load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + // odd element + load_helper( + src, src_ptr_odd, 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); + load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src_ptr += ld_src_iw; + src_ptr_odd += ld_src_iw; + weight_ptr += ld_weight_fh; + } + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +/** + * for kernel[7], calculate sequence is kernel[0], kernel[2], kernel[4], + * kernel[6], kernel[1], kernel[3], kernel[5] + * src is packed like 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9 + **/ +template +struct KerNeonXXs2Nchw44FP32 { + static void impl(const float32_t* src_ptr_origin, + const float32_t* weight_ptr, const float32_t* bias_ptr, + float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, + const Op& op, const float32_t* src_ptr_odd_origin) { + constexpr int loop_ic_step = 4; + constexpr int filter_size = 7; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + + constexpr int ld_weight = oc_step * oc_step; + const int ld_bias = bias_mode == BiasMode::BIAS ? ld_dst_oc : oc_step; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_fh = oc_step * oc_step * filter_size; + const int ld_src_ic = ih * iw; + const int ld_src_iw = iw * oc_step; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; + + for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { + float32x4_t src[ow_block]; + float32x4_t weight[c_dim][4]; + // even element + load_helper(src, src_ptr, + 0); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, + ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[0] = vld1q_f32(src_ptr + ow_block * simd_len); + load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); + load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); + load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + // odd element + load_helper( + src, src_ptr_odd, 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); + load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); + load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); + cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src_ptr += ld_src_iw; + src_ptr_odd += ld_src_iw; + weight_ptr += ld_weight_fh; + } + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +} // namespace +namespace { + +inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr, + const int odd_start, const int src_idx, + const int iw_idx) { + constexpr int ic_step = 4; + const int src_offset = src_idx * ic_step; + const int even_offset = iw_idx / 2 * ic_step; + const int odd_offset = (odd_start + iw_idx / 2) * ic_step; + float32x4_t temp[8]; + temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); + temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); + temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); + temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); + temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); + temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); + temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); + temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); + vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]); + vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]); + vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]); + vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]); + vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]); + vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]); + vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]); + vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]); +} +void odd_even_split_iw8_odd(float* sptr_base, const float* sptr, + const int odd_start, const int src_idx, + const int iw_idx) { + constexpr int ic_step = 4; + const int src_offset = src_idx * ic_step; + const int even_offset = (iw_idx + 1) / 2 * ic_step; + const int odd_offset = (odd_start + iw_idx / 2) * ic_step; + float32x4_t temp[8]; + temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); + temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); + temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); + temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); + temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); + temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); + temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); + temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); + vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]); + vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]); + vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]); + vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]); + vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]); + vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]); + vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]); + vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]); +} +} // namespace + +void conv_bias::pack_src_fp32_nchw44_stride2( + float* sptr_base, const float* sptr_origin, const int ph, const int pw, + const int pad_right, const int ih, const int iw, const int iw2, + const int pad_top, const int pad_bottom, const int ic, + const int ic_stride) { + constexpr int ic_step = 4; + int odd_start = megdnn::div_ceil(iw2, 2); + float32x4_t zero_v = vdupq_n_f32(0.f); + MEGDNN_MARK_USED_VAR(ph); + bool even_start = pw % 2 == 0; + rep_step(ic_idx, ic, ic_step) { + const float* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step); + sptr_base += iw2 * pad_top * ic_step; + rep(ih_idx, ih) { + int iw_idx = 0; + rep(idx, pw) { + if (iw_idx % 2 == 0) { + vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); + } else { + vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, + zero_v); + } + ++iw_idx; + } + int src_idx = 0; + if (even_start) { + for (; src_idx + 7 < iw; src_idx += 8) { + odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx, + iw_idx); + iw_idx += 8; + } + } else { + for (; src_idx + 7 < iw; src_idx += 8) { + odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx, + iw_idx); + iw_idx += 8; + } + } + for (; src_idx < iw; ++src_idx) { + if (iw_idx % 2 == 0) { + vst1q_f32(sptr_base + iw_idx / 2 * ic_step, + vld1q_f32(sptr + src_idx * ic_step)); + } else { + vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, + vld1q_f32(sptr + src_idx * ic_step)); + } + ++iw_idx; + } + rep(idx, pad_right) { + if (iw_idx % 2 == 0) { + vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); + } else { + vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, + zero_v); + } + ++iw_idx; + } + sptr_base += iw2 * ic_step; + sptr += iw * ic_step; + } + memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step); + sptr_base += iw2 * pad_bottom * ic_step; + } +} + +template +static void conv_direct_stride2_fp32_nchw44( + const float32_t* src, const float32_t* filter, const float32_t* bias, + float32_t*, float32_t* dst, const int oc, const int ic, const int ih, + const int iw, const int oh, const int oh_block, const int ow, + const Op& op, const int, const int) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; +#if MEGDNN_ARMV7 + constexpr int big_oc_step = 4; +#else + constexpr int big_oc_step = 8; +#endif + constexpr int oc_step = 4; + constexpr int ih_step = 1; + constexpr int oh_step = 1; + constexpr int ow_step = 8; + constexpr int stride_h = 2; + constexpr int stride_w = 2; + + const int img_stride = oh * ow; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + const int odd_start = div_ceil(iw, 2); + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs2Nchw44FP32::impl; \ + kern_small_oc_remain = \ + KerNeonXXs2Nchw44FP32::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = (oh_idx * stride_h * iw + + ow_idx / 2 * stride_w * ih_step) * + ic_step; + const int src_offset_odd = + (oh_idx * stride_h * iw + + ow_idx / 2 * stride_w * ih_step + odd_start) * + ic_step; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + const int bias_offset = + bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; + KerNeonXXs2Nchw44FP32::impl(src + src_offset, + filter + weight_offset, + bias + bias_offset, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op, + src + src_offset_odd); + } + if (ow_remain > 0) { + const int src_offset = (oh_idx * stride_h * iw + + ow_end / 2 * stride_w * ih_step) * + ic_step; + const int src_offset_odd = + (oh_idx * stride_h * iw + + ow_end / 2 * stride_w * ih_step + odd_start) * + ic_step; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + const int bias_offset = + bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + bias_offset, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op, src + src_offset_odd); + } + } + } + if (oc_remain > 0) { + int oc_idx = oc_end; + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = (oh_idx * stride_h * iw + + ow_idx / 2 * stride_w * ih_step) * + ic_step; + const int src_offset_odd = + (oh_idx * stride_h * iw + + ow_idx / 2 * stride_w * ih_step + odd_start) * + ic_step; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + const int bias_offset = + bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; + KerNeonXXs2Nchw44FP32::impl(src + src_offset, + filter + weight_offset, + bias + bias_offset, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op, + src + src_offset_odd); + } + if (ow_remain > 0) { + const int src_offset = (oh_idx * stride_h * iw + + ow_end / 2 * stride_w * ih_step) * + ic_step; + const int src_offset_odd = + (oh_idx * stride_h * iw + + ow_end / 2 * stride_w * ih_step + odd_start) * + ic_step; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + const int bias_offset = + bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + bias_offset, dst + dst_offset, ic, + ih, iw, ld_dst_oc, op, + src + src_offset_odd); + } + } + } +} + +#define CONSTRUCT_FUNC(filter_size) \ + template \ + void conv_bias:: \ + conv_direct_stride2_##filter_size##x##filter_size##_fp32_nchw44( \ + const float32_t* src, const float32_t* filter, \ + const float32_t* bias, float32_t* temp, float32_t* dst, \ + const int oc, const int ic, const int ih, const int iw, \ + const int oh, const int oh_block, const int ow, \ + const Op& op, const int ph, const int pw) { \ + conv_direct_stride2_fp32_nchw44( \ + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \ + ow, op, ph, pw); \ + } +CONSTRUCT_FUNC(2); +CONSTRUCT_FUNC(3); +CONSTRUCT_FUNC(5); +CONSTRUCT_FUNC(7); +#undef CONSTRUCT_FUNC + +#define INSTANTIATION(stride, i, bias, Op) \ + template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \ + bias, Op>(const float32_t*, const float32_t*, const float32_t*, \ + float32_t*, float32_t*, const int, const int, const int, \ + const int, const int, const int, const int, const Op&, \ + const int, const int); + +#define FOR_OP(stride, i, bias) \ + INSTANTIATION(stride, i, bias, NoneOp) \ + INSTANTIATION(stride, i, bias, ReluOp) \ + INSTANTIATION(stride, i, bias, HSwishOp) + +#define FOR_BIAS(stride, i) \ + FOR_OP(stride, i, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + FOR_OP(stride, i, BiasMode::BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) \ + FOR_BIAS(stride, 7) + +FOR_FILTER(stride2) + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef FOR_REMAIN +#undef INSTANTIATION diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h new file mode 100644 index 0000000000000000000000000000000000000000..a0d852a2c37a75841ade11f7eb675a46e939e21f --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" +namespace megdnn { +namespace arm_common { +namespace conv_bias { +#define KERN(stride, i, layout) \ + template \ + void conv_direct_##stride##_##i##x##i##_fp32_##layout( \ + const float* src, const float* filter, const float* bias, \ + float* temp, float* dst, const int oc, const int ic, const int ih, \ + const int iw, const int oh, const int oh_block, const int ow, \ + const Op& op, const int ph, const int pw); + +KERN(stride2, 2, nchw44) +KERN(stride2, 3, nchw44) +KERN(stride2, 5, nchw44) +KERN(stride2, 7, nchw44) +#undef KERN + +void pack_src_fp32_nchw44_stride2(float* sptr_base, const float* sptr_origin, + const int ph, const int pw, + const int pad_right, const int ih, + const int iw, const int iw2, + const int pad_top, const int pad_bottom, + const int ic, const int ic_stride); +} // namespace conv_bias +} // namespace arm_common +} // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp index ba3b62721472864046a9a8f5b5196d2b8d146e4e..501c1dcd76324f5d740febecd7bcd5a0fa497d3b 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp @@ -111,7 +111,7 @@ struct KerNeonXXs2NchwNchw44FP32 { const int ld_src_ic = ih * iw; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { float32x4_t src[src_reg_size]; @@ -157,7 +157,7 @@ struct KerNeonXXs2NchwNchw44FP32 { const int ld_src_ic = ih * iw; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { float32x4_t src[src_reg_size]; @@ -201,7 +201,7 @@ struct KerNeonXXs2NchwNchw44FP32 { const int ld_src_ic = ih * iw; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { float32x4_t src[src_reg_size]; @@ -258,7 +258,7 @@ struct KerNeonXXs2NchwNchw44FP32 { const int ld_src_ic = ih * iw; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { float32x4_t src[src_reg_size]; diff --git a/dnn/src/arm_common/conv_bias/intrinsic_helper.h b/dnn/src/arm_common/conv_bias/intrinsic_helper.h index a5a6c6c5ef9d63a90407ae747735c36ba7e776c0..a43d47cb9c9d0cf84eae215e7879b5748dbbed0d 100644 --- a/dnn/src/arm_common/conv_bias/intrinsic_helper.h +++ b/dnn/src/arm_common/conv_bias/intrinsic_helper.h @@ -194,7 +194,20 @@ struct StoreOcxOw8Remain<2, 0, Op, T> { op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24); } }; +template +struct StoreOcxOw8Remain<2, 8, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + op({{c[0][4], c[0][5]}}, dst_ptr + 16); + op({{c[0][6], c[0][7]}}, dst_ptr + 24); + op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); + op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); + op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); + op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24); + } +}; template struct StoreOcxOw8Remain<2, 7, Op, T> { static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { @@ -277,6 +290,15 @@ struct StoreOcxOw8Remain<1, 0, Op, T> { op({{c[0][6], c[0][7]}}, dst_ptr + 24); } }; +template +struct StoreOcxOw8Remain<1, 8, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + op({{c[0][4], c[0][5]}}, dst_ptr + 16); + op({{c[0][6], c[0][7]}}, dst_ptr + 24); + } +}; template struct StoreOcxOw8Remain<1, 7, Op, T> { @@ -499,46 +521,127 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr, } } /////////////////////////init_ocx_ow8//////////////////// -template +template struct InitOcxOw8 { static void impl(T& c, T2 bias_ptr, int oc_step); }; -template -struct InitOcxOw8<2, bias_mode, T, T2> { +template +struct InitOcxOw8<2, BiasMode::NO_BIAS, 8, T, T2> { + static void impl(T& c, const float32_t*, int) { +#define BAIS_INIT(step) \ + c[0][step] = vdupq_n_f32(0); \ + c[1][step] = vdupq_n_f32(0); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } +}; +template +struct InitOcxOw8<2, BiasMode::NO_BIAS, 4, T, T2> { + static void impl(T& c, const float32_t*, int) { +#define BAIS_INIT(step) \ + c[0][step] = vdupq_n_f32(0); \ + c[1][step] = vdupq_n_f32(0); + UNROLL_CALL_RAW(4, BAIS_INIT); +#undef BAIS_INIT + } +}; +template +struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { static void impl(T& c, const float32_t* bias_ptr, int oc_step) { - if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { #define BAIS_INIT(step) \ c[0][step] = vld1q_f32(bias_ptr); \ c[1][step] = vld1q_f32(bias_ptr + oc_step); - UNROLL_CALL_RAW(8, BAIS_INIT); + UNROLL_CALL_RAW(8, BAIS_INIT); #undef BAIS_INIT - } else { -#define BAIS_INIT(step) \ - c[0][step] = vdupq_n_f32(0); \ - c[1][step] = vdupq_n_f32(0); - UNROLL_CALL_RAW(8, BAIS_INIT); + } +}; +template +struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { + static void impl(T& c, const float32_t* bias_ptr, int oc_step) { +#define BAIS_INIT(step) \ + c[0][step] = vld1q_f32(bias_ptr); \ + c[1][step] = vld1q_f32(bias_ptr + oc_step); + UNROLL_CALL_RAW(4, BAIS_INIT); +#undef BAIS_INIT + } +}; +template +struct InitOcxOw8<2, BiasMode::BIAS, 8, T, T2> { + static void impl(T& c, const float32_t* bias_ptr, int oc_step) { + constexpr int simd_len = 4; +#define BAIS_INIT(step) \ + c[0][step] = vld1q_f32(bias_ptr + step * simd_len); \ + c[1][step] = vld1q_f32(bias_ptr + oc_step + step * simd_len); + UNROLL_CALL_RAW(8, BAIS_INIT); #undef BAIS_INIT - } } }; -template -struct InitOcxOw8<1, bias_mode, T, T2> { +template +struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> { + static void impl(T& c, const float32_t* bias_ptr, int oc_step) { + constexpr int simd_len = 4; +#define BAIS_INIT(step) \ + c[0][step] = vld1q_f32(bias_ptr + step * simd_len); \ + c[1][step] = vld1q_f32(bias_ptr + oc_step + step * simd_len); + UNROLL_CALL_RAW(4, BAIS_INIT); +#undef BAIS_INIT + } +}; + +template +struct InitOcxOw8<1, BiasMode::NO_BIAS, 8, T, T2> { + static void impl(T& c, const float32_t*, int) { +#define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } +}; +template +struct InitOcxOw8<1, BiasMode::NO_BIAS, 4, T, T2> { + static void impl(T& c, const float32_t*, int) { +#define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0); + UNROLL_CALL_RAW(4, BAIS_INIT); +#undef BAIS_INIT + } +}; +template +struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { static void impl(T& c, const float32_t* bias_ptr, int) { - if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { #define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr); - UNROLL_CALL_RAW(8, BAIS_INIT); + UNROLL_CALL_RAW(8, BAIS_INIT); #undef BAIS_INIT - } else { -#define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0); - UNROLL_CALL_RAW(8, BAIS_INIT); + } +}; +template +struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { + static void impl(T& c, const float32_t* bias_ptr, int) { +#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr); + UNROLL_CALL_RAW(4, BAIS_INIT); +#undef BAIS_INIT + } +}; +template +struct InitOcxOw8<1, BiasMode::BIAS, 8, T, T2> { + static void impl(T& c, const float32_t* bias_ptr, int) { + constexpr int simd_len = 4; +#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr + step * simd_len); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } +}; +template +struct InitOcxOw8<1, BiasMode::BIAS, 4, T, T2> { + static void impl(T& c, const float32_t* bias_ptr, int) { + constexpr int simd_len = 4; +#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr + step * simd_len); + UNROLL_CALL_RAW(4, BAIS_INIT); #undef BAIS_INIT - } } }; -template +template inline void init_ocx_ow8(T& c, T2 bias_ptr, int oc_step) { - InitOcxOw8::impl(c, bias_ptr, oc_step); + InitOcxOw8::impl(c, bias_ptr, oc_step); } /////////////////////init_ocx_ow4///////////////////// template @@ -638,6 +741,20 @@ struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, T2, XT...> { UNROLL_CALL_RAW(6, WEIGHT_CB); } }; +template +struct LoadHelper<7, base_offset, ptr_step, 0, Func, T, T2, XT...> { + static void impl(T& src, T2 ptr, int, XT... args) { + UNROLL_CALL_RAW(7, WEIGHT_CB); + } +}; +template +struct LoadHelper<8, base_offset, ptr_step, 0, Func, T, T2, XT...> { + static void impl(T& src, T2 ptr, int, XT... args) { + UNROLL_CALL_RAW(8, WEIGHT_CB); + } +}; #undef WEIGHT_CB #define WEIGHT_CB(step) \ @@ -674,6 +791,11 @@ struct LoadHelper<7, base_offset, ptr_step, 1, Func, T, T2> { static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(7, WEIGHT_CB); } }; +template +struct LoadHelper<8, base_offset, ptr_step, 1, Func, T, T2> { + static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(8, WEIGHT_CB); } +}; + #undef WEIGHT_CB #define WEIGHT_CB(step) \ @@ -724,6 +846,13 @@ struct LoadHelper<7, base_offset, ptr_step, 2, Func, T, T2> { } }; +template +struct LoadHelper<8, base_offset, ptr_step, 2, Func, T, T2> { + static void impl(T& src, T2 ptr, int oc_offset) { + UNROLL_CALL_RAW(8, WEIGHT_CB); + } +}; + #undef WEIGHT_CB template get_int8_quint8_conv_bias_args( std::vector get_nchw44_conv_bias_args( std::vector kernel_vec, size_t stride, bool no_pad = false, bool no_bias = false, bool no_nonlinemode = false, - bool is_input_nchw = false) { + bool is_input_nchw = false, bool support_full_bias = false) { using namespace conv_bias; using NLMode = param::ConvBias::NonlineMode; std::vector args; auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t kernel, size_t stride, size_t group, NLMode nlmode, - int any_pad = -1) { + megdnn::BiasMode bias_mode, int any_pad = -1) { constexpr int pack_c = 4; const size_t pad = any_pad >= 0 ? any_pad : kernel / 2; - auto bias_mode = no_bias ? megdnn::BiasMode::NO_BIAS - : megdnn::BiasMode::BROADCAST_CHANNEL_BIAS; auto oc_per_group = oc / group; auto ic_per_group = ic / group; bool ok_group = (oc % group == 0 && ic % group == 0) && @@ -116,6 +114,10 @@ std::vector get_nchw44_conv_bias_args( auto bias_tensor_shape = TensorShape{}; if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c}; + } else if (bias_mode == megdnn::BiasMode::BIAS) { + bias_tensor_shape = {n, oc / pack_c, + (h + 2 * pad - kernel) / stride + 1, + (w + 2 * pad - kernel) / stride + 1, pack_c}; } if (group == 1) { param.sparse = param::ConvBias::Sparse::DENSE; @@ -149,19 +151,29 @@ std::vector get_nchw44_conv_bias_args( nonlinemode.emplace_back(NLMode::RELU); nonlinemode.emplace_back(NLMode::H_SWISH); } - for (auto nlmode : nonlinemode) - for (size_t n : {1, 2}) - for (size_t kernel : kernel_vec) - for (size_t oc : {4, 12, 32}) - for (size_t ic : {1, 3, 4, 12, 32}) - for (size_t h : {3, 5, 12}) - for (size_t w : {7, 16, 23}) { - for (size_t group = 1; - group <= std::min(oc, ic); ++group) { - pack(n, oc, ic, h, w, kernel, stride, group, - nlmode); + + std::vector bias_mode = { + megdnn::BiasMode::BROADCAST_CHANNEL_BIAS}; + if (no_bias) { + bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); + } + if (support_full_bias) { + bias_mode.emplace_back(megdnn::BiasMode::BIAS); + } + for (auto bias : bias_mode) + for (auto nlmode : nonlinemode) + for (size_t n : {1, 2}) + for (size_t kernel : kernel_vec) + for (size_t oc : {4, 12, 32}) + for (size_t ic : {1, 3, 4, 12, 32}) + for (size_t h : {3, 5, 12}) + for (size_t w : {7, 16, 23}) { + for (size_t group = 1; + group <= std::min(oc, ic); ++group) { + pack(n, oc, ic, h, w, kernel, stride, + group, nlmode, bias); + } } - } return args; } @@ -325,6 +337,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) { get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(), "F32DIRECT_SMALL_GROUP"); } + +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { + check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, + false, false, true), + handle(), "F32_CONV_NCHW44_DIRECT_S2"); +} + TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) { check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), handle(), "F32STRD1_LARGE_GROUP");