From 580a2753321b5ca8716d771c364ae1b39b757832 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 20 May 2020 20:47:21 +0800 Subject: [PATCH] feat(dnn/arm): add nchw44 fp32 direct stride 1 GitOrigin-RevId: 65f54a4f7ea754b16588baa1dc017ebe6599940d --- dnn/src/arm_common/conv_bias/fp32/algos.h | 16 + ...44_algo.cpp => f32_direct_nchw44_algo.cpp} | 110 ++-- .../fp32/f32_direct_stride1_nchw44_kern.cpp | 571 ++++++++++++++++++ .../fp32/f32_direct_stride1_nchw44_kern.h | 40 ++ .../fp32/f32_direct_stride2_nchw44_kern.cpp | 9 +- dnn/src/arm_common/conv_bias/opr_impl.cpp | 5 +- dnn/src/arm_common/conv_bias/opr_impl.h | 3 +- .../elemwise_helper/kimpl/sigmoid.h | 7 +- dnn/test/arm_common/conv_bias.cpp | 8 +- .../arm_common/conv_bias_multi_thread.cpp | 19 +- 10 files changed, 731 insertions(+), 57 deletions(-) rename dnn/src/arm_common/conv_bias/fp32/{f32_direct_stride2_nchw44_algo.cpp => f32_direct_nchw44_algo.cpp} (76%) create mode 100644 dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h index eb92bb026..dd172ed40 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -178,6 +178,22 @@ public: const NCBKernSizeParam& param) const override; }; +class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + +public: + AlgoF32DirectNCHW44() {} + bool is_reproducible() const override { return true; } + const char* name() const override { return "F32_CONV_NCHW44_DIRECT"; } + bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; class ConvBiasImpl::AlgoF32DirectStride2NCHW44 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp similarity index 76% rename from dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_algo.cpp rename to dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp index 627c6ebcd..652722c8f 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_algo.cpp + * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -12,10 +12,9 @@ #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/fp32/algos.h" +#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" -#include "src/arm_common/conv_bias/fp32/strategy.h" #include "src/arm_common/elemwise_op.h" -#include "src/common/opr_delegate.h" #include "midout.h" @@ -25,7 +24,7 @@ using conv_fun = std::function; -MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride2) +MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1) namespace { // block_helper is used to calculate oh block size static inline int block_helper(const int nthread, const int amount, @@ -79,7 +78,7 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { return {nullptr, {src_size * param.nr_threads}}; }; -template +template static void do_conv_kern(WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, const ConvBiasImpl::NCBKernIndex& ncb_index, @@ -125,11 +124,17 @@ static void do_conv_kern(WorkspaceBundle bundle, const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); float* sptr = reinterpret_cast((int8_t*)bundle.get(0) + ncb_index.thread_id * src_size); - - conv_bias::pack_src_fp32_nchw44_stride2( - sptr, origin_sptr, ph, pw, remain_right_pad, - ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, - src_bottom_pad, ic, ih * iw); + if (stride == 1) { + conv_bias::pack_src_fp32_nchw44_stride1( + sptr, origin_sptr, ph, pw, remain_right_pad, + ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, + src_bottom_pad, ic, ih * iw); + } else { + conv_bias::pack_src_fp32_nchw44_stride2( + sptr, origin_sptr, ph, pw, remain_right_pad, + ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, + src_bottom_pad, ic, ih * iw); + } const float* fptr = kern_param.filter(group_id) + oc_idx * fh * fw * ic; @@ -142,46 +147,59 @@ static void do_conv_kern(WorkspaceBundle bundle, kern_param.bias(batch_id, group_id) + bias_offset; Op op; + if (stride == 1) { +#define KERN1_NCHW44_CONV(filter) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_fp32_nchw44< \ + \ + bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \ + ih_real, iw2, oh, oh_block_real, ow, op, ph, pw) + + DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); +#undef KERN1_NCHW44_CONV + } else { #define KERN1_NCHW44_CONV(filter) \ conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw44< \ \ bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \ ih_real, iw2, oh, oh_block_real, ow, op, ph, pw) - DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); + DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); #undef KERN1_NCHW44_CONV + } } } // namespace -/* ===================== stride2 algo ===================== */ -bool ConvBiasImpl::AlgoF32DirectStride2NCHW44::usable( - fallback::ConvBiasImpl*, const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +/* ===================== stride1 algo ===================== */ +bool ConvBiasImpl::AlgoF32DirectNCHW44::usable(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { auto&& fm = param.filter_meta; auto fh = fm.spatial[0]; int oc = fm.ocpg; + int ic = fm.icpg; bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float32 && param.filter_type.enumv() == DTypeEnum::Float32 && (param.dst_type.enumv() == DTypeEnum::Float32))) && (fm.format == param::Convolution::Format::NCHW44); - bool ok_src_dst = (oc % 4 == 0 && oc >= 4); + bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic % 4 == 0 && ic >= 4); bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && (fh == 2 || fh == 3 || fh == 5 || fh == 7); bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 2 && fm.stride[1] == 2; + ((fm.stride[0] == 1 && fm.stride[1] == 1) || + (fm.stride[0] == 2 && fm.stride[1] == 2)); bool ok_conv = !fm.should_flip; bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; return avaible; } -size_t ConvBiasImpl::AlgoF32DirectStride2NCHW44::get_workspace( +size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace( fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { return get_bundle(param).total_size_in_bytes(); } SmallVector -ConvBiasImpl::AlgoF32DirectStride2NCHW44::dispatch_kerns( +ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns( fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { auto fm = param.filter_meta; const int batch = param.n; @@ -190,27 +208,43 @@ ConvBiasImpl::AlgoF32DirectStride2NCHW44::dispatch_kerns( conv_fun do_conv_fun = nullptr; // NOTE: remain_w is not used to gen hash of midout for compatible with // shape runtime -#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw44_stride2, \ - midout_iv(#filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw44_stride1, \ + midout_iv(#filter #bias_mode #stride #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); -#define GET_OP_PARAM(filter, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(filter, bias_mode, NoneOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(filter, bias_mode, ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(filter, bias_mode, HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_STRIDE_PARAM(filter, bias_mode, op) \ + switch (fm.stride[0]) { \ + case 1: \ + DO_CONV_KERN_FUN(filter, bias_mode, op, 1); \ + break; \ + case 2: \ + DO_CONV_KERN_FUN(filter, bias_mode, op, 2); \ + break; \ + \ + default: \ + megdnn_assert(0); \ + } + +#define GET_OP_PARAM(filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + GET_STRIDE_PARAM(filter, bias_mode, NoneOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + GET_STRIDE_PARAM(filter, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + GET_STRIDE_PARAM(filter, bias_mode, HSwishOp) \ + break; \ + case param::ConvBias::NonlineMode::SIGMOID: \ + GET_STRIDE_PARAM(filter, bias_mode, SigmoidOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(filter) \ diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp new file mode 100644 index 000000000..9daf67a1c --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp @@ -0,0 +1,571 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +namespace { + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight); +}; + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step, lane) \ + c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 8]); \ + c[1][step] = Func::template impl(c[1][step], weight[1][lane], \ + src[(step + src_idx) % 8]); + + UNROLL_CALL_RAW(8, cb, 0); + UNROLL_CALL_RAW(8, cb, 1); + UNROLL_CALL_RAW(8, cb, 2); + UNROLL_CALL_RAW(8, cb, 3); +#undef cb + } +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step, lane) \ + c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 4]); \ + c[1][step] = Func::template impl(c[1][step], weight[1][lane], \ + src[(step + src_idx) % 4]); + + UNROLL_CALL_RAW(4, cb, 0); + UNROLL_CALL_RAW(4, cb, 1); + UNROLL_CALL_RAW(4, cb, 2); + UNROLL_CALL_RAW(4, cb, 3); +#undef cb + } +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step, lane) \ + c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 8]); + + UNROLL_CALL_RAW(8, cb, 0); + UNROLL_CALL_RAW(8, cb, 1); + UNROLL_CALL_RAW(8, cb, 2); + UNROLL_CALL_RAW(8, cb, 3); +#undef cb + } +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { +#define cb(step, lane) \ + c[0][step] = Func::template impl(c[0][step], weight[0][lane], \ + src[(step + src_idx) % 4]); + + UNROLL_CALL_RAW(4, cb, 0); + UNROLL_CALL_RAW(4, cb, 1); + UNROLL_CALL_RAW(4, cb, 2); + UNROLL_CALL_RAW(4, cb, 3); +#undef cb + } +}; + +template +inline void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl(c, src, weight); +}; +template +struct OCHelper { +public: + static const int val = -1; +}; + +template <> +struct OCHelper<4> { +public: + static const int val = 1; +}; +#if MEGDNN_AARCH64 +template <> +struct OCHelper<8> { +public: + static const int val = 2; +}; +#endif + +/** + * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel + * */ +template +struct KerNeonXXs1Nchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op); +}; + +template +struct KerNeonXXs1Nchw44FP32 { + static void impl(const float32_t* src_ptr_origin, + const float32_t* weight_ptr, const float32_t* bias_ptr, + float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, + const Op& op) { + constexpr int ic_step = 4; + constexpr int filter_size = 2; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + + constexpr int ld_weight = oc_step * oc_step; + const int ld_bias = bias_mode == BiasMode::BIAS ? ld_dst_oc : oc_step; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_fh = oc_step * oc_step * filter_size; + const int ld_src_ic = ih * iw; + const int ld_src_iw = iw * oc_step; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { + float32x4_t src[ow_block]; + float32x4_t weight[c_dim][ic_step]; + load_helper(src, src_ptr, + 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src_ptr += ld_src_iw; + weight_ptr += ld_weight_fh; + } + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +template +struct KerNeonXXs1Nchw44FP32 { + static void impl(const float32_t* src_ptr_origin, + const float32_t* weight_ptr, const float32_t* bias_ptr, + float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, + const Op& op) { + constexpr int ic_step = 4; + constexpr int filter_size = 3; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + + constexpr int ld_weight = oc_step * oc_step; + const int ld_bias = bias_mode == BiasMode::BIAS ? ld_dst_oc : oc_step; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_fh = oc_step * oc_step * filter_size; + const int ld_src_ic = ih * iw; + const int ld_src_iw = iw * oc_step; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { + float32x4_t src[ow_block]; + float32x4_t weight[c_dim][ic_step]; + load_helper(src, src_ptr, + 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src_ptr += ld_src_iw; + weight_ptr += ld_weight_fh; + } + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; +template +struct KerNeonXXs1Nchw44FP32 { + static void impl(const float32_t* src_ptr_origin, + const float32_t* weight_ptr, const float32_t* bias_ptr, + float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, + const Op& op) { + constexpr int ic_step = 4; + constexpr int filter_size = 5; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + + constexpr int ld_weight = oc_step * oc_step; + const int ld_bias = bias_mode == BiasMode::BIAS ? ld_dst_oc : oc_step; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_fh = oc_step * oc_step * filter_size; + const int ld_src_ic = ih * iw; + const int ld_src_iw = iw * oc_step; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { + float32x4_t src[ow_block]; + float32x4_t weight[c_dim][ic_step]; + load_helper(src, src_ptr, + 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src_ptr += ld_src_iw; + weight_ptr += ld_weight_fh; + } + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +template +struct KerNeonXXs1Nchw44FP32 { + static void impl(const float32_t* src_ptr_origin, + const float32_t* weight_ptr, const float32_t* bias_ptr, + float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, + const Op& op) { + constexpr int ic_step = 4; + constexpr int filter_size = 7; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + + constexpr int ld_weight = oc_step * oc_step; + const int ld_bias = bias_mode == BiasMode::BIAS ? ld_dst_oc : oc_step; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_fh = oc_step * oc_step * filter_size; + const int ld_src_ic = ih * iw; + const int ld_src_iw = iw * oc_step; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][ow_block]; + init_ocx_ow8(c, bias_ptr, ld_bias); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { + const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; + for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { + float32x4_t src[ow_block]; + float32x4_t weight[c_dim][ic_step]; + load_helper(src, src_ptr, + 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<5, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + + src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<6, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, + weight); + src_ptr += ld_src_iw; + weight_ptr += ld_weight_fh; + } + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +} // namespace + +void conv_bias::pack_src_fp32_nchw44_stride1( + float* sptr_base, const float* sptr_origin, const int, const int pw, + const int pad_right, const int ih, const int iw, const int iw2, + const int pad_top, const int pad_bottom, const int ic, + const int ic_stride) { + constexpr int ic_step = 4; + rep_step(ic_idx, ic, ic_step) { + const float* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step); + sptr_base += iw2 * pad_top * ic_step; + rep(ih_idx, ih) { + memset(sptr_base, 0, sizeof(float) * pw * ic_step); + sptr_base += pw * ic_step; + memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step); + sptr_base += iw * ic_step; + sptr += iw * ic_step; + memset(sptr_base, 0, sizeof(float) * pad_right * ic_step); + sptr_base += pad_right * ic_step; + } + memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step); + sptr_base += iw2 * pad_bottom * ic_step; + } +} + +template +static void conv_direct_stride1_fp32_nchw44( + const float32_t* src, const float32_t* filter, const float32_t* bias, + float32_t*, float32_t* dst, const int oc, const int ic, const int ih, + const int iw, const int oh, const int oh_block, const int ow, + const Op& op, const int, const int) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 4; +#if MEGDNN_ARMV7 + constexpr int big_oc_step = 4; +#else + constexpr int big_oc_step = 8; +#endif + constexpr int oc_step = 4; + constexpr int ih_step = 1; + constexpr int oh_step = 1; + constexpr int ow_step = 8; + constexpr int stride_h = 1; + constexpr int stride_w = 1; + + const int img_stride = oh * ow; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs1Nchw44FP32::impl; \ + kern_small_oc_remain = \ + KerNeonXXs1Nchw44FP32::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + const int bias_offset = + bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; + KerNeonXXs1Nchw44FP32::impl(src + src_offset, + filter + weight_offset, + bias + bias_offset, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + const int bias_offset = + bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + bias_offset, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + int oc_idx = oc_end; + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + const int bias_offset = + bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; + KerNeonXXs1Nchw44FP32::impl(src + src_offset, + filter + weight_offset, + bias + bias_offset, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + const int bias_offset = + bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + bias_offset, dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + } + } +} + +#define CONSTRUCT_FUNC(filter_size) \ + template \ + void conv_bias:: \ + conv_direct_stride1_##filter_size##x##filter_size##_fp32_nchw44( \ + const float32_t* src, const float32_t* filter, \ + const float32_t* bias, float32_t* temp, float32_t* dst, \ + const int oc, const int ic, const int ih, const int iw, \ + const int oh, const int oh_block, const int ow, \ + const Op& op, const int ph, const int pw) { \ + conv_direct_stride1_fp32_nchw44( \ + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \ + ow, op, ph, pw); \ + } +CONSTRUCT_FUNC(2); +CONSTRUCT_FUNC(3); +CONSTRUCT_FUNC(5); +CONSTRUCT_FUNC(7); +#undef CONSTRUCT_FUNC + +#define INSTANTIATION(stride, i, bias, Op) \ + template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \ + bias, Op>(const float32_t*, const float32_t*, const float32_t*, \ + float32_t*, float32_t*, const int, const int, const int, \ + const int, const int, const int, const int, const Op&, \ + const int, const int); + +#define FOR_OP(stride, i, bias) \ + INSTANTIATION(stride, i, bias, NoneOp) \ + INSTANTIATION(stride, i, bias, ReluOp) \ + INSTANTIATION(stride, i, bias, HSwishOp) \ + INSTANTIATION(stride, i, bias, SigmoidOp) + +#define FOR_BIAS(stride, i) \ + FOR_OP(stride, i, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + FOR_OP(stride, i, BiasMode::BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) \ + FOR_BIAS(stride, 7) + +FOR_FILTER(stride1) + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef FOR_REMAIN +#undef INSTANTIATION diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h new file mode 100644 index 000000000..c58d3d914 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" +namespace megdnn { +namespace arm_common { +namespace conv_bias { +#define KERN(stride, i, layout) \ + template \ + void conv_direct_##stride##_##i##x##i##_fp32_##layout( \ + const float* src, const float* filter, const float* bias, \ + float* temp, float* dst, const int oc, const int ic, const int ih, \ + const int iw, const int oh, const int oh_block, const int ow, \ + const Op& op, const int ph, const int pw); + +KERN(stride1, 2, nchw44) +KERN(stride1, 3, nchw44) +KERN(stride1, 5, nchw44) +KERN(stride1, 7, nchw44) +#undef KERN + +void pack_src_fp32_nchw44_stride1(float* sptr_base, const float* sptr_origin, + const int ph, const int pw, + const int pad_right, const int ih, + const int iw, const int iw2, + const int pad_top, const int pad_bottom, + const int ic, const int ic_stride); +} // namespace conv_bias +} // namespace arm_common +} // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp index 7f607ca09..2b2b73d8c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp @@ -721,10 +721,11 @@ CONSTRUCT_FUNC(7); const int, const int, const int, const int, const Op&, \ const int, const int); -#define FOR_OP(stride, i, bias) \ - INSTANTIATION(stride, i, bias, NoneOp) \ - INSTANTIATION(stride, i, bias, ReluOp) \ - INSTANTIATION(stride, i, bias, HSwishOp) +#define FOR_OP(stride, i, bias) \ + INSTANTIATION(stride, i, bias, NoneOp) \ + INSTANTIATION(stride, i, bias, ReluOp) \ + INSTANTIATION(stride, i, bias, HSwishOp) \ + INSTANTIATION(stride, i, bias, SigmoidOp) #define FOR_BIAS(stride, i) \ FOR_OP(stride, i, BiasMode::NO_BIAS) \ diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 6ba512642..ecefafac3 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -67,7 +67,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoF32Direct f32_direct_large_group{true}; AlgoF32Direct f32_direct_small_group{false}; - AlgoF32DirectStride2NCHW44 f32_direct_stride2_nchw44; + AlgoF32DirectNCHW44 f32_direct_nchw44; AlgoF32DirectStride2 f32_direct_stride2_large_group{true}; AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; AlgoF32DirectStride1 f32_direct_stride1_large_group{true}; @@ -126,8 +126,7 @@ public: direct_algos.emplace_back(&i8x8x16_stride2_large_group); direct_algos.emplace_back(&i8x8x16_stride2_small_group); direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); - - direct_algos.emplace_back(&f32_direct_stride2_nchw44); + direct_algos.emplace_back(&f32_direct_nchw44); direct_algos.emplace_back(&f32_direct_stride1_large_group); direct_algos.emplace_back(&f32_direct_stride1_small_group); direct_algos.emplace_back(&f32_direct_stride2_large_group); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index ba2cbe782..939abd7be 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -66,10 +66,11 @@ private: #endif class AlgoF32Direct; class AlgoF32DirectStride1; + class AlgoF32DirectNCHW44; class AlgoF32DirectStride2; class AlgoF32DirectStride2NCHWNCHW44; - class AlgoF32DirectStride2NCHW44; + class AlgoI8x8x16Direct; class AlgoI8x8x16Stride2; class AlgoI8x8x16Stride2Filter2; diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h b/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h index 60be900b3..c79283733 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -43,6 +44,10 @@ struct SigmoidOp; vst1q_##_func_suffix(dst, vitem.val[0]); \ vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ } \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ _neon_type2 operator()(const _neon_type2& src) const { \ return {{operator()(src.val[0]), operator()(src.val[1])}}; \ } \ diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index c558cf21c..f666c68e1 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -203,11 +203,9 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { run(1, 128, 256, 28, 28, 3, 2, false); run(1, 256, 512, 14, 14, 3, 2, false); - run(1, 64, 128, 56, 56, 7, 2, false); - run(1, 128, 256, 28, 28, 7, 2, false); - run(1, 256, 512, 14, 14, 7, 2, false); - - run(1, 64, 64, 48, 48, 3, 2, false); + run(1, 128, 128, 28, 28, 3, 1, false); + run(1, 256, 256, 14, 14, 3, 1, false); + run(1, 512, 512, 7, 7, 3, 1, false); } else { for (size_t stride : {1, 2}) { printf("stride %zu\n", stride); diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 85b55dec8..064105e3c 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -72,7 +72,8 @@ std::vector get_int8_quint8_conv_bias_args( std::vector get_nchw44_conv_bias_args( std::vector kernel_vec, size_t stride, bool no_pad = false, bool no_bias = false, bool no_nonlinemode = false, - bool is_input_nchw = false, bool support_full_bias = false) { + bool is_input_nchw = false, bool support_full_bias = false, + bool support_sigmoid = false) { using namespace conv_bias; using NLMode = param::ConvBias::NonlineMode; std::vector args; @@ -151,6 +152,9 @@ std::vector get_nchw44_conv_bias_args( nonlinemode.emplace_back(NLMode::RELU); nonlinemode.emplace_back(NLMode::H_SWISH); } + if (support_sigmoid) { + nonlinemode.emplace_back(NLMode::SIGMOID); + } std::vector bias_mode = { megdnn::BiasMode::BROADCAST_CHANNEL_BIAS}; @@ -337,11 +341,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) { get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(), "F32DIRECT_SMALL_GROUP"); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1) { + check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, + false, false, true, true), + handle(), "F32_CONV_NCHW44_DIRECT"); +} TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, - false, false, true), - handle(), "F32_CONV_NCHW44_DIRECT_S2"); + false, false, true, true), + handle(), "F32_CONV_NCHW44_DIRECT"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1_LARGE_GROUP) { @@ -682,8 +691,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) { size_t conv_bias_workspace_in_bytes = conv_bias_opr->get_workspace_in_bytes( tensors[0].layout, filter_transform_layout, - tensors[2].layout, tensors[3].layout, - tensors[4].layout, nullptr); + tensors[2].layout, tensors[3].layout, tensors[4].layout, + nullptr); WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(), conv_bias_workspace_in_bytes, -- GitLab