From c9986df596b43e00108cd904de19458b54eaf0b9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 11 May 2020 14:56:41 +0800 Subject: [PATCH] feat(dnn/arm): add fp32 nchw_nchw44 conv GitOrigin-RevId: f19fe892d9f3e4c166d4835804bf5fc0ad31ccbc --- dnn/src/arm_common/conv_bias/fp32/algos.h | 22 +- .../f32_direct_stride2_nchw_nchw44_algo.cpp | 317 +++++++++++++ .../f32_direct_stride2_nchw_nchw44_kern.cpp | 430 ++++++++++++++++++ .../f32_direct_stride2_nchw_nchw44_kern.h | 38 ++ .../arm_common/conv_bias/intrinsic_helper.h | 357 ++++++++++++--- dnn/src/arm_common/conv_bias/neon_struct.h | 11 + dnn/src/arm_common/conv_bias/opr_impl.cpp | 2 + dnn/src/arm_common/conv_bias/opr_impl.h | 1 + .../arm_common/elemwise_helper/kimpl/hswish.h | 7 +- .../arm_common/elemwise_helper/kimpl/none.h | 10 +- .../arm_common/elemwise_helper/kimpl/relu.h | 5 + dnn/src/arm_common/simd_macro/marm_neon.h | 33 ++ dnn/test/arm_common/conv_bias.cpp | 84 ++-- .../arm_common/conv_bias_multi_thread.cpp | 14 +- 14 files changed, 1231 insertions(+), 100 deletions(-) create mode 100644 dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp create mode 100644 dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h index c97e7fc30..2a63fad08 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -156,7 +157,6 @@ private: uint32_t m_tile_size; }; - class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; bool m_large_group; @@ -217,6 +217,24 @@ public: fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param) const override; }; + +class ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + +public: + AlgoF32DirectStride2NCHWNCHW44() {} + bool is_reproducible() const override { return true; } + const char* name() const override { return "F32_CONV_NCHW_NCHW44"; } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; +}; + } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp new file mode 100644 index 000000000..8499fef2b --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp @@ -0,0 +1,317 @@ +/** + * \file + dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. + */ + +#include "megdnn/oprs.h" +#include "src/arm_common/conv_bias/fp32/algos.h" +#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h" +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/elemwise_op.h" +#include "src/common/opr_delegate.h" + +#include "midout.h" + +using namespace megdnn; +using namespace arm_common; +using conv_fun = std::function; +MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2) +namespace { +static inline int block_helper(const int nthread, const int amount, + const int per_unit_bytes) { + MEGDNN_MARK_USED_VAR(per_unit_bytes); + const int block_per_thread = div_ceil(amount, nthread); + const int best_block = 16; + const int max_block_num = div_ceil(block_per_thread, best_block); + const int min_block_num = std::max(max_block_num - 1, 1); + const int max_block = div_ceil(block_per_thread, max_block_num); + const int min_block = div_ceil(block_per_thread, min_block_num); + const int max_loss = std::abs(max_block_num * max_block - block_per_thread); + const int min_loss = std::abs(min_block_num * min_block - block_per_thread); + int block = max_loss > min_loss ? min_block : max_block; + return block; +} +static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, + const int iw2) { + // border_size is used to avoid read illegal memory + int border_size = 64 * 2; + return ic * ih2 * iw2 * sizeof(float) + border_size; +} +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, + int& iw2, int& oh2, int& ow2) { + int iw = param.isz[1]; + int oh = param.osz[0]; + int ow = param.osz[1]; + + oh2 = oh; + ow2 = ow; + constexpr int cacheline = 64 / sizeof(float); + int block_oh = block_helper(param.nr_threads, oh, 0); + auto&& fm = param.filter_meta; + const int stride_h = static_cast(fm.stride[0]); + const int filter_h = static_cast(fm.spatial[0]); + ih2 = block_oh * stride_h + filter_h - stride_h; + iw2 = round_up(iw + 2 * static_cast(fm.padding[1]), cacheline); +} + +static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + int group = fm.group; + int ic = fm.icpg; + int oc = fm.ocpg; + int fh = fm.spatial[0]; + int fw = fm.spatial[1]; + int ih2, iw2, oh2, ow2; + get_rectified_size(param, ih2, iw2, oh2, ow2); + + int oh_block = block_helper(param.nr_threads, oh2, 0); + megdnn_assert(oh_block != 0, "oh_block!=0"); + size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); + size_t weight_size = group * oc * ic * fh * fw * sizeof(float); + return {nullptr, {src_size * param.nr_threads, weight_size}}; +}; + +static inline void copy_pad_src(float* sptr_base, const float* sptr_origin, + int ph, int pw, int pad_right, int ih, int iw, + int iw2, int pad_top, int pad_bottom, int ic, + int ic_stride) { + MEGDNN_MARK_USED_VAR(ph); + rep(ic_idx, ic) { + const float* sptr = sptr_origin + ic_idx * ic_stride; + memset(sptr_base, 0, sizeof(float) * iw2 * pad_top); + sptr_base += iw2 * pad_top; + rep(ih_idx, ih) { + memset(sptr_base, 0, sizeof(float) * pw); + sptr_base += pw; + memcpy(sptr_base, sptr, sizeof(float) * iw); + sptr_base += iw; + sptr += iw; + memset(sptr_base, 0, sizeof(float) * pad_right); + sptr_base += pad_right; + } + memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom); + sptr_base += iw2 * pad_bottom; + } +} +static void pack_weight(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { + bundle.set(kern_param.workspace_ptr); + const int group_id = ncb_index.ndrange_id[0]; + int fh = kern_param.filter_meta.spatial[0]; + int fw = kern_param.filter_meta.spatial[1]; + int oc = kern_param.filter_meta.ocpg; + int ic = kern_param.filter_meta.icpg; + int oc_block = oc; + int oc_idx = 0; + const float* fptr = + kern_param.filter(group_id) + oc_idx * fh * fw * ic; + auto packed_weight = reinterpret_cast(bundle.get(1)) + + group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; + conv_bias::pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh, + fw, ic); +} + +template +static void do_conv_kern(WorkspaceBundle bundle, + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, + const CpuNDRange&, const CpuNDRange&) { + const int oh = kern_param.osz[0]; + const int ow = kern_param.osz[1]; + const int fh = kern_param.filter_meta.spatial[0]; + const int fw = kern_param.filter_meta.spatial[1]; + const int ic = kern_param.filter_meta.icpg; + const int oc = kern_param.filter_meta.ocpg; + const int ih = kern_param.isz[0]; + const int iw = kern_param.isz[1]; + const int stride_h = kern_param.filter_meta.stride[0]; + const int ph = kern_param.filter_meta.padding[0]; + const int pw = kern_param.filter_meta.padding[1]; + int ih2 = 0; + int iw2 = 0; + int oh2 = 0; + int ow2 = 0; + get_rectified_size(kern_param, ih2, iw2, oh2, ow2); + bundle.set(kern_param.workspace_ptr); + + constexpr int pack_c = 4; + const int batch_id = ncb_index.ndrange_id[0]; + const int group_id = ncb_index.ndrange_id[1]; + int oc_idx = 0; + int oc_block = oc; + int oh_block = block_helper(kern_param.nr_threads, oh2, 0); + const int oh_idx = ncb_index.ndrange_id[2]; + const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); + const int ih_real = oh_block_real * stride_h + fh - stride_h; + const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0); + const int src_bottom_pad = std::max( + (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, + 0); + const int remain_right_pad = std::max(iw2 - iw - pw, 0); + const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw; + const float* origin_sptr = static_cast(kern_param.src( + batch_id, group_id, 0, 1, 1)) + + src_offset; + const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); + float* sptr = reinterpret_cast((int8_t*)bundle.get(0) + + ncb_index.thread_id * src_size); + + copy_pad_src(sptr, origin_sptr, ph, pw, remain_right_pad, + ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, + src_bottom_pad, ic, ih * iw); + // pack weight + auto packed_weight = reinterpret_cast(bundle.get(1)) + + group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; + // get param + float_t* dst = kern_param.dst(batch_id, group_id) + + oh_idx * oh_block * ow * pack_c; + const float* bptr = + kern_param.bias(batch_id, group_id) + oc_idx; + Op op; +#define KERN1_NCHW44_CONV(filter) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw_nchw44< \ + \ + bias_mode, Op>(sptr, packed_weight, bptr, nullptr, dst, oc_block, \ + ic, ih_real, iw2, oh, oh_block_real, ow, op, ph, \ + pw) + + DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); +#undef KERN1_NCHW44_CONV +} + +} // namespace + +/* ===================== stride2 algo ===================== */ +bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + auto&& fm = param.filter_meta; + auto fh = fm.spatial[0]; + int oc = fm.ocpg; + bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + (param.dst_type.enumv() == DTypeEnum::Float32))) && + (fm.format == param::Convolution::Format::NCHW44); + bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1; + bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && + (fh == 3 || fh == 5 || fh == 7); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2; + bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; + bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; + return avaible; +} + +size_t ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + return get_bundle(param).total_size_in_bytes(); +} + +SmallVector +ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + auto fm = param.filter_meta; + const int batch = param.n; + const int group = fm.group; + WorkspaceBundle wbundle = get_bundle(param); + conv_fun do_conv_fun = nullptr; + // NOTE: remain_w is not used to gen hash of midout for compatible with +// shape runtime +#define DO_CONV_KERN_FUN(filter, bias_mode, op) \ + MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2, \ + midout_iv(#filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(filter, bias_mode, NoneOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(filter, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(filter, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } +#define GET_BIAS_MODE_PARAM(filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define DISPATCH_CONV_KERN() \ + switch (param.filter_meta.spatial[0]) { \ + case 3: \ + GET_BIAS_MODE_PARAM(3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(5) \ + break; \ + case 7: \ + GET_BIAS_MODE_PARAM(7) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + DISPATCH_CONV_KERN(); + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN + + megdnn_assert(do_conv_fun); + + SmallVector ret_kerns; + WorkspaceBundle bundle = wbundle; + int oh = param.osz[0]; + int oh_block = block_helper(param.nr_threads, oh, 0); + auto do_pack_weight = [bundle](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + pack_weight(bundle, kern_param, ncb_index); + }; + ret_kerns.push_back({do_pack_weight, {static_cast(group)}}); + CpuNDRange ncb_range = {static_cast(batch), + static_cast(group), + static_cast(div_ceil(oh, oh_block))}; + auto do_conv = [bundle, do_conv_fun, ncb_range]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, + ncb_range); + }; + ret_kerns.push_back({do_conv, ncb_range}); + + return ret_kerns; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp new file mode 100644 index 000000000..fc9aca9f9 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp @@ -0,0 +1,430 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h" +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +using namespace megdnn; +using namespace arm_common; +namespace { + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight); +}; + +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { + constexpr int stride = 2; +#define cb(step) \ + c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ + c[0][step], weight[0][weight_idx], \ + src[(step * stride + src_idx) / 4]); \ + c[1][step] = Func::template impl<(step * stride + src_idx) % 4>( \ + c[1][step], weight[1][weight_idx], \ + src[(step * stride + src_idx) / 4]); + + UNROLL_CALL_RAW(8, cb); +#undef cb + } +}; +template +struct ShiftCalHelper { + static void impl(T& c, T2& src, T3& weight) { + constexpr int stride = 2; +#define cb(step) \ + c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ + c[0][step], weight[0][weight_idx], \ + src[(step * stride + src_idx) / 4]); + + UNROLL_CALL_RAW(8, cb); +#undef cb + } +}; + +template +inline void cal_helper(T& c, T2& src, T3& weight) { + ShiftCalHelper::impl( + c, src, weight); +}; +template +struct OCHelper { +public: + static const int val = -1; +}; + +template <> +struct OCHelper<4> { +public: + static const int val = 1; +}; + +template <> +struct OCHelper<8> { +public: + static const int val = 2; +}; +/** + * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel + * */ +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op); +}; +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 7; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + constexpr int src_reg_size = 6; + + constexpr int ld_weight_fw = oc_step * filter_size; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_ic = oc_step * filter_size * filter_size; + const int ld_src_ic = ih * iw; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + float32x4_t src[src_reg_size]; + float32x4_t weight[c_dim][filter_size]; + +#define KERNEL_CB(step) \ + load_helper( \ + src, src_ptr + step * iw, 0); \ + load_helper( \ + weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ + cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ + cal_helper<3, 3, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ + cal_helper<4, 4, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ + cal_helper<5, 5, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ + cal_helper<6, 6, c_dim, Vfmaq_laneq_f32>(c, src, weight); + + UNROLL_CALL_RAW(7, KERNEL_CB) +#undef KERNEL_CB + + src_ptr += ld_src_ic; + weight_ptr += ld_weight_ic; + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 5; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + constexpr int src_reg_size = 5; + + constexpr int ld_weight_fw = oc_step * filter_size; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_ic = oc_step * filter_size * filter_size; + const int ld_src_ic = ih * iw; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + float32x4_t src[src_reg_size]; + float32x4_t weight[c_dim][filter_size]; + +#define KERNEL_CB(step) \ + load_helper( \ + src, src_ptr + step * iw, 0); \ + load_helper( \ + weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ + cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ + cal_helper<3, 3, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ + cal_helper<4, 4, c_dim, Vfmaq_laneq_f32>(c, src, weight); + UNROLL_CALL_RAW(5, KERNEL_CB) +#undef KERNEL_CB + + src_ptr += ld_src_ic; + weight_ptr += ld_weight_ic; + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 3; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + constexpr int src_reg_size = 5; + + constexpr int ld_weight_fw = oc_step * filter_size; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_ic = oc_step * filter_size * filter_size; + const int ld_src_ic = ih * iw; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + float32x4_t src[src_reg_size]; + float32x4_t weight[c_dim][filter_size]; + // row 0 + load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); + load_helper<3, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, + ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); + + // row 1 + load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0); + load_helper<3, 0, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); + + // row 2 + load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + 2 * iw, 0); + load_helper<3, 0, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); + + src_ptr += ld_src_ic; + weight_ptr += ld_weight_ic; + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +} // namespace + +void conv_bias::pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, + float32_t* dst_ptr, const int oc, + const int kh, const int kw, + const int ic) { + constexpr int oc_step = 4; + const int filter_oc_stride = kh * kw * ic; + const int filter_ic_stride = kh * kw * oc_step; + for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const float32_t* in_ptr_oc = in_ptr + oc_idx * filter_oc_stride; + float32_t* dst_ptr_oc = dst_ptr + oc_idx * filter_oc_stride; + for (int kh_idx = 0; kh_idx < kh; ++kh_idx) { + for (int kw_idx = 0; kw_idx < kw; ++kw_idx) { + for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { + float32x4_t vsrc = vld1q_f32(in_ptr_oc); + vst1q_f32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc); + in_ptr_oc += oc_step; + } + dst_ptr_oc += oc_step; + } + } + } +} + +template +static void conv_direct_stride2_fp32_nchw_nchw44( + const float32_t* src, const float32_t* filter, const float32_t* bias, + float32_t*, float32_t* dst, const int oc, const int ic, const int ih, + const int iw, const int oh, const int oh_block, const int ow, + const Op& op, const int, const int) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 1; + constexpr int big_oc_step = 8; + constexpr int oc_step = 4; + constexpr int ih_step = 1; + constexpr int oh_step = 1; + constexpr int ow_step = 8; + constexpr int stride_h = 2; + constexpr int stride_w = 2; + constexpr int pack_iw_len = 1; + + const int img_stride = oh * ow; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs2NchwNchw44FP32::impl; \ + kern_small_oc_remain = \ + KerNeonXXs2NchwNchw44FP32::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44FP32< + bias_mode, Op, 0, filter_size, + big_oc_step>::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + int oc_idx = oc_end; + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44FP32::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + } + } +} + +#define CONSTRUCT_FUNC(filter_size) \ + template \ + void conv_bias:: \ + conv_direct_stride2_##filter_size##x##filter_size##_fp32_nchw_nchw44( \ + const float32_t* src, const float32_t* filter, \ + const float32_t* bias, float32_t* temp, float32_t* dst, \ + const int oc, const int ic, const int ih, const int iw, \ + const int oh, const int oh_block, const int ow, \ + const Op& op, const int ph, const int pw) { \ + conv_direct_stride2_fp32_nchw_nchw44( \ + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \ + ow, op, ph, pw); \ + } + +CONSTRUCT_FUNC(3); +CONSTRUCT_FUNC(5); +CONSTRUCT_FUNC(7); +#undef CONSTRUCT_FUNC + +template +void conv_bias::conv_direct_stride2_2x2_fp32_nchw_nchw44( + const float32_t*, const float32_t*, const float32_t*, float32_t*, + float32_t*, const int, const int, const int, const int, const int, + const int, const int, const Op&, const int, const int) { + megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv"); +} + +#define INSTANTIATION(stride, i, bias, Op) \ + template void conv_bias:: \ + conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44( \ + const float32_t*, const float32_t*, const float32_t*, \ + float32_t*, float32_t*, const int, const int, const int, \ + const int, const int, const int, const int, const Op&, \ + const int, const int); + +#define FOR_OP(stride, i, bias) \ + INSTANTIATION(stride, i, bias, NoneOp) \ + INSTANTIATION(stride, i, bias, ReluOp) \ + INSTANTIATION(stride, i, bias, HSwishOp) + +#define FOR_BIAS(stride, i) \ + FOR_OP(stride, i, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) \ + FOR_BIAS(stride, 7) + +FOR_FILTER(stride2) + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_IC +#undef FOR_BIAS +#undef FOR_NONLINEAR +#undef FOR_REMAIN +#undef INSTANTIATION diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h new file mode 100644 index 000000000..ec3fca810 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h @@ -0,0 +1,38 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace conv_bias { +#define KERN(stride, i, layout) \ + template \ + void conv_direct_##stride##_##i##x##i##_fp32_nchw_##layout( \ + const float* src, const float* filter, const float* bias, \ + float* temp, float* dst, const int oc, const int ic, const int ih, \ + const int iw, const int oh, const int oh_block, const int ow, \ + const Op& op, const int ph, const int pw); + +KERN(stride2, 2, nchw44) +KERN(stride2, 3, nchw44) +KERN(stride2, 5, nchw44) +KERN(stride2, 7, nchw44) +#undef KERN +void pack_weight_fp32_nchw_nchw44(const float_t* in_ptr, float_t* dst_ptr, + const int oc, const int kh, const int kw, + const int ic); + +} // namespace conv_bias +} // namespace arm_common +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/intrinsic_helper.h b/dnn/src/arm_common/conv_bias/intrinsic_helper.h index 5946b3fcd..a5a6c6c5e 100644 --- a/dnn/src/arm_common/conv_bias/intrinsic_helper.h +++ b/dnn/src/arm_common/conv_bias/intrinsic_helper.h @@ -174,7 +174,167 @@ inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { StoreOcxOw4Remain::impl(c, op, dst_ptr, ld_dst_oc); } +////////////////////Store_OCX_OW8_Remain///////////////////////// +template +struct StoreOcxOw8Remain { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc); +}; + +template +struct StoreOcxOw8Remain<2, 0, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + op({{c[0][4], c[0][5]}}, dst_ptr + 16); + op({{c[0][6], c[0][7]}}, dst_ptr + 24); + + op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); + op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); + op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); + op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24); + } +}; + +template +struct StoreOcxOw8Remain<2, 7, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + op({{c[0][4], c[0][5]}}, dst_ptr + 16); + op(c[0][6], dst_ptr + 24); + + op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); + op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); + op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); + op(c[1][6], dst_ptr + ld_dst_oc + 24); + } +}; +template +struct StoreOcxOw8Remain<2, 6, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + op({{c[0][4], c[0][5]}}, dst_ptr + 16); + + op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); + op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); + op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); + } +}; +template +struct StoreOcxOw8Remain<2, 5, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + op(c[0][4], dst_ptr + 16); + + op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); + op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); + op(c[1][4], dst_ptr + ld_dst_oc + 16); + } +}; +template +struct StoreOcxOw8Remain<2, 4, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + + op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); + op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); + } +}; +template +struct StoreOcxOw8Remain<2, 3, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op(c[0][2], dst_ptr + 8); + + op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); + op(c[1][2], dst_ptr + ld_dst_oc + 8); + } +}; +template +struct StoreOcxOw8Remain<2, 2, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); + } +}; +template +struct StoreOcxOw8Remain<2, 1, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { + op(c[0][0], dst_ptr); + op(c[1][0], dst_ptr + ld_dst_oc); + } +}; + +template +struct StoreOcxOw8Remain<1, 0, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + op({{c[0][4], c[0][5]}}, dst_ptr + 16); + op({{c[0][6], c[0][7]}}, dst_ptr + 24); + } +}; + +template +struct StoreOcxOw8Remain<1, 7, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + op({{c[0][4], c[0][5]}}, dst_ptr + 16); + op(c[0][6], dst_ptr + 24); + } +}; +template +struct StoreOcxOw8Remain<1, 6, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + op({{c[0][4], c[0][5]}}, dst_ptr + 16); + } +}; +template +struct StoreOcxOw8Remain<1, 5, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + op(c[0][4], dst_ptr + 16); + } +}; +template +struct StoreOcxOw8Remain<1, 4, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op({{c[0][2], c[0][3]}}, dst_ptr + 8); + } +}; +template +struct StoreOcxOw8Remain<1, 3, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { + op({{c[0][0], c[0][1]}}, dst_ptr); + op(c[0][2], dst_ptr + 8); + } +}; +template +struct StoreOcxOw8Remain<1, 2, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { + op({{c[0][0], c[0][1]}}, dst_ptr); + } +}; +template +struct StoreOcxOw8Remain<1, 1, Op, T> { + static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { + op(c[0][0], dst_ptr); + } +}; +template +inline void store_ocx_ow8_remain_static(T& c, const Op& op, float32_t* dst_ptr, + int ld_dst_oc) { + StoreOcxOw8Remain::impl(c, op, dst_ptr, ld_dst_oc); +} ////////////////////Store_OC8_OW8_Remain///////////////////////// template @@ -299,14 +459,15 @@ struct Store_OC8_OW8_Remain<1, Op> { } }; -template -inline void store_oc8_ow8_remain_static(int32x4_t c[2][8], const Op& op, - int8_t* dst_ptr, int ld_dst_oc) { +/////////// + +template +inline void store_oc8_ow8_remain_static(T& c, const Op& op, T2 dst_ptr, + int ld_dst_oc) { Store_OC8_OW8_Remain::impl(c, op, dst_ptr, ld_dst_oc); } -/////////////////////////////////////////////////////// - +////////////////////////////////////// template inline void init_oc4_ow8(int32x4_t c[8], const int32_t* bias_ptr) { if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { @@ -337,6 +498,49 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr, #undef BAIS_INIT } } +/////////////////////////init_ocx_ow8//////////////////// +template +struct InitOcxOw8 { + static void impl(T& c, T2 bias_ptr, int oc_step); +}; +template +struct InitOcxOw8<2, bias_mode, T, T2> { + static void impl(T& c, const float32_t* bias_ptr, int oc_step) { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { +#define BAIS_INIT(step) \ + c[0][step] = vld1q_f32(bias_ptr); \ + c[1][step] = vld1q_f32(bias_ptr + oc_step); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } else { +#define BAIS_INIT(step) \ + c[0][step] = vdupq_n_f32(0); \ + c[1][step] = vdupq_n_f32(0); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } + } +}; +template +struct InitOcxOw8<1, bias_mode, T, T2> { + static void impl(T& c, const float32_t* bias_ptr, int) { + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { +#define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } else { +#define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0); + UNROLL_CALL_RAW(8, BAIS_INIT); +#undef BAIS_INIT + } + } +}; + +template +inline void init_ocx_ow8(T& c, T2 bias_ptr, int oc_step) { + InitOcxOw8::impl(c, bias_ptr, oc_step); +} +/////////////////////init_ocx_ow4///////////////////// template struct InitOcxOw4 { static void impl(T& c, const int32_t* bias_ptr, int oc_step); @@ -383,57 +587,54 @@ inline void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) { } /////////////////////////////////////// template + typename Func, typename T, typename T2, typename... XT> struct LoadHelper { - static void impl(T& weight, const int8_t* ptr, int oc_offset, XT... args); + static void impl(T& weight, T2 ptr, int oc_offset, XT... args); }; #define WEIGHT_CB(step) \ src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...); -template -struct LoadHelper<1, base_offset, ptr_step, 0, Func, T, XT...> { - static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { +struct LoadHelper<1, base_offset, ptr_step, 0, Func, T, T2, XT...> { + static void impl(T& src, T2 ptr, int, XT... args) { UNROLL_CALL_RAW(1, WEIGHT_CB); } }; -template -struct LoadHelper<2, base_offset, ptr_step, 0, Func, T, XT...> { - static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { +struct LoadHelper<2, base_offset, ptr_step, 0, Func, T, T2, XT...> { + static void impl(T& src, T2 ptr, int, XT... args) { UNROLL_CALL_RAW(2, WEIGHT_CB); } }; -template -struct LoadHelper<3, base_offset, ptr_step, 0, Func, T, XT...> { - static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { +struct LoadHelper<3, base_offset, ptr_step, 0, Func, T, T2, XT...> { + static void impl(T& src, T2 ptr, int, XT... args) { UNROLL_CALL_RAW(3, WEIGHT_CB); } }; -template -struct LoadHelper<4, base_offset, ptr_step, 0, Func, T, XT...> { - static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { - MEGDNN_MARK_USED_VAR(oc_offset); +struct LoadHelper<4, base_offset, ptr_step, 0, Func, T, T2, XT...> { + static void impl(T& src, T2 ptr, int, XT... args) { UNROLL_CALL_RAW(4, WEIGHT_CB); } }; -template -struct LoadHelper<5, base_offset, ptr_step, 0, Func, T, XT...> { - static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { - MEGDNN_MARK_USED_VAR(oc_offset); +struct LoadHelper<5, base_offset, ptr_step, 0, Func, T, T2, XT...> { + static void impl(T& src, T2 ptr, int, XT... args) { UNROLL_CALL_RAW(5, WEIGHT_CB); } }; -template -struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> { - static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { - MEGDNN_MARK_USED_VAR(oc_offset); +struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, T2, XT...> { + static void impl(T& src, T2 ptr, int, XT... args) { UNROLL_CALL_RAW(6, WEIGHT_CB); } }; @@ -441,27 +642,36 @@ struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> { #define WEIGHT_CB(step) \ src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); -template -struct LoadHelper<1, base_offset, ptr_step, 1, Func, T> { - static void impl(T& src, const int8_t* ptr, int oc_offset) { - MEGDNN_MARK_USED_VAR(oc_offset); - UNROLL_CALL_RAW(1, WEIGHT_CB); - } +template +struct LoadHelper<1, base_offset, ptr_step, 1, Func, T, T2> { + static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(1, WEIGHT_CB); } }; -template -struct LoadHelper<2, base_offset, ptr_step, 1, Func, T> { - static void impl(T& src, const int8_t* ptr, int oc_offset) { - MEGDNN_MARK_USED_VAR(oc_offset); - UNROLL_CALL_RAW(2, WEIGHT_CB); - } +template +struct LoadHelper<2, base_offset, ptr_step, 1, Func, T, T2> { + static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(2, WEIGHT_CB); } }; -template -struct LoadHelper<3, base_offset, ptr_step, 1, Func, T> { - static void impl(T& src, const int8_t* ptr, int oc_offset) { - MEGDNN_MARK_USED_VAR(oc_offset); - UNROLL_CALL_RAW(3, WEIGHT_CB); - } +template +struct LoadHelper<3, base_offset, ptr_step, 1, Func, T, T2> { + static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(3, WEIGHT_CB); } +}; +template +struct LoadHelper<4, base_offset, ptr_step, 1, Func, T, T2> { + static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(4, WEIGHT_CB); } +}; + +template +struct LoadHelper<5, base_offset, ptr_step, 1, Func, T, T2> { + static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(5, WEIGHT_CB); } +}; +template +struct LoadHelper<6, base_offset, ptr_step, 1, Func, T, T2> { + static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(6, WEIGHT_CB); } +}; + +template +struct LoadHelper<7, base_offset, ptr_step, 1, Func, T, T2> { + static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(7, WEIGHT_CB); } }; #undef WEIGHT_CB @@ -470,40 +680,63 @@ struct LoadHelper<3, base_offset, ptr_step, 1, Func, T> { src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \ src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset); -template -struct LoadHelper<1, base_offset, ptr_step, 2, Func, T> { - static void impl(T& src, const int8_t* ptr, int oc_offset) { +template +struct LoadHelper<1, base_offset, ptr_step, 2, Func, T, T2> { + static void impl(T& src, T2 ptr, int oc_offset) { UNROLL_CALL_RAW(1, WEIGHT_CB); } }; -template -struct LoadHelper<2, base_offset, ptr_step, 2, Func, T> { - static void impl(T& src, const int8_t* ptr, int oc_offset) { +template +struct LoadHelper<2, base_offset, ptr_step, 2, Func, T, T2> { + static void impl(T& src, T2 ptr, int oc_offset) { UNROLL_CALL_RAW(2, WEIGHT_CB); } }; -template -struct LoadHelper<3, base_offset, ptr_step, 2, Func, T> { - static void impl(T& src, const int8_t* ptr, int oc_offset) { +template +struct LoadHelper<3, base_offset, ptr_step, 2, Func, T, T2> { + static void impl(T& src, T2 ptr, int oc_offset) { UNROLL_CALL_RAW(3, WEIGHT_CB); } }; +template +struct LoadHelper<4, base_offset, ptr_step, 2, Func, T, T2> { + static void impl(T& src, T2 ptr, int oc_offset) { + UNROLL_CALL_RAW(4, WEIGHT_CB); + } +}; +template +struct LoadHelper<5, base_offset, ptr_step, 2, Func, T, T2> { + static void impl(T& src, T2 ptr, int oc_offset) { + UNROLL_CALL_RAW(5, WEIGHT_CB); + } +}; +template +struct LoadHelper<6, base_offset, ptr_step, 2, Func, T, T2> { + static void impl(T& src, T2 ptr, int oc_offset) { + UNROLL_CALL_RAW(6, WEIGHT_CB); + } +}; +template +struct LoadHelper<7, base_offset, ptr_step, 2, Func, T, T2> { + static void impl(T& src, T2 ptr, int oc_offset) { + UNROLL_CALL_RAW(7, WEIGHT_CB); + } +}; #undef WEIGHT_CB template -inline void load_helper(T& weight, const int8_t* ptr, int oc_offset) { - LoadHelper::impl( + typename Func, typename T, typename T2> +inline void load_helper(T& weight, T2 ptr, int oc_offset) { + LoadHelper::impl( weight, ptr, oc_offset); } template -inline void load_helper_x(T& weight, const int8_t* ptr, int oc_offset, - XT... args) { - LoadHelper +inline void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) { + LoadHelper::impl(weight, ptr, oc_offset, args...); } diff --git a/dnn/src/arm_common/conv_bias/neon_struct.h b/dnn/src/arm_common/conv_bias/neon_struct.h index 535674ec3..4303689bb 100644 --- a/dnn/src/arm_common/conv_bias/neon_struct.h +++ b/dnn/src/arm_common/conv_bias/neon_struct.h @@ -34,6 +34,9 @@ struct Vmlal_s16 { struct Vld1q_s8 { static int8x16_t impl(const int8_t* ptr) { return vld1q_s8(ptr); } }; +struct Vld1q_f32 { + static float32x4_t impl(const float32_t* ptr) { return vld1q_f32(ptr); } +}; struct Vld1_s8 { static int8x8_t impl(const int8_t* ptr) { return vld1_s8(ptr); } }; @@ -50,5 +53,13 @@ struct Vldq_tbl_low_s8 { struct Vld1_dup_s8_s16 { static int16x8_t impl(const int8_t* ptr) { return vld1_dup_s8_s16(ptr); } }; + +struct Vfmaq_laneq_f32 { + template + static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + return vfmaq_laneq_f32(a, b, v, lane); + } +}; + } // namespace } // namespace megdnn \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index b1f7808e5..b2a81eaf2 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -71,6 +71,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; AlgoF32DirectStride1 f32_direct_stride1_large_group{true}; AlgoF32DirectStride1 f32_direct_stride1_small_group{false}; + AlgoF32DirectStride2NCHWNCHW44 f32_direct_stride2_nchw_nchw44; AlgoI8x8x16Direct i8x8x16_direct_large_group{true}; AlgoI8x8x16Direct i8x8x16_direct_small_group{false}; AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true}; @@ -123,6 +124,7 @@ public: direct_algos.emplace_back(&i8x8x16_stride2_filter2); direct_algos.emplace_back(&i8x8x16_stride2_large_group); direct_algos.emplace_back(&i8x8x16_stride2_small_group); + direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); direct_algos.emplace_back(&f32_direct_stride1_large_group); direct_algos.emplace_back(&f32_direct_stride1_small_group); direct_algos.emplace_back(&f32_direct_stride2_large_group); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index f21dba87d..40eac1474 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -67,6 +67,7 @@ private: class AlgoF32Direct; class AlgoF32DirectStride1; class AlgoF32DirectStride2; + class AlgoF32DirectStride2NCHWNCHW44; class AlgoI8x8x16Direct; class AlgoI8x8x16Stride2; class AlgoI8x8x16Stride2Filter2; diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h b/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h index dccd94868..d1774ada2 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h @@ -45,13 +45,17 @@ struct HSwishOp; vst1q_##_func_suffix(dst, vitem.val[0]); \ vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ } \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ _neon_type2 operator()(const _neon_type2& src) const { \ auto val1 = src.val[0]; \ auto val2 = src.val[1]; \ H_SWISH_KERN(_func_suffix, val1, val2); \ return {{val1, val2}}; \ } \ - _neon_type operator()(const _neon_type& src) { \ + _neon_type operator()(const _neon_type& src) const { \ auto val_zero = vdupq_n_##_func_suffix(0.f); \ auto val_six = vdupq_n_##_func_suffix(6.f); \ auto val_three = vdupq_n_##_func_suffix(3.f); \ @@ -64,6 +68,7 @@ struct HSwishOp; val_rec_six); \ } \ }; + OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/none.h b/dnn/src/arm_common/elemwise_helper/kimpl/none.h index 6cf5bd000..224148eb3 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/none.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/none.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -30,6 +31,13 @@ struct NoneOp; using NoneOpBase::operator(); \ constexpr static size_t SIMD_WIDTH = _simd_width; \ _neon_type2 operator()(const _neon_type2& src) const { return src; } \ + void operator()(const _neon_type2& src, _ctype* dst) const { \ + vst1q_##_func_suffix(dst, src.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, src.val[1]); \ + } \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + vst1q_##_func_suffix(dst, src); \ + } \ _neon_type operator()(const _neon_type& src) const { return src; } \ }; diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/relu.h b/dnn/src/arm_common/elemwise_helper/kimpl/relu.h index 5335070c1..76949ddbe 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/relu.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/relu.h @@ -47,11 +47,16 @@ struct ReluOp; auto vitem1 = vmaxq_##_func_suffix(src.val[1], vzero); \ return {{vitem0, vitem1}}; \ } \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ _neon_type operator()(const _neon_type& src) const { \ auto vzero = vdupq_n_##_func_suffix(0); \ return vmaxq_##_func_suffix(src, vzero); \ } \ }; + OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) diff --git a/dnn/src/arm_common/simd_macro/marm_neon.h b/dnn/src/arm_common/simd_macro/marm_neon.h index f0dc8458c..a923ff872 100644 --- a/dnn/src/arm_common/simd_macro/marm_neon.h +++ b/dnn/src/arm_common/simd_macro/marm_neon.h @@ -479,6 +479,39 @@ UNROLL_CALL_RAW(4, cb); #undef cb } // namespace #define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7::impl(vec) +namespace { +template +struct Vfmap_laneq_f32_armv7 { + static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v); +}; + +template <> +struct Vfmap_laneq_f32_armv7<0> { + static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + return vmlaq_lane_f32(a, b, vget_low_f32(v), 0); + } +}; +template <> +struct Vfmap_laneq_f32_armv7<1> { + static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + return vmlaq_lane_f32(a, b, vget_low_f32(v), 1); + } +}; +template <> +struct Vfmap_laneq_f32_armv7<2> { + static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + return vmlaq_lane_f32(a, b, vget_high_f32(v), 0); + } +}; +template <> +struct Vfmap_laneq_f32_armv7<3> { + static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + return vmlaq_lane_f32(a, b, vget_high_f32(v), 1); + } +}; +} // namespace +#define vfmaq_laneq_f32(a, b, v, lane) \ + Vfmap_laneq_f32_armv7::impl(a, b, v) #endif diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index fe7f01104..65134ce75 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -85,7 +85,7 @@ TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QU8) { #if MEGDNN_WITH_BENCHMARK -static void benchmark_convbias(Handle* handle) { +static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { constexpr size_t RUNS = 30; Benchmarker benchmarker_int(handle); @@ -102,15 +102,25 @@ static void benchmark_convbias(Handle* handle) { Benchmarker benchmarker_float(handle); benchmarker_float.set_display(false).set_times(RUNS); benchmarker_float.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker(".+")); + conv_bias::ConvBiasAlgoChecker( + "IM2COLMATMUL:AARCH64_F32K8X12X1:192")); Benchmarker benchmarker_int_nchw44(handle); - benchmarker_int_nchw44.set_times(RUNS) - .set_dtype(0, dtype::QuantizedS8(2.5)) - .set_dtype(1, dtype::QuantizedS8(2.5)) - .set_dtype(2, dtype::QuantizedS32(6.25)) - .set_dtype(4, dtype::QuantizedS8(60.25)) - .set_display(false); + if (is_fp32) { + benchmarker_int_nchw44.set_times(RUNS) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_dtype(4, dtype::Float32()) + .set_display(false); + } else { + benchmarker_int_nchw44.set_times(RUNS) + .set_dtype(0, dtype::QuantizedS8(2.5)) + .set_dtype(1, dtype::QuantizedS8(2.5)) + .set_dtype(2, dtype::QuantizedS32(6.25)) + .set_dtype(4, dtype::QuantizedS8(60.25)) + .set_display(false); + } benchmarker_int_nchw44.set_before_exec_callback( conv_bias::ConvBiasAlgoChecker(".+")); @@ -151,7 +161,6 @@ static void benchmark_convbias(Handle* handle) { auto int_nchw44_used = benchmarker_int_nchw44.set_param(param).exec( {src, filter, bias, {}, dst}) / RUNS; - float computations = IC * (FS * FS) * dst.total_nr_elems() * 2 * 1e-6; printf("run: %s %s %s->%s \n", src.to_string().c_str(), filter.to_string().c_str(), bias.to_string().c_str(), @@ -160,32 +169,42 @@ static void benchmark_convbias(Handle* handle) { computations / float_used); printf("int_nchw: %f ms %f Gflops, ", int_used, computations / int_used); - printf("int_nchw44: %f ms %f Gflops %f speedup, ", int_nchw44_used, - computations / int_nchw44_used, int_used / int_nchw44_used); + auto speed_up = int_used / int_nchw44_used; + if (is_fp32) { + speed_up = float_used / int_nchw44_used; + printf("fp32_nchw44: %f ms %f Gflops %f speedup, ", int_nchw44_used, + computations / int_nchw44_used, speed_up); + } else { + printf("int_nchw44: %f ms %f Gflops %f speedup, ", int_nchw44_used, + computations / int_nchw44_used, speed_up); + } printf("\n"); }; - run(1, 3, 32, 224, 224, 3, 2, true); - run(1, 3, 64, 224, 224, 5, 2, true); - run(1, 3, 64, 224, 224, 7, 2, true); - run(1, 3, 32, 224, 224, 7, 2, true); - for (size_t stride : {1, 2}) { - printf("stride %zu\n", stride); - for (size_t filter_size : {2, 3, 5, 7}) { - for (size_t img_size : {32}) { - for (size_t channel : {8, 16, 32, 64, 128, 256}) { - run(1, channel, channel, img_size, img_size, filter_size, - stride, false); + + if (is_fp32) { + run(1, 3, 32, 224, 224, 3, 2, true); + run(1, 3, 64, 224, 224, 7, 2, true); + } else { + for (size_t stride : {1, 2}) { + printf("stride %zu\n", stride); + for (size_t filter_size : {2, 3, 5, 7}) { + for (size_t img_size : {32}) { + for (size_t channel : {8, 16, 32, 64, 128, 256}) { + run(1, channel, channel, img_size, img_size, + filter_size, stride, false); + } } } } } } TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { - benchmark_convbias(handle()); + benchmark_convbias(handle(), true); } TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { - benchmark_convbias(handle()); + benchmark_convbias(handle(), true); } + #endif TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QS8) { using namespace conv_bias; @@ -1464,7 +1483,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) { #if MEGDNN_WITH_BENCHMARK namespace { -std::vector get_conv_bias_1x1_benchmark_args(size_t pack_size = 1) { +std::vector get_conv_bias_1x1_benchmark_args( + size_t pack_size = 1) { using namespace conv_bias; std::vector args; param::ConvBias param; @@ -1474,15 +1494,17 @@ std::vector get_conv_bias_1x1_benchmark_args(size_t pack_siz param.pad_w = 0; param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; auto bench_case = [&](size_t OC, size_t IC, size_t H, size_t W) { - if(pack_size == 1) + if (pack_size == 1) args.emplace_back(param, TensorShape{1, IC, H, W}, - TensorShape{OC, IC, 1, 1}, TensorShape{}); + TensorShape{OC, IC, 1, 1}, TensorShape{}); else { - if(pack_size == 4) + if (pack_size == 4) param.format = param::ConvBias::Format::NCHW44; - args.emplace_back(param, TensorShape{1, IC / pack_size, H, W, pack_size}, - TensorShape{OC / pack_size, IC / pack_size, 1, 1, pack_size, pack_size}, - TensorShape{}); + args.emplace_back(param, + TensorShape{1, IC / pack_size, H, W, pack_size}, + TensorShape{OC / pack_size, IC / pack_size, 1, 1, + pack_size, pack_size}, + TensorShape{}); } }; diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 28bc67d7a..7740c329a 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -78,9 +78,10 @@ std::vector get_nchw44_conv_bias_args( std::vector args; auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, - size_t kernel, size_t stride, size_t group, NLMode nlmode) { + size_t kernel, size_t stride, size_t group, NLMode nlmode, + int any_pad = -1) { constexpr int pack_c = 4; - const size_t pad = no_pad ? 0 : kernel / 2; + const size_t pad = any_pad >= 0 ? any_pad : kernel / 2; auto bias_mode = no_bias ? megdnn::BiasMode::NO_BIAS : megdnn::BiasMode::BROADCAST_CHANNEL_BIAS; auto oc_per_group = oc / group; @@ -90,7 +91,8 @@ std::vector get_nchw44_conv_bias_args( ic_per_group > 0; bool nchw_disable = group > 1 || ic_per_group >= 4; bool nchw44_disable = ic_per_group % pack_c != 0; - if (!(ok_group)) { + bool invalid_pad = (w + 2 * pad < kernel) || (h + 2 * pad < kernel); + if (!(ok_group) || invalid_pad) { return; } if ((is_input_nchw && nchw_disable) || @@ -107,6 +109,7 @@ std::vector get_nchw44_conv_bias_args( param.pad_h = pad; param.pad_w = pad; param.nonlineMode = nlmode; + auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c}; auto weight_tensor_shape = TensorShape{ oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c}; @@ -338,6 +341,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), handle(), "F32STRD2_SMALL_GROUP"); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) { + check_conv_bias( + get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true), + handle(), "F32_CONV_NCHW_NCHW44"); +} /**********************************F16 direct************************/ #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) { -- GitLab