diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h index 0229898cd7fee51f680006f6b373c1574b553779..7c1bd69283e6443704dce69603d3be8237ded198 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -293,11 +293,11 @@ public: const NCBKernSizeParam& param) const override; }; -class ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44 final : public AlgoBase { +class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; public: - AlgoF32DirectStride2NCHWNCHW44() {} + AlgoF32DirectNCHWNCHW44() {} bool is_reproducible() const override { return true; } const char* name() const override { return "F32_CONV_NCHW_NCHW44"; } bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp similarity index 82% rename from dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp rename to dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp index f6635cd01bb4031df05b26ead49fcaafc8f91875..8fc0962d059c7208782b5cddde2ac133bb7e4409 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp @@ -1,6 +1,6 @@ /** * \file - dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp + dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -13,7 +13,7 @@ #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/fp32/algos.h" -#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h" +#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/strategy.h" #include "src/arm_common/elemwise_op.h" #include "src/common/opr_delegate.h" @@ -26,7 +26,7 @@ using conv_fun = std::function; -MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2) +MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44) namespace { static inline int block_helper(const int nthread, const int amount, const int per_unit_bytes) { @@ -120,11 +120,10 @@ static void pack_weight(WorkspaceBundle bundle, kern_param.filter(group_id) + oc_idx * fh * fw * ic; auto packed_weight = reinterpret_cast(bundle.get(1)) + group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; - conv_bias::pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh, - fw, ic); + pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh, fw, ic); } -template +template static void do_conv_kern(WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, const ConvBiasImpl::NCBKernIndex& ncb_index, @@ -137,7 +136,7 @@ static void do_conv_kern(WorkspaceBundle bundle, const int oc = kern_param.filter_meta.ocpg; const int ih = kern_param.isz[0]; const int iw = kern_param.isz[1]; - const int stride_h = kern_param.filter_meta.stride[0]; + const int stride_h = stride; const int ph = kern_param.filter_meta.padding[0]; const int pw = kern_param.filter_meta.padding[1]; int ih2 = 0; @@ -181,21 +180,15 @@ static void do_conv_kern(WorkspaceBundle bundle, const float* bptr = kern_param.bias(batch_id, group_id) + oc_idx; Op op; -#define KERN1_NCHW44_CONV(filter) \ - conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw_nchw44< \ - \ - bias_mode, Op>(sptr, packed_weight, bptr, nullptr, dst, oc_block, \ - ic, ih_real, iw2, oh, oh_block_real, ow, op, ph, \ - pw) - DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); -#undef KERN1_NCHW44_CONV + conv_direct_fp32_nchw_nchw44( + sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, + oh, oh_block_real, ow, op, ph, pw); } } // namespace -/* ===================== stride2 algo ===================== */ -bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable( +bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable( fallback::ConvBiasImpl*, const NCBKernSizeParam& param, AlgoSelectionStrategy) const { auto&& fm = param.filter_meta; @@ -209,19 +202,20 @@ bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable( bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && (fh == 2 || fh == 3 || fh == 5 || fh == 7); bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 2 && fm.stride[1] == 2; + fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 1 || fm.stride[0] == 2); bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; return avaible; } -size_t ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::get_workspace( +size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { return get_bundle(param).total_size_in_bytes(); } SmallVector -ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns( +ConvBiasImpl::AlgoF32DirectNCHWNCHW44::dispatch_kerns( fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { auto fm = param.filter_meta; const int batch = param.n; @@ -230,61 +224,73 @@ ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns( conv_fun do_conv_fun = nullptr; // NOTE: remain_w is not used to gen hash of midout for compatible with // shape runtime -#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2, \ - midout_iv(#filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44, \ + midout_iv(#stride #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); -#define GET_OP_PARAM(filter, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(filter, bias_mode, NoneOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(filter, bias_mode, ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(filter, bias_mode, HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(stride, filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, NoneOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(stride, filter, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } -#define GET_BIAS_MODE_PARAM(filter) \ - switch (param.bias_mode) { \ - case BiasMode::NO_BIAS: \ - GET_OP_PARAM(filter, BiasMode::NO_BIAS) \ - break; \ - case BiasMode::BROADCAST_CHANNEL_BIAS: \ - GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_BIAS_MODE_PARAM(stride, filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } -#define DISPATCH_CONV_KERN() \ +#define DISPATCH_CONV_KERN(stride) \ switch (param.filter_meta.spatial[0]) { \ case 2: \ - GET_BIAS_MODE_PARAM(2) \ + GET_BIAS_MODE_PARAM(stride, 2) \ break; \ case 3: \ - GET_BIAS_MODE_PARAM(3) \ + GET_BIAS_MODE_PARAM(stride, 3) \ break; \ case 5: \ - GET_BIAS_MODE_PARAM(5) \ + GET_BIAS_MODE_PARAM(stride, 5) \ break; \ case 7: \ - GET_BIAS_MODE_PARAM(7) \ + GET_BIAS_MODE_PARAM(stride, 7) \ break; \ default: \ megdnn_assert(0); \ break; \ } - DISPATCH_CONV_KERN(); + switch (param.filter_meta.stride[0]) { + case 1: + DISPATCH_CONV_KERN(1); + break; + case 2: + DISPATCH_CONV_KERN(2); + break; + default: + megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", + param.filter_meta.stride[0]) + .c_str()); + break; + } #undef DO_CONV_KERN_FUN #undef GET_REMAIN_W_PARAM diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp rename to dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h index 501c1dcd76324f5d740febecd7bcd5a0fa497d3b..205855f55560193fd8c7c5c8871fed96bea5bdce 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h @@ -1,6 +1,5 @@ /** - * \file - * dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp + * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -10,30 +9,36 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ - -#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h" +#pragma once #include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" -using namespace megdnn; -using namespace arm_common; +namespace megdnn { +namespace arm_common { namespace { - -template +/** + *\brief ShiftCalHelper is core calculate code + *\tparam src_idx is offset for src regs + *\tparam weight_idx is offset for weight regs + *\tparam T is type of output regs + *\tparam T2 is type of src regs + *\tparam T3 is type of weight regs + */ +template struct ShiftCalHelper { static void impl(T& c, T2& src, T3& weight); }; -template -struct ShiftCalHelper { +template +struct ShiftCalHelper { static void impl(T& c, T2& src, T3& weight) { - constexpr int stride = 2; #define cb(step) \ c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ c[0][step], weight[0][weight_idx], \ @@ -46,11 +51,10 @@ struct ShiftCalHelper { #undef cb } }; -template -struct ShiftCalHelper { +template +struct ShiftCalHelper { static void impl(T& c, T2& src, T3& weight) { - constexpr int stride = 2; #define cb(step) \ c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ c[0][step], weight[0][weight_idx], \ @@ -61,10 +65,10 @@ struct ShiftCalHelper { } }; -template +template inline void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl( + ShiftCalHelper::impl( c, src, weight); }; template @@ -86,16 +90,18 @@ public: }; /** * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel - * */ + **/ template + int oc_block, int stride, int ow_block> struct KerNeonXXs2NchwNchw44FP32 { static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op); }; -template -struct KerNeonXXs2NchwNchw44FP32 { +template +struct KerNeonXXs2NchwNchw44FP32 { static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -103,7 +109,9 @@ struct KerNeonXXs2NchwNchw44FP32 { constexpr int filter_size = 7; constexpr int oc_step = 4; constexpr int simd_len = 4; - constexpr int src_reg_size = 6; + constexpr int src_reg_size = + (ow_block * stride + filter_size - stride + simd_len - 1) / + simd_len; constexpr int ld_weight_fw = oc_step * filter_size; const int ld_weight_oc = oc_step * filter_size * filter_size * ic; @@ -117,18 +125,18 @@ struct KerNeonXXs2NchwNchw44FP32 { float32x4_t src[src_reg_size]; float32x4_t weight[c_dim][filter_size]; -#define KERNEL_CB(step) \ - load_helper( \ - src, src_ptr + step * iw, 0); \ - load_helper( \ - weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ - cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ - cal_helper<3, 3, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ - cal_helper<4, 4, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ - cal_helper<5, 5, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ - cal_helper<6, 6, c_dim, Vfmaq_laneq_f32>(c, src, weight); +#define KERNEL_CB(step) \ + load_helper( \ + src, src_ptr + step * iw, 0); \ + load_helper( \ + weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ + cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ + cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ + cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ + cal_helper<5, 5, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ + cal_helper<6, 6, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); UNROLL_CALL_RAW(7, KERNEL_CB) #undef KERNEL_CB @@ -140,8 +148,10 @@ struct KerNeonXXs2NchwNchw44FP32 { ld_dst_oc); } }; -template -struct KerNeonXXs2NchwNchw44FP32 { +template +struct KerNeonXXs2NchwNchw44FP32 { static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -149,7 +159,9 @@ struct KerNeonXXs2NchwNchw44FP32 { constexpr int filter_size = 5; constexpr int oc_step = 4; constexpr int simd_len = 4; - constexpr int src_reg_size = 5; + constexpr int src_reg_size = + (ow_block * stride + filter_size - stride + simd_len - 1) / + simd_len; constexpr int ld_weight_fw = oc_step * filter_size; const int ld_weight_oc = oc_step * filter_size * filter_size * ic; @@ -163,16 +175,16 @@ struct KerNeonXXs2NchwNchw44FP32 { float32x4_t src[src_reg_size]; float32x4_t weight[c_dim][filter_size]; -#define KERNEL_CB(step) \ - load_helper( \ - src, src_ptr + step * iw, 0); \ - load_helper( \ - weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ - cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ - cal_helper<3, 3, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ - cal_helper<4, 4, c_dim, Vfmaq_laneq_f32>(c, src, weight); +#define KERNEL_CB(step) \ + load_helper( \ + src, src_ptr + step * iw, 0); \ + load_helper( \ + weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ + cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ + cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ + cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); UNROLL_CALL_RAW(5, KERNEL_CB) #undef KERNEL_CB @@ -184,8 +196,10 @@ struct KerNeonXXs2NchwNchw44FP32 { } }; -template -struct KerNeonXXs2NchwNchw44FP32 { +template +struct KerNeonXXs2NchwNchw44FP32 { static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -193,7 +207,9 @@ struct KerNeonXXs2NchwNchw44FP32 { constexpr int filter_size = 3; constexpr int oc_step = 4; constexpr int simd_len = 4; - constexpr int src_reg_size = 5; + constexpr int src_reg_size = + (ow_block * stride + filter_size - stride + simd_len - 1) / + simd_len; constexpr int ld_weight_fw = oc_step * filter_size; const int ld_weight_oc = oc_step * filter_size * filter_size * ic; @@ -211,27 +227,27 @@ struct KerNeonXXs2NchwNchw44FP32 { 0); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); - cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); + cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); // row 1 load_helper( src, src_ptr + iw, 0); load_helper( weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); - cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); + cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); // row 2 load_helper( src, src_ptr + 2 * iw, 0); load_helper( weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); - cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); + cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); src_ptr += ld_src_ic; weight_ptr += ld_weight_ic; @@ -241,8 +257,10 @@ struct KerNeonXXs2NchwNchw44FP32 { } }; -template -struct KerNeonXXs2NchwNchw44FP32 { +template +struct KerNeonXXs2NchwNchw44FP32 { static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -250,7 +268,9 @@ struct KerNeonXXs2NchwNchw44FP32 { constexpr int filter_size = 2; constexpr int oc_step = 4; constexpr int simd_len = 4; - constexpr int src_reg_size = 4; + constexpr int src_reg_size = + (ow_block * stride + filter_size - stride + simd_len - 1) / + simd_len; constexpr int ld_weight_fw = oc_step * filter_size; const int ld_weight_oc = oc_step * filter_size * filter_size * ic; @@ -268,16 +288,16 @@ struct KerNeonXXs2NchwNchw44FP32 { 0); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); // row 1 load_helper( src, src_ptr + iw, 0); load_helper( weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); - cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); - cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); src_ptr += ld_src_ic; weight_ptr += ld_weight_ic; @@ -286,13 +306,9 @@ struct KerNeonXXs2NchwNchw44FP32 { ld_dst_oc); } }; - -} // namespace - -void conv_bias::pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, - float32_t* dst_ptr, const int oc, - const int kh, const int kw, - const int ic) { +void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr, + const int oc, const int kh, const int kw, + const int ic) { constexpr int oc_step = 4; const int filter_oc_stride = kh * kw * ic; const int filter_ic_stride = kh * kw * oc_step; @@ -312,8 +328,8 @@ void conv_bias::pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, } } -template -static void conv_direct_stride2_fp32_nchw_nchw44( +template +static void conv_direct_fp32_nchw_nchw44( const float32_t* src, const float32_t* filter, const float32_t* bias, float32_t*, float32_t* dst, const int oc, const int ic, const int ih, const int iw, const int oh, const int oh_block, const int ow, @@ -326,8 +342,8 @@ static void conv_direct_stride2_fp32_nchw_nchw44( constexpr int ih_step = 1; constexpr int oh_step = 1; constexpr int ow_step = 8; - constexpr int stride_h = 2; - constexpr int stride_w = 2; + constexpr int stride_h = stride; + constexpr int stride_w = stride; constexpr int pack_iw_len = 1; const int img_stride = oh * ow; @@ -345,14 +361,14 @@ static void conv_direct_stride2_fp32_nchw_nchw44( remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs2NchwNchw44FP32::impl; \ - kern_small_oc_remain = \ - KerNeonXXs2NchwNchw44FP32::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs2NchwNchw44FP32::impl; \ + kern_small_oc_remain = \ + KerNeonXXs2NchwNchw44FP32::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -368,12 +384,13 @@ static void conv_direct_stride2_fp32_nchw_nchw44( ic_step * pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44FP32< - bias_mode, Op, 0, filter_size, - big_oc_step>::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); + KerNeonXXs2NchwNchw44FP32::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); } if (ow_remain > 0) { const int src_offset = @@ -397,8 +414,9 @@ static void conv_direct_stride2_fp32_nchw_nchw44( ic_step * pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44FP32::impl(src + src_offset, + KerNeonXXs2NchwNchw44FP32::impl(src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, @@ -417,56 +435,7 @@ static void conv_direct_stride2_fp32_nchw_nchw44( } } } - -#define CONSTRUCT_FUNC(filter_size) \ - template \ - void conv_bias:: \ - conv_direct_stride2_##filter_size##x##filter_size##_fp32_nchw_nchw44( \ - const float32_t* src, const float32_t* filter, \ - const float32_t* bias, float32_t* temp, float32_t* dst, \ - const int oc, const int ic, const int ih, const int iw, \ - const int oh, const int oh_block, const int ow, \ - const Op& op, const int ph, const int pw) { \ - conv_direct_stride2_fp32_nchw_nchw44( \ - src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \ - ow, op, ph, pw); \ - } - -CONSTRUCT_FUNC(2); -CONSTRUCT_FUNC(3); -CONSTRUCT_FUNC(5); -CONSTRUCT_FUNC(7); -#undef CONSTRUCT_FUNC - -#define INSTANTIATION(stride, i, bias, Op) \ - template void conv_bias:: \ - conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44( \ - const float32_t*, const float32_t*, const float32_t*, \ - float32_t*, float32_t*, const int, const int, const int, \ - const int, const int, const int, const int, const Op&, \ - const int, const int); - -#define FOR_OP(stride, i, bias) \ - INSTANTIATION(stride, i, bias, NoneOp) \ - INSTANTIATION(stride, i, bias, ReluOp) \ - INSTANTIATION(stride, i, bias, HSwishOp) - -#define FOR_BIAS(stride, i) \ - FOR_OP(stride, i, BiasMode::NO_BIAS) \ - FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) - -#define FOR_FILTER(stride) \ - FOR_BIAS(stride, 2) \ - FOR_BIAS(stride, 3) \ - FOR_BIAS(stride, 5) \ - FOR_BIAS(stride, 7) - -FOR_FILTER(stride2) - -#undef FOR_STRIDE -#undef FOR_FILTER -#undef FOR_IC -#undef FOR_BIAS -#undef FOR_NONLINEAR -#undef FOR_REMAIN -#undef INSTANTIATION +} // namespace +} // namespace arm_common +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h deleted file mode 100644 index ec3fca8100d38a62f31d66ab11032f9be6b3a6c6..0000000000000000000000000000000000000000 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "src/arm_common/conv_bias/opr_impl.h" -#include "src/fallback/conv_bias/common.h" - -namespace megdnn { -namespace arm_common { -namespace conv_bias { -#define KERN(stride, i, layout) \ - template \ - void conv_direct_##stride##_##i##x##i##_fp32_nchw_##layout( \ - const float* src, const float* filter, const float* bias, \ - float* temp, float* dst, const int oc, const int ic, const int ih, \ - const int iw, const int oh, const int oh_block, const int ow, \ - const Op& op, const int ph, const int pw); - -KERN(stride2, 2, nchw44) -KERN(stride2, 3, nchw44) -KERN(stride2, 5, nchw44) -KERN(stride2, 7, nchw44) -#undef KERN -void pack_weight_fp32_nchw_nchw44(const float_t* in_ptr, float_t* dst_ptr, - const int oc, const int kh, const int kw, - const int ic); - -} // namespace conv_bias -} // namespace arm_common -} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index f8fd4f204f4f85f5baad3468e7bef2acab51fcaf..38ec7b7a8fc012701cd45fbc1c1e19c99aa6c849 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -66,7 +66,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; #endif - AlgoF32DirectStride2NCHWNCHW44 f32_direct_stride2_nchw_nchw44; + AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; AlgoF32DirectNCHW44 f32_direct_nchw44; diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 1e92a0fb0605d29d3c168c6c19abee6cd4f08f52..35c9f35396f6c3c83f312cd62d60f8d1a9073fe0 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -71,7 +71,7 @@ private: class AlgoF32Direct; class AlgoF32DirectStride1; class AlgoF32DirectStride2; - class AlgoF32DirectStride2NCHWNCHW44; + class AlgoF32DirectNCHWNCHW44; class AlgoF32ChannelWiseNCHW44; class AlgoF32DirectNCHW44; diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 7651dd7ed94fb85a74dc576074121bc440d8f4ed..6ce361e93d2853646893e7e0180b4290f50826f1 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -204,6 +204,11 @@ static void benchmark_convbias(Handle* handle, std::string int_name, run(1, 3, 32, 224, 224, 3, 2, true); run(1, 3, 64, 224, 224, 7, 2, true); + run(1, 1, 4, 112, 112, 2, 1, true); + run(1, 3, 32, 224, 224, 3, 1, true); + run(1, 3, 64, 224, 224, 3, 1, true); + run(1, 3, 64, 224, 224, 7, 1, true); + run(1, 64, 128, 56, 56, 3, 2, false); run(1, 128, 256, 28, 28, 3, 2, false); run(1, 256, 512, 14, 14, 3, 2, false); diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 6f4fae34bb007845f1cc05c2ccbe6c19adb18e53..ed9e2072c86b2257630103e47522c74db4dac91a 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -392,6 +392,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) { check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, true), handle(), "F32_CONV_NCHW_NCHW44"); + check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, + false, true), + handle(), "F32_CONV_NCHW_NCHW44"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) { check_conv_bias( @@ -824,13 +827,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) { auto conv_bias_opr = handle->create_operator(); conv_bias_opr->param() = param; - conv_bias_opr->param().format = param::ConvBias::Format::NCHW44_WINOGRAD; + conv_bias_opr->param().format = + param::ConvBias::Format::NCHW44_WINOGRAD; conv_bias_opr->param().output_block_size = m; size_t conv_bias_workspace_in_bytes = conv_bias_opr->get_workspace_in_bytes( tensors[0].layout, filter_transform_layout, - tensors[2].layout, tensors[3].layout, - tensors[4].layout, nullptr); + tensors[2].layout, tensors[3].layout, tensors[4].layout, + nullptr); WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(), conv_bias_workspace_in_bytes,