From 58ca0cf0b0fdc1d78b1b07960d7e94f4ab6655a5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 24 Nov 2022 11:28:12 +0800 Subject: [PATCH] feat(dnn): add fp16 hybrid direct conv algo GitOrigin-RevId: 6192b0ffb40ed812fc386299d0e1354bf4556f11 --- dnn/src/arm_common/conv_bias/block_helper.h | 24 -- dnn/src/arm_common/conv_bias/f16/algos.h | 17 + .../conv_bias/f16/direct_nchw88_algo.cpp | 1 - .../conv_bias/f16/direct_nchw_nchw88_algo.cpp | 302 ++++++++++++++++++ .../conv_bias/f16/direct_nchw_nchw88_kern.h | 41 +++ .../f16/direct_nchw_nchw88_kern_common.h | 287 +++++++++++++++++ .../int8/direct_dotprod_nchw44_algo.cpp | 2 +- .../int8/dot_direct_nchw_nchw44_algo.cpp | 2 +- .../int8x8x16/direct_nchw_nchw44_algo.cpp | 2 +- .../arm_common/conv_bias/intrinsic_helper.h | 199 ++++++------ dnn/src/arm_common/conv_bias/opr_impl.cpp | 2 + dnn/src/arm_common/conv_bias/opr_impl.h | 1 + dnn/src/arm_common/neon_struct.h | 5 + dnn/src/common/nchw_nchwxx_valid.cpp | 15 +- dnn/src/common/nchw_nchwxx_valid.h | 32 ++ dnn/src/common/unroll_macro.h | 17 + dnn/src/fallback/conv_bias/gi/block_helper.h | 2 +- .../fallback/conv_bias/gi/intrinsic_helper.h | 40 +-- dnn/src/fallback/conv_bias/im2col/algos.cpp | 36 ++- dnn/src/fallback/conv_bias/opr_impl.h | 1 + dnn/test/aarch64/conv_bias.cpp | 59 ---- .../arm_common/conv_bias_multi_thread.cpp | 50 +++ .../conv_bias_multi_thread_benchmark.cpp | 62 ++++ dnn/test/common/conv_bias.cpp | 5 +- 24 files changed, 963 insertions(+), 241 deletions(-) delete mode 100644 dnn/src/arm_common/conv_bias/block_helper.h create mode 100644 dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_algo.cpp create mode 100644 dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern.h create mode 100644 dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern_common.h diff --git a/dnn/src/arm_common/conv_bias/block_helper.h b/dnn/src/arm_common/conv_bias/block_helper.h deleted file mode 100644 index 74f0ae749..000000000 --- a/dnn/src/arm_common/conv_bias/block_helper.h +++ /dev/null @@ -1,24 +0,0 @@ -#include "src/common/utils.h" -namespace megdnn { -namespace { -// block_helper is used to calculate oh block size -static inline int l2_block_helper( - const int nthread, const int amount, const int size_per_unit) { - constexpr int l2_cache_size = 256 * 1024; - const int block_per_thread = div_ceil(amount, nthread); - const int best_block = - std::min(amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); - const int max_block_num = div_ceil(block_per_thread, best_block); - const int min_block_num = std::max(max_block_num - 1, 1); - const int max_block = div_ceil(block_per_thread, max_block_num); - const int min_block = div_ceil(block_per_thread, min_block_num); - const int max_loss = std::abs(max_block_num * max_block - block_per_thread); - const int min_loss = std::abs(min_block_num * min_block - block_per_thread); - int block = max_loss > min_loss ? min_block : max_block; - return block; -} - -} // namespace -} // namespace megdnn - -// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/f16/algos.h b/dnn/src/arm_common/conv_bias/f16/algos.h index ee25b5a1a..a759c7e85 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.h +++ b/dnn/src/arm_common/conv_bias/f16/algos.h @@ -156,6 +156,23 @@ public: MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW88_FP16) }; +class ConvBiasImpl::AlgoF16DirectNchwNchw88 final : public AlgoBase { +public: + AlgoF16DirectNchwNchw88() = default; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "DIRECT_CONV_F16_NCHW_NCHW88"; } + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; + } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW88_FP16) +}; + } // namespace arm_common } // namespace megdnn #endif diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp index df72f4be7..68d3828af 100644 --- a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp @@ -1,5 +1,4 @@ #include "megdnn/oprs.h" -#include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/f16/algos.h" #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_algo.cpp b/dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_algo.cpp new file mode 100644 index 000000000..a72d8b030 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_algo.cpp @@ -0,0 +1,302 @@ +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "src/arm_common/conv_bias/f16/algos.h" +#include "src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern.h" +#include "src/arm_common/elemwise_helper/elemwise_op.h" +#include "src/common/nchw_nchwxx_valid.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/gi/block_helper.h" +#include "src/fallback/conv_bias/opr_impl.h" +using namespace megdnn; +using namespace arm_common; + +MIDOUT_DECL(megdnn_arm_common_direct_conv_nchw_nchw88_fp16) + +namespace { + +static void get_rectified_size( + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, + int& iw2, int& oh2, int& ow2, int& block_oh) { + int ic = param.filter_meta.icpg; + int iw = param.isz[1]; + int oh = param.osz[0]; + int ow = param.osz[1]; + + oh2 = oh; + ow2 = ow; + iw2 = iw + 2 * static_cast(param.filter_meta.padding[1]); + const int sh = static_cast(param.filter_meta.stride[0]); + block_oh = l2_block_helper(param.nr_threads, oh, ic * iw2 * sh * sizeof(__fp16)); + const int fh = static_cast(param.filter_meta.spatial[0]); + ih2 = (block_oh - 1) * sh + fh; +} + +static WorkspaceBundle get_bundle( + const fallback::ConvBiasImpl::NCBKernSizeParam& param) { + const auto& fm = param.filter_meta; + const int group = fm.group; + const int ic = fm.icpg; + const int oc = fm.ocpg; + const int fh = fm.spatial[0]; + const int fw = fm.spatial[1]; + int ih2, iw2, oh2, ow2, oh_block; + get_rectified_size(param, ih2, iw2, oh2, ow2, oh_block); + + megdnn_assert(oh_block != 0, "oh_block == 0"); + const size_t src_size = ic * ih2 * iw2 * sizeof(__fp16); + const size_t weight_size = group * oc * ic * fh * fw * sizeof(__fp16); + return {nullptr, {src_size * param.nr_threads, weight_size}}; +} + +static void pack_weight( + const WorkspaceBundle& bundle, + const fallback::ConvBiasImpl::NCBKernParam& kern_param, + const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { + const int group_id = ncb_index.ndrange_id[0]; + const auto& fm = kern_param.filter_meta; + const int oc = fm.ocpg; + const int ic = fm.icpg; + const int fh = fm.spatial[0]; + const int fw = fm.spatial[1]; + + const int oc_idx = 0; + const int oc_block = oc; + + const __fp16* weight = + reinterpret_cast(kern_param.filter(group_id)) + + oc_idx * ic * fh * fw; + __fp16* packed_weight = static_cast<__fp16*>(bundle.get(1)) + + group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; + fp16_direct_nchw_nchw88::pack_weight_fp16_nchw_nchw88( + weight, packed_weight, oc_block, fh, fw, ic); +} + +/** + * @brief Copy data from sptr_origin to sptr, and padding. + * + */ +static inline void src_copy_pad( + __fp16* sptr, const __fp16* sptr_origin, const int pw, const int pad_right, + const int pad_top, const int pad_buttom, const int ic, const int ih, + const int iw, const int iw2, const int ld_src_ic) { + rep(ic_idx, ic) { + const __fp16* ic_sptr_origin = sptr_origin + ic_idx * ld_src_ic; + + memset(sptr, 0, iw2 * pad_top * sizeof(__fp16)); + sptr += iw2 * pad_top; + + rep(ih_idx, ih) { + memset(sptr, 0, pw * sizeof(__fp16)); + sptr += pw; + + memcpy(sptr, ic_sptr_origin, iw * sizeof(__fp16)); + sptr += iw; + ic_sptr_origin += iw; + + memset(sptr, 0, pad_right * sizeof(__fp16)); + sptr += pad_right; + } + + memset(sptr, 0, iw2 * pad_buttom * sizeof(__fp16)); + sptr += iw2 * pad_buttom; + } +} + +template +static void do_conv_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { + const int oc = kern_param.filter_meta.ocpg; + const int oh = kern_param.osz[0]; + const int ow = kern_param.osz[1]; + + const int ic = kern_param.filter_meta.icpg; + const int ih = kern_param.isz[0]; + const int iw = kern_param.isz[1]; + + const int fh = kern_param.filter_meta.spatial[0]; + const int fw = kern_param.filter_meta.spatial[1]; + + const int sh = kern_param.filter_meta.stride[0]; + + const int ph = kern_param.filter_meta.padding[0]; + const int pw = kern_param.filter_meta.padding[1]; + + int ih2, iw2, oh2, ow2, oh_block; + get_rectified_size(kern_param, ih2, iw2, oh2, ow2, oh_block); + + constexpr int pack_oc = 8; + const int batch_id = ncb_index.ndrange_id[0]; + const int group_id = ncb_index.ndrange_id[1]; + int oc_idx = 0; + int oc_block = oc; + const int oh_idx = ncb_index.ndrange_id[2]; + const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); + const int ih_block_real = (oh_block_real - 1) * sh + fh; + const int src_top_pad = std::max(0, ph - oh_idx * oh_block * sh); + const int src_buttom_pad = + std::max(0, (oh_idx * oh_block + oh_block_real - 1) * sh + fh - ph - ih); + const int src_right_pad = std::max(iw2 - iw - pw, 0); + const int src_offset = std::max(oh_idx * oh_block * sh - ph, 0) * iw; + const __fp16* origin_ptr = reinterpret_cast( + kern_param.src(batch_id, group_id)) + + src_offset; + const size_t src_size = sizeof(__fp16) * ic * ih2 * iw2; + __fp16* sptr = reinterpret_cast<__fp16*>( + reinterpret_cast(bundle.get(0)) + ncb_index.thread_id * src_size); + src_copy_pad( + sptr, origin_ptr, pw, src_right_pad, src_top_pad, src_buttom_pad, ic, + ih_block_real - src_top_pad - src_buttom_pad, iw, iw2, ih * iw); + + //! packed weight + const __fp16* weight = reinterpret_cast(bundle.get(1)) + + group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; + __fp16* dst = + reinterpret_cast<__fp16*>(kern_param.dst(batch_id, group_id)) + + oh_idx * oh_block * ow * pack_oc; + const __fp16* bias = reinterpret_cast( + kern_param.bias(batch_id, group_id)) + + oc_idx; + Op op; + fp16_direct_nchw_nchw88::fp16_direct_conv_nchw_nchw88< + bias_mode, Op, filter_size, stride>( + sptr, weight, bias, dst, oc_block, ic, ih_block_real, iw2, oh, + oh_block_real, ow2, op); +} +} // namespace + +bool ConvBiasImpl::AlgoF16DirectNchwNchw88::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + return nchw_nchwxx_valid( + param.src_type.enumv(), param.filter_type.enumv(), param.dst_type.enumv(), + param.filter_meta, param.bias_mode, param.nonlineMode); +} + +size_t ConvBiasImpl::AlgoF16DirectNchwNchw88::get_workspace( + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_direct_conv_nchw_nchw88_fp16, + midout_iv("AlgoF16DirectNchwNchw88::get_workspace"_hash)) { + return get_bundle(param).total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} + +SmallVector ConvBiasImpl::AlgoF16DirectNchwNchw88:: + dispatch_kerns(const NCBKernSizeParam& param) const { + using conv_func_ptr = std::function; + const auto& fm = param.filter_meta; + const int batch = param.n; + const int group = fm.group; + auto bundle = get_bundle(param); + conv_func_ptr conv_func = nullptr; + +#define CONV_FUNC(bias_mode, op, filter, stride) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_direct_conv_nchw_nchw88_fp16, \ + midout_iv(#bias_mode #op #filter #stride##_hash)) { \ + conv_func = do_conv_kern; \ + } \ + MIDOUT_END(); + +#define FOR_OP(bias_mode, filter, stride) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + CONV_FUNC(bias_mode, NoneOp<__fp16>, filter, stride); \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + CONV_FUNC(bias_mode, ReluOp<__fp16>, filter, stride); \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + CONV_FUNC(bias_mode, HSwishOp<__fp16>, filter, stride); \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define FOR_BIAS_MODE(filter, stride) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + FOR_OP(BiasMode::NO_BIAS, filter, stride); \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + FOR_OP(BiasMode::BROADCAST_CHANNEL_BIAS, filter, stride); \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define FOR_FILTER(stride) \ + switch (fm.spatial[0]) { \ + case 2: \ + FOR_BIAS_MODE(2, stride); \ + break; \ + case 3: \ + FOR_BIAS_MODE(3, stride); \ + break; \ + case 5: \ + FOR_BIAS_MODE(5, stride); \ + break; \ + case 7: \ + FOR_BIAS_MODE(7, stride); \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define FOR_STRIDE() \ + switch (fm.stride[0]) { \ + case 1: \ + FOR_FILTER(1); \ + break; \ + case 2: \ + FOR_FILTER(2); \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ + } + + FOR_STRIDE() + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_BIAS_MODE +#undef FOR_OP +#undef CONV_FUNC + + megdnn_assert(conv_func); + + SmallVector ret_kerns; + int oh = param.osz[0]; + + int ih2, iw2, oh2, ow2, oh_block; + get_rectified_size(param, ih2, iw2, oh2, ow2, oh_block); + auto do_pack_weight = [bundle]( + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) mutable { + bundle.set(kern_param.workspace_ptr); + pack_weight(bundle, kern_param, ncb_index); + }; + ret_kerns.push_back({do_pack_weight, {static_cast(group)}}); + + CpuNDRange ncb_range{ + static_cast(batch), static_cast(group), + static_cast(div_ceil(oh, oh_block))}; + auto do_conv = [bundle, conv_func]( + const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) mutable { + bundle.set(kern_param.workspace_ptr); + conv_func(bundle, kern_param, ncb_index); + }; + ret_kerns.push_back({do_conv, ncb_range}); + + return ret_kerns; +} +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern.h b/dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern.h new file mode 100644 index 000000000..a26371ff5 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern.h @@ -0,0 +1,41 @@ +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern_common.h" +#pragma once +namespace megdnn { +namespace arm_common { +namespace fp16_direct_nchw_nchw88 { +//! (OC/8, FH, FW, IC, 8) --> (OC/8, IC, FH, FW, 8) +static inline void pack_weight_fp16_nchw_nchw88( + const __fp16* in_ptr, __fp16* dst_ptr, const int oc, const int fh, const int fw, + const int ic) { + constexpr int oc_step = 8; + const int ld_ic = fh * fw * oc_step; + const int ld_oc = ic * fh * fw; + + for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { + const __fp16* in_ptr_oc = in_ptr + ld_oc * oc_idx; + __fp16* dst_ptr_oc = dst_ptr + ld_oc * oc_idx; + for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { + for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { + for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { + vst1q_f16(dst_ptr_oc + ic_idx * ld_ic, vld1q_f16(in_ptr_oc)); + in_ptr_oc += oc_step; + } + dst_ptr_oc += oc_step; + } + } + } +} + +template +static void fp16_direct_conv_nchw_nchw88( + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + const int oc, const int ic, const int ih, const int iw, const int oh, + const int oh_block, const int ow, const Op& op) { + ConvDirectNchwNchw88Fp16::impl( + src, filter, bias, dst, oc, ic, ih, iw, oh, oh_block, ow, op); +} +} // namespace fp16_direct_nchw_nchw88 +} // namespace arm_common +} // namespace megdnn +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern_common.h b/dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern_common.h new file mode 100644 index 000000000..1d3cef4c7 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw_nchw88_kern_common.h @@ -0,0 +1,287 @@ +#pragma once +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "src/arm_common/conv_bias/intrinsic_helper.h" +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/arm_common/intrinsic_helper.h" +#include "src/arm_common/neon_struct.h" +#include "src/common/unroll_macro.h" +#include "src/fallback/conv_bias/common.h" +namespace megdnn { +namespace arm_common { +namespace { + +/** + * @brief Say src -> (N, IC, IH, IW), weight -> (OC, IC, FH, FW), bias -> (1, OC, 1, 1) + * Calculate (n, ic, ih, iw) * (oc : oc + nr_oc_block * oc_block, ic, fh, fw) + (0, + * oc : oc + nr_oc_block * oc_block, 0, 0) + * + * @tparam src_idx related to ih and iw above + * @tparam weight_idx related to fh and fw above + * @tparam nr_oc_block number of oc block + * @tparam stride + * @tparam nr_ow This function calculates the value of nr_ow positions at a time. + * @tparam T1 + * @tparam T2 + * @tparam T3 + */ +template < + int src_idx, int weight_idx, int nr_oc_block, int stride, int nr_ow, + typename T1, typename T2, typename T3> +struct CalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T1& bias, const T2& src, const T3& weight); +}; + +template < + int src_idx, int weight_idx, int nr_oc_block, int stride, typename T1, + typename T2, typename T3> +struct CalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T1& bias, const T2& src, const T3& weight){}; +}; + +#if defined(__ARM_FEATURE_FMA) + +#if MEGDNN_AARCH64 +#define fma_lane_f16(a, b, v, lane) vfmaq_laneq_f16((a), (b), (v), (lane)) +#else +#define fma_lane_f16(a, b, v, lane) \ + vfmaq_f16((a), (b), vdupq_n_f16(vgetq_lane_f16(v, lane))) +#endif + +#else + +#if MEGDNN_AARCH64 +#define fma_lane_f16(a, b, v, lane) vaddq_f16((a), vmulq_laneq_f16((b), (v), (lane))) +#else +#define fma_lane_f16(a, b, v, lane) \ + vaddq_f16((a), vmulq_n_f16(b, vgetq_lane_f16(v, lane))) +#endif + +#endif + +#define cb1(step) \ + bias[0][step] = fma_lane_f16( \ + bias[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 8], \ + (step * stride + src_idx) % 8); + +#define cb2(step) \ + bias[0][step] = fma_lane_f16( \ + bias[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 8], \ + (step * stride + src_idx) % 8); \ + bias[1][step] = fma_lane_f16( \ + bias[1][step], weight[1][weight_idx], src[(step * stride + src_idx) / 8], \ + (step * stride + src_idx) % 8); + +#define CAL_HELPER(nr_ow) \ + template < \ + int src_idx, int weight_idx, int stride, typename T1, typename T2, \ + typename T3> \ + struct CalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl( \ + T1& bias, const T2& src, const T3& weight) { \ + UNROLL_CALL_NOWRAPPER(nr_ow, cb1); \ + } \ + }; \ + template < \ + int src_idx, int weight_idx, int stride, typename T1, typename T2, \ + typename T3> \ + struct CalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl( \ + T1& bias, const T2& src, const T3& weight) { \ + UNROLL_CALL_NOWRAPPER(nr_ow, cb2); \ + } \ + }; + +CAL_HELPER(1) +CAL_HELPER(2) +CAL_HELPER(3) +CAL_HELPER(4) +CAL_HELPER(5) +CAL_HELPER(6) +CAL_HELPER(7) +CAL_HELPER(8) + +#undef CAL_HELPER +#undef cb2 +#undef cb1 +#undef fma_lane_f16 + +template < + int src_idx, int weight_idx, int nr_oc_block, int stride, int nr_ow, + typename T1, typename T2, typename T3> +MEGDNN_ALWAYS_INLINE void cal_helper(T1& bias, const T2& src, const T3& weight) { + CalHelper::impl( + bias, src, weight); +} + +template +struct OCHelper { + static constexpr int val = -1; +}; +template <> +struct OCHelper<8> { + static constexpr int val = 1; +}; +template <> +struct OCHelper<16> { + static constexpr int val = 2; +}; + +template < + BiasMode bias_mode, typename Op, int nr_ow, int filter, int oc_block, + int stride> //! CHECK +struct KernFilterXStrideXNchwNchw88FP16 { + static void impl( + const __fp16* src_ptr, const __fp16* weight_ptr, const __fp16* bias_ptr, + __fp16* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op); +}; + +#define KERNEL_CB(i) cal_helper(bias, src, weight); + +#define KERNEL(step, FILTER_SIZE) \ + load_helper(src, src_ptr + step * iw, 0); \ + load_helper( \ + weight, weight_ptr + step * ld_weight_fh, ld_weight_oc); \ + UNROLL_CALL_RAW(FILTER_SIZE, KERNEL_CB); + +#define INSTANCE_KERN(FILTER_SIZE) \ + template \ + struct KernFilterXStrideXNchwNchw88FP16< \ + bias_mode, Op, nr_ow, FILTER_SIZE, oc_block, stride> { \ + static void impl( \ + const __fp16* src_ptr, const __fp16* weight_ptr, \ + const __fp16* bias_ptr, __fp16* dst_ptr, int ic, int ih, int iw, \ + int ld_dst_oc, const Op& op) { \ + constexpr int filter_size = FILTER_SIZE; \ + constexpr int oc_step = 8; \ + constexpr int simd_len = 8; \ + constexpr int src_reg_size = \ + ((nr_ow - 1) * stride + filter_size + simd_len - 1) / simd_len; \ + \ + constexpr int ld_weight_fh = filter_size * oc_step; \ + constexpr int ld_weight_ic = ld_weight_fh * filter_size; \ + const int ld_weight_oc = ld_weight_ic * ic; \ + const int ld_src_ic = ih * iw; \ + \ + constexpr int nr_oc_block = OCHelper::val; \ + float16x8_t bias[nr_oc_block][nr_ow]; \ + init_ocx_ow8(bias, bias_ptr, oc_step); \ + \ + rep(ic_idx, ic) { \ + float16x8_t src[src_reg_size], weight[nr_oc_block][filter_size]; \ + UNROLL_CALL_ONE_ARG_RAW(FILTER_SIZE, KERNEL, FILTER_SIZE); \ + src_ptr += ld_src_ic; \ + weight_ptr += ld_weight_ic; \ + } \ + store_ocx_ow8_remain_static( \ + bias, op, dst_ptr, ld_dst_oc); \ + } \ + }; + +INSTANCE_KERN(2) +INSTANCE_KERN(3) +INSTANCE_KERN(5) +INSTANCE_KERN(7) + +#undef INSTANCE_KERN +#undef KERNEL +#undef KERNEL_CB + +template +struct ConvDirectNchwNchw88Fp16 { + static MEGDNN_ALWAYS_INLINE void impl( + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + const int oc, const int ic, const int ih, const int iw, const int oh, + const int oh_block, const int ow, const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; +#ifdef MEGDNN_ARMV7 + constexpr int big_oc_step = 8; +#else + constexpr int big_oc_step = 16; +#endif + constexpr int oc_step = 8; + constexpr int ow_step = 8; + constexpr int sh = stride; + constexpr int sw = stride; + + const int ld_dst_oc = oh * ow * oc_step; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int oc_remain = oc - oc_end; + + using remain_func = std::function; + + remain_func big_oc_remain = nullptr, small_oc_remain = nullptr; + if (ow_remain) { + switch (ow_remain) { +#define cb(i) \ + case i + 1: \ + big_oc_remain = KernFilterXStrideXNchwNchw88FP16< \ + bias_mode, Op, i + 1, filter_size, big_oc_step, stride>::impl; \ + small_oc_remain = KernFilterXStrideXNchwNchw88FP16< \ + bias_mode, Op, i + 1, filter_size, oc_step, stride>::impl; \ + break; + UNROLL_CALL_NOWRAPPER(7, cb); +#undef cb + default: + megdnn_assert(0, "Don't support remain %d for kern", ow_remain); + } + } + + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const __fp16* weight_ptr = filter + oc_idx * ic * fh * fw; + rep(oh_idx, oh_block) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const __fp16* src_ptr = src + oh_idx * sh * iw + ow_idx * sw; + const __fp16* bias_ptr = bias + oc_idx; + __fp16* dst_ptr = + dst + oc_idx * oh * ow + (oh_idx * ow + ow_idx) * oc_step; + KernFilterXStrideXNchwNchw88FP16< + bias_mode, Op, ow_step, filter_size, big_oc_step, stride>:: + impl(src_ptr, weight_ptr, bias_ptr, dst_ptr, ic, ih, iw, + ld_dst_oc, op); + } + if (ow_remain > 0) { + const __fp16* src_ptr = src + oh_idx * sh * iw + ow_end * sw; + const __fp16* bias_ptr = bias + oc_idx; + __fp16* dst_ptr = + dst + oc_idx * oh * ow + (oh_idx * ow + ow_end) * oc_step; + big_oc_remain( + src_ptr, weight_ptr, bias_ptr, dst_ptr, ic, ih, iw, + ld_dst_oc, op); + } + } + } + if (oc_remain > 0) { + const __fp16* weight_ptr = filter + oc_end * ic * fh * fw; + rep(oh_idx, oh_block) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const __fp16* src_ptr = src + oh_idx * sh * iw + ow_idx * sw; + const __fp16* bias_ptr = bias + oc_end; + __fp16* dst_ptr = + dst + oc_end * oh * ow + (oh_idx * ow + ow_idx) * oc_step; + KernFilterXStrideXNchwNchw88FP16< + bias_mode, Op, ow_step, filter_size, oc_step, stride>:: + impl(src_ptr, weight_ptr, bias_ptr, dst_ptr, ic, ih, iw, + ld_dst_oc, op); + } + if (ow_remain > 0) { + const __fp16* src_ptr = src + oh_idx * sh * iw + ow_end * sw; + const __fp16* bias_ptr = bias + oc_end; + __fp16* dst_ptr = + dst + oc_end * oh * ow + (oh_idx * ow + ow_end) * oc_step; + small_oc_remain( + src_ptr, weight_ptr, bias_ptr, dst_ptr, ic, ih, iw, + ld_dst_oc, op); + } + } + } + } +}; +} // namespace +} // namespace arm_common +} // namespace megdnn +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp index 23a0cbd88..74f4796ec 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp @@ -1,4 +1,4 @@ -#include "src/arm_common/conv_bias/block_helper.h" +#include "src/fallback/conv_bias/gi/block_helper.h" #if MGB_ENABLE_DOT #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp index 5f7ff1c7a..1c2ef29d5 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp @@ -12,11 +12,11 @@ */ #include "megdnn/oprs.h" #if MGB_ENABLE_DOT -#include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" #include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/nchw_nchwxx_valid.h" +#include "src/fallback/conv_bias/gi/block_helper.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp index 7b8a3818f..79a6cb34a 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp @@ -12,12 +12,12 @@ */ #include "megdnn/oprs.h" -#include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/int8x8x16/algos.h" #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" #include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/nchw_nchwxx_valid.h" #include "src/common/opr_delegate.h" +#include "src/fallback/conv_bias/gi/block_helper.h" #include "midout.h" diff --git a/dnn/src/arm_common/conv_bias/intrinsic_helper.h b/dnn/src/arm_common/conv_bias/intrinsic_helper.h index 790ab9b2a..62488b3b0 100644 --- a/dnn/src/arm_common/conv_bias/intrinsic_helper.h +++ b/dnn/src/arm_common/conv_bias/intrinsic_helper.h @@ -174,191 +174,195 @@ __ai void store_ocx_ow4_remain_static( StoreOcxOw4Remain::impl(c, op, dst_ptr, ld_dst_oc); } ////////////////////Store_OCX_OW8_Remain///////////////////////// -template +template < + int c_dim, int ow_remain, typename Op, typename T, typename T2, typename T3, + size_t simd_lenx2> struct StoreOcxOw8Remain { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc); }; -template -struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> { +template < + int c_dim, typename Op, typename T, typename T2, typename T3, size_t simd_lenx2> +struct StoreOcxOw8Remain { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { - op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); - op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); - - op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); - op({{c[1][6], c[1][7]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + MEGDNN_MARK_USED_VAR(c); + MEGDNN_MARK_USED_VAR(op); + MEGDNN_MARK_USED_VAR(dst_ptr); + MEGDNN_MARK_USED_VAR(ld_dst_oc); } }; -template -struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); - op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + simd_lenx2)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + simd_lenx2 * 2)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + simd_lenx2 * 3)); op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); - op({{c[1][6], c[1][7]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2)); + op({{c[1][4], c[1][5]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2 * 2)); + op({{c[1][6], c[1][7]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2 * 3)); } }; -template -struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); - op(c[0][6], reinterpret_cast(dst_ptr + 24)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + simd_lenx2)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + simd_lenx2 * 2)); + op(c[0][6], reinterpret_cast(dst_ptr + simd_lenx2 * 3)); op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); - op(c[1][6], reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2)); + op({{c[1][4], c[1][5]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2 * 2)); + op(c[1][6], reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2 * 3)); } }; -template -struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + simd_lenx2)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + simd_lenx2 * 2)); op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2)); + op({{c[1][4], c[1][5]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2 * 2)); } }; -template -struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op(c[0][4], reinterpret_cast(dst_ptr + 16)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + simd_lenx2)); + op(c[0][4], reinterpret_cast(dst_ptr + simd_lenx2 * 2)); op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op(c[1][4], reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2)); + op(c[1][4], reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2 * 2)); } }; -template -struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + simd_lenx2)); op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][2], c[1][3]}}, + reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2)); } }; -template -struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op(c[0][2], reinterpret_cast(dst_ptr + 8)); + op(c[0][2], reinterpret_cast(dst_ptr + simd_lenx2)); op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); - op(c[1][2], reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op(c[1][2], reinterpret_cast(dst_ptr + ld_dst_oc + simd_lenx2)); } }; -template -struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); } }; -template -struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { op(c[0][0], reinterpret_cast(dst_ptr)); op(c[1][0], reinterpret_cast(dst_ptr + ld_dst_oc)); } }; -template -struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> { - static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) { - op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); - op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); - } -}; -template -struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); - op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + simd_lenx2)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + simd_lenx2 * 2)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + simd_lenx2 * 3)); } }; -template -struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); - op(c[0][6], reinterpret_cast(dst_ptr + 24)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + simd_lenx2)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + simd_lenx2 * 2)); + op(c[0][6], reinterpret_cast(dst_ptr + simd_lenx2 * 3)); } }; -template -struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + simd_lenx2)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + simd_lenx2 * 2)); } }; -template -struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op(c[0][4], reinterpret_cast(dst_ptr + 16)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + simd_lenx2)); + op(c[0][4], reinterpret_cast(dst_ptr + simd_lenx2 * 2)); } }; -template -struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + simd_lenx2)); } }; -template -struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op(c[0][2], reinterpret_cast(dst_ptr + 8)); + op(c[0][2], reinterpret_cast(dst_ptr + simd_lenx2)); } }; -template -struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); } }; -template -struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3, simd_lenx2> { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int) { op(c[0][0], reinterpret_cast(dst_ptr)); } }; -template +template < + int c_dim, int ow_remain, typename Op, size_t simd_lenx2 = 8, typename T, + typename T2> __ai void store_ocx_ow8_remain_static(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { - StoreOcxOw8Remain::impl(c, op, dst_ptr, ld_dst_oc); + StoreOcxOw8Remain::impl( + c, op, dst_ptr, ld_dst_oc); } -template +template < + int c_dim, int ow_remain, typename Op, typename T3, size_t simd_lenx2 = 8, + typename T, typename T2> __ai void store_ocx_ow8_remain_static_dt( T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { - StoreOcxOw8Remain::impl(c, op, dst_ptr, ld_dst_oc); + StoreOcxOw8Remain::impl( + c, op, dst_ptr, ld_dst_oc); } ////////////////////Store_OCX_OW8_Remain///////////////////////// template < @@ -611,6 +615,15 @@ __ai int32x4_t neon_vld1q(const int* ptr) { __ai int16x8_t neon_vld1q(const int16_t* ptr) { return vld1q_s16(ptr); } + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +__ai float16x8_t neon_vdupq_n(__fp16 val) { + return vdupq_n_f16(val); +} +__ai float16x8_t neon_vld1q(const __fp16* ptr) { + return vld1q_f16(ptr); +} +#endif template struct NeonLdqSimd; template <> diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index e43a8835b..dfe87f9f4 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -67,6 +67,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoF16DirectStride1 f16_direct_stride1; AlgoF16ChannelWiseNCHW88 f16_channel_wise_nchw88; AlgoF16DirectNCHW88 f16_direct_nchw88; + AlgoF16DirectNchwNchw88 f16_direct_nchw_nchw88; #endif SmallVector> refhold; @@ -104,6 +105,7 @@ public: m_direct_algos.emplace_back(&f16_direct); m_direct_algos.emplace_back(&f16_channel_wise_nchw88); m_direct_algos.emplace_back(&f16_direct_nchw88); + m_direct_algos.emplace_back(&f16_direct_nchw_nchw88); #endif m_direct_algos.emplace_back(&i8x8x16_direct); m_direct_algos.emplace_back(&i8x8x16_stride2_filter2); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index ec8a4c15d..1226b94b6 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -73,6 +73,7 @@ private: class AlgoF16DirectStride1; class AlgoF16ChannelWiseNCHW88; class AlgoF16DirectNCHW88; + class AlgoF16DirectNchwNchw88; #endif class AlgoPack; diff --git a/dnn/src/arm_common/neon_struct.h b/dnn/src/arm_common/neon_struct.h index c5aeff381..adee32b2b 100644 --- a/dnn/src/arm_common/neon_struct.h +++ b/dnn/src/arm_common/neon_struct.h @@ -31,6 +31,11 @@ struct Vld1q_f32 { struct Vld1_s8 { static __ai int8x8_t impl(const int8_t* ptr) { return vld1_s8(ptr); } }; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +struct Vld1q_f16 { + static __ai float16x8_t impl(const __fp16* ptr) { return vld1q_f16(ptr); } +}; +#endif struct Vldq_dup_4s8_8s16 { static __ai int16x8_t impl(const int8_t* ptr) { return vldq_dup_4s8_8s16(ptr); } }; diff --git a/dnn/src/common/nchw_nchwxx_valid.cpp b/dnn/src/common/nchw_nchwxx_valid.cpp index 7dd58c684..458f3d6bb 100644 --- a/dnn/src/common/nchw_nchwxx_valid.cpp +++ b/dnn/src/common/nchw_nchwxx_valid.cpp @@ -8,12 +8,15 @@ using NchwNchwxxFuncInterface = std::function::CanonizedFilterMeta& fm, const ConvBiasForward::BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode)>; -static SmallVector g_func_vec{ - nchw_nchwxx_valid, - nchw_nchwxx_valid, - nchw_nchwxx_valid, - nchw_nchwxx_valid, - nchw_nchwxx_valid, +static SmallVector g_func_vec { + nchw_nchwxx_valid, + nchw_nchwxx_valid, + nchw_nchwxx_valid, + nchw_nchwxx_valid, + nchw_nchwxx_valid, +#if !MEGDNN_DISABLE_FLOAT16 + nchw_nchwxx_valid, +#endif }; } // namespace bool ConvBiasForward::is_nchw_nchwxx_optimized( diff --git a/dnn/src/common/nchw_nchwxx_valid.h b/dnn/src/common/nchw_nchwxx_valid.h index 53f39357e..f3b3f2ab6 100644 --- a/dnn/src/common/nchw_nchwxx_valid.h +++ b/dnn/src/common/nchw_nchwxx_valid.h @@ -9,6 +9,9 @@ enum NchwNchwxxType { NCHW44_INT8_INT8_INT16, NCHW44_INT8_DOT, NCHW88, +#if !MEGDNN_DISABLE_FLOAT16 + NCHW88_FP16, +#endif }; template static inline bool nchw_nchwxx_valid( @@ -139,6 +142,35 @@ inline bool nchw_nchwxx_valid( bool avaible = ok_type && ok_src_dst && ok_slide && ok_conv; return avaible; } +#if !MEGDNN_DISABLE_FLOAT16 +template <> +inline bool nchw_nchwxx_valid( + const DTypeEnum src_dtype, const DTypeEnum filter_dtype, + const DTypeEnum dst_dtype, + const ConvolutionBase::CanonizedFilterMeta& fm, + const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) { + bool ok_type = + ((src_dtype == DTypeEnum::Float16 && filter_dtype == DTypeEnum::Float16 && + (dst_dtype == DTypeEnum::Float16))) && + (fm.format == param::Convolution::Format::NCHW88); + bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || + nonline_mode == param::ConvBias::NonlineMode::RELU || + nonline_mode == param::ConvBias::NonlineMode::H_SWISH; + bool ok_src_dst = + fm.icpg < 8 && (fm.ocpg % 8 == 0 && fm.ocpg >= 8) && fm.group == 1; + + bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || + fm.spatial[0] == 7); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 1 || fm.stride[1] == 2); + bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; + bool avaible = + ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv; + return avaible; +} +#endif } // namespace } // namespace megdnn \ No newline at end of file diff --git a/dnn/src/common/unroll_macro.h b/dnn/src/common/unroll_macro.h index 004adc825..600be8341 100644 --- a/dnn/src/common/unroll_macro.h +++ b/dnn/src/common/unroll_macro.h @@ -239,6 +239,23 @@ #define UNROLL_CALL_NOWRAPPER_D2(step, step2, cb) \ UNROLL_CALL_RAW_D2(step, step2, cb) +/////////////////// unroll call with one argument /////////////////// + +//! If the arg of cb is related to a macro inside cb, use this. +//! Reason: ## before VA_ARGS removes the ',' when there are no arguments, +//! but with that we can not nest macros. +//! Ref: https://stackoverflow.com/questions/5891221/variadic-macros-with-zero-arguments + +#define UNROLL_ONE_ARG_RAW2(cb, arg) cb(0, arg) cb(1, arg) +#define UNROLL_ONE_ARG_RAW3(cb, arg) UNROLL_ONE_ARG_RAW2(cb, arg) cb(2, arg) +#define UNROLL_ONE_ARG_RAW5(cb, arg) \ + UNROLL_ONE_ARG_RAW3(cb, arg) \ + cb(3, arg) cb(4, arg) +#define UNROLL_ONE_ARG_RAW7(cb, arg) \ + UNROLL_ONE_ARG_RAW5(cb, arg) \ + cb(5, arg) cb(6, arg) +#define UNROLL_CALL_ONE_ARG_RAW(step, cb, arg) UNROLL_ONE_ARG_RAW##step(cb, arg) + // clang-format on // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/gi/block_helper.h b/dnn/src/fallback/conv_bias/gi/block_helper.h index 664954e3b..d32bd616a 100644 --- a/dnn/src/fallback/conv_bias/gi/block_helper.h +++ b/dnn/src/fallback/conv_bias/gi/block_helper.h @@ -22,4 +22,4 @@ static inline int l2_block_helper( } // namespace } // namespace megdnn -// vim: syntax=cpp.doxygen +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/fallback/conv_bias/gi/intrinsic_helper.h b/dnn/src/fallback/conv_bias/gi/intrinsic_helper.h index 1364322d4..a0bd3f1b9 100644 --- a/dnn/src/fallback/conv_bias/gi/intrinsic_helper.h +++ b/dnn/src/fallback/conv_bias/gi/intrinsic_helper.h @@ -255,27 +255,13 @@ struct StoreOcxOw8Remain { static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc); }; -template -struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> { +template +struct StoreOcxOw8Remain { static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { - ParamElemFixLenVisitor vis; - op(vis(c[0][0]), reinterpret_cast(dst_ptr)); - op(vis(c[0][1]), reinterpret_cast(dst_ptr + 4)); - op(vis(c[0][2]), reinterpret_cast(dst_ptr + 8)); - op(vis(c[0][3]), reinterpret_cast(dst_ptr + 12)); - op(vis(c[0][4]), reinterpret_cast(dst_ptr + 16)); - op(vis(c[0][5]), reinterpret_cast(dst_ptr + 20)); - op(vis(c[0][6]), reinterpret_cast(dst_ptr + 24)); - op(vis(c[0][7]), reinterpret_cast(dst_ptr + 28)); - - op(vis(c[1][0]), reinterpret_cast(dst_ptr + ld_dst_oc)); - op(vis(c[1][1]), reinterpret_cast(dst_ptr + ld_dst_oc + 4)); - op(vis(c[1][2]), reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op(vis(c[1][3]), reinterpret_cast(dst_ptr + ld_dst_oc + 12)); - op(vis(c[1][4]), reinterpret_cast(dst_ptr + ld_dst_oc + 16)); - op(vis(c[1][5]), reinterpret_cast(dst_ptr + ld_dst_oc + 20)); - op(vis(c[1][6]), reinterpret_cast(dst_ptr + ld_dst_oc + 24)); - op(vis(c[1][7]), reinterpret_cast(dst_ptr + ld_dst_oc + 28)); + MEGDNN_MARK_USED_VAR(c); + MEGDNN_MARK_USED_VAR(op); + MEGDNN_MARK_USED_VAR(dst_ptr); + MEGDNN_MARK_USED_VAR(ld_dst_oc); } }; template @@ -406,20 +392,6 @@ struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> { } }; -template -struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> { - static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { - ParamElemFixLenVisitor vis; - op(vis(c[0][0]), reinterpret_cast(dst_ptr)); - op(vis(c[0][1]), reinterpret_cast(dst_ptr + 4)); - op(vis(c[0][2]), reinterpret_cast(dst_ptr + 8)); - op(vis(c[0][3]), reinterpret_cast(dst_ptr + 12)); - op(vis(c[0][4]), reinterpret_cast(dst_ptr + 16)); - op(vis(c[0][5]), reinterpret_cast(dst_ptr + 20)); - op(vis(c[0][6]), reinterpret_cast(dst_ptr + 24)); - op(vis(c[0][7]), reinterpret_cast(dst_ptr + 28)); - } -}; template struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> { static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { diff --git a/dnn/src/fallback/conv_bias/im2col/algos.cpp b/dnn/src/fallback/conv_bias/im2col/algos.cpp index 5b28e9043..b91a1fab1 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.cpp +++ b/dnn/src/fallback/conv_bias/im2col/algos.cpp @@ -335,19 +335,23 @@ bool ConvBiasImpl::AlgoIm2col::usable( return false; } if (format == param::ConvBias::Format::NCHW88) { - if (matmul_desc.packmode != Pack_Mode::DEFAULT) { + //! current NCHW88 im2col only support DEFAULT mode matmul + bool is_packmode_not_default = (matmul_desc.packmode != Pack_Mode::DEFAULT); + //! NCHW88 hybrid mode and channel wise is not support + bool is_hybrid_mode_or_channel_wise = + (param.filter_meta.icpg < 8_z || param.filter_meta.ocpg == 1); + if (is_packmode_not_default || is_hybrid_mode_or_channel_wise) { return false; } } if (format == param::ConvBias::Format::NCHW44 || format == param::ConvBias::Format::NCHW44_DOT) { //! current NCHW44 im2col only support DEFAULT mode matmul - if (matmul_desc.packmode != Pack_Mode::DEFAULT) { - return false; - //! nchw44 hybird mode and channel wise is not support - } else if ( - param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || - param.filter_meta.ocpg == 1) { + bool is_packmode_not_default = (matmul_desc.packmode != Pack_Mode::DEFAULT); + //! NCHW44 hybrid mode and channel wise is not support + bool is_hybrid_mode_or_channel_wise = + (param.filter_meta.icpg < 4_z || param.filter_meta.ocpg == 1); + if (is_packmode_not_default || is_hybrid_mode_or_channel_wise) { return false; } } @@ -358,12 +362,11 @@ bool ConvBiasImpl::AlgoIm2col::usable( } if (format == param::ConvBias::Format::NCHW44) { //! current NCHW44 im2col only support DEFAULT mode matmul - if (matmul_desc.packmode != Pack_Mode::DEFAULT) { - return false; - //! nchw44 hybird mode and channel wise is not support - } else if ( - param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || - param.filter_meta.ocpg == 1) { + bool is_packmode_not_default = (matmul_desc.packmode != Pack_Mode::DEFAULT); + //! NCHW44 hybrid mode and channel wise is not support + bool is_hybrid_mode_or_channel_wise = + (param.filter_meta.icpg < 4_z || param.filter_meta.ocpg == 1); + if (is_packmode_not_default || is_hybrid_mode_or_channel_wise) { return false; } } @@ -411,15 +414,14 @@ bool ConvBiasImpl::AlgoIm2col::usable( matmul_desc.innerblocksize.n, m_ohw_tile_size, matmul_desc.packmode); fallback::MatrixMulImpl::KernSizeParam matmul_param = get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); - bool matmulusable = m_matmul_algo->usable(matmul_param); - return matmulusable && - (!(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + return (!(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && param.filter_meta.spatial[0] == 1 && param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1)) && (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && - param.compute_mode == param::ConvBias::ComputeMode::DEFAULT; + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + m_matmul_algo->usable(matmul_param); } MIDOUT_END(); return false; diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index a2222b517..344f0e280 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -255,6 +255,7 @@ public: ARM_COMMON_DIRECT_STRD1_FP16, ARM_COMMON_CHWNWISE_NCHW88_F16, ARM_COMMON_DIRECT_NCHW88_FP16, + ARM_COMMON_DIRECT_NCHW_NCHW88_FP16, ARM_COMMON_DIRECT_STRD1_S8, ARM_COMMON_DIRECT_STRD2_S8, ARM_COMMON_DIRECT_NCHW44, diff --git a/dnn/test/aarch64/conv_bias.cpp b/dnn/test/aarch64/conv_bias.cpp index 4b0750abb..331391b8d 100644 --- a/dnn/test/aarch64/conv_bias.cpp +++ b/dnn/test/aarch64/conv_bias.cpp @@ -11,65 +11,6 @@ namespace megdnn { namespace test { -std::vector get_conv_bias_args( - std::vector kernel, size_t stride) { - using namespace conv_bias; - using Param = param::ConvBias; - using NLMode = param::ConvBias::NonlineMode; - - std::vector args; - auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h, size_t kernel, - size_t stride, NLMode nonline_mode) { - Param param; - param.stride_h = stride; - param.stride_w = stride; - param.pad_h = kernel == 1 ? 0 : kernel / 2; - param.pad_w = kernel == 1 ? 0 : kernel / 2; - param.nonlineMode = nonline_mode; - - //! no bias - args.emplace_back( - param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, kernel, kernel}, - TensorShape{}); - //! bias broadcast channle - args.emplace_back( - param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, kernel, kernel}, - TensorShape{1, oc, 1, 1}); - //! bias - args.emplace_back( - param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, kernel, kernel}, - TensorShape{ - n, oc, (h + 2 * param.pad_h - kernel) / stride + 1, - (w + 2 * param.pad_h - kernel) / stride + 1}); - }; - - for (auto nlmode : {NLMode::IDENTITY, NLMode::RELU, NLMode::SIGMOID}) { - for (size_t n : {1, 2}) { - for (size_t ic : {1, 2, 3, 4, 8}) { - for (size_t oc : {1, 2, 3, 4, 8}) { - for (size_t size : {1, 2, 3, 4, 8, 24}) { - for (size_t k : kernel) { - pack(n, oc, ic, size + 24, size + 24, k, stride, nlmode); - } - } - } - } - } - } - return args; -} - -void checker_conv_bias( - std::vector args, Handle* handle, const char* algo_name) { - using namespace conv_bias; - - Checker checker(handle); - checker.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker(algo_name)); - for (auto&& arg : args) { - checker.set_param(arg.param).execs({arg.src, arg.filter, arg.bias, {}, {}}); - } -} TEST_F(AARCH64_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) { check_conv_bias( conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index d6043f379..fdec1174d 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -378,6 +378,56 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) { get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(), rng, "F16DIRECT", 0.03); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW_NCHW88) { + NormalRNG rng(1); + using namespace conv_bias; + std::vector args; + auto pack = [&](size_t oc, size_t ic, size_t ih, size_t iw, size_t n, size_t filter, + size_t stride, size_t pad, param::ConvBias::NonlineMode nlmode, + megdnn::BiasMode bias_mode) { + if (ih + 2 * pad < filter || iw + 2 * pad < filter) + return; + param::ConvBias param; + constexpr size_t pack_c = 8; + param.format = param::ConvBias::Format::NCHW88; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = pad; + param.pad_w = pad; + param.nonlineMode = nlmode; + + auto src_tensor_shape = TensorShape{n, ic, ih, iw}; + auto weight_tensor_shape = TensorShape{oc / pack_c, filter, filter, ic, pack_c}; + auto bias_tensor_shape = TensorShape{}; + if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { + bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c}; + } + args.emplace_back( + param, src_tensor_shape, weight_tensor_shape, bias_tensor_shape); + }; + + for (size_t n : {1}) { + for (size_t oc : {8, 16}) { + for (size_t ic : {4}) { + for (size_t ih = 1; ih < 71; ih += 17) { + for (size_t filter : {2, 3, 5, 7}) { + for (size_t stride : {1, 2}) { + for (size_t pad : {0, 1}) { + for (auto nlmode : QUAN_NLMODE) { + for (auto bias_mode : BR_AND_BIAS_BIASMODE) { + pack(oc, ic, ih, ih, n, filter, stride, pad, + nlmode, bias_mode); + } + } + } + } + } + } + } + } + } + checker_conv_bias_f16(args, handle(), rng, "DIRECT_CONV_F16_NCHW_NCHW88", 0.03); +} TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) { NormalRNG rng(1); checker_conv_bias_f16( diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index ab4c432f0..7acbe4bc8 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -1677,6 +1677,68 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_NCHW44_VS_N algo_name_nchw44, data_type_fp32, RUNS, {1, {4}}); } +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, + BENCHMARK_CONVBIAS_DIRECT_HYBRID_NCHW44_VS_NCHW88) { + constexpr size_t RUNS = 50; + using NLMode = param::ConvBias::NonlineMode; + + std::vector args_nchw88, args_nchw44; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, + size_t FS) { + param::ConvBias param_nchw88, param_nchw44; + param_nchw88.format = param::ConvBias::Format::NCHW88; + param_nchw44.format = param::ConvBias::Format::NCHW44; + for (size_t pad : {1}) { + for (size_t stride : {1, 2}) { + for (auto nlmode : {NLMode::RELU, NLMode::IDENTITY, NLMode::H_SWISH}) { + param_nchw88.nonlineMode = nlmode; + param_nchw88.pad_h = pad; + param_nchw88.pad_w = pad; + param_nchw88.stride_h = stride; + param_nchw88.stride_w = stride; + + param_nchw44.nonlineMode = nlmode; + param_nchw44.pad_h = pad; + param_nchw44.pad_w = pad; + param_nchw44.stride_h = stride; + param_nchw44.stride_w = stride; + + args_nchw88.emplace_back( + param_nchw88, TensorShape{N, IC, H, W}, + TensorShape{OC / 8, FS, FS, IC, 8}, + TensorShape{1, OC / 8, 1, 1, 8}); + args_nchw44.emplace_back( + param_nchw44, TensorShape{N, IC, H, W}, + TensorShape{OC / 4, FS, FS, IC, 4}, + TensorShape{1, OC / 4, 1, 1, 4}); + } + } + } + }; + std::vector data_type_fp16 = { + dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()}; + std::vector data_type_fp32 = { + dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; + bench_case(1, 3, 32, 400, 400, 7); + bench_case(1, 3, 32, 300, 300, 5); + bench_case(1, 3, 32, 100, 100, 3); + bench_case(1, 3, 32, 80, 80, 2); + bench_case(1, 3, 64, 200, 200, 7); + bench_case(1, 3, 64, 128, 128, 5); + bench_case(1, 3, 64, 100, 100, 3); + bench_case(1, 3, 64, 80, 80, 2); + bench_case(1, 3, 128, 200, 200, 7); + bench_case(1, 3, 128, 128, 128, 5); + bench_case(1, 3, 128, 100, 100, 3); + bench_case(1, 3, 128, 80, 80, 2); + std::string algo_name_nchw88 = "DIRECT_CONV_F16_NCHW_NCHW88"; + std::string algo_name_nchw44 = "F32_CONV_NCHW_NCHW44"; + + benchmark_with_contrast( + args_nchw88, algo_name_nchw88, data_type_fp16, args_nchw44, + algo_name_nchw44, data_type_fp32, RUNS, {1, {4}}); +} + TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CHANNEL_WISE_INT8_INT8_INT8_STRIDE1) { constexpr size_t RUNS = 50; diff --git a/dnn/test/common/conv_bias.cpp b/dnn/test/common/conv_bias.cpp index 3b865828c..2aa70a8fc 100644 --- a/dnn/test/common/conv_bias.cpp +++ b/dnn/test/common/conv_bias.cpp @@ -1163,7 +1163,7 @@ void benchmark_with_contrast( {arg.src, data_type[0]}, {arg.filter, data_type[1]}, {arg.bias, data_type[2]}, {}, dst_layout); float computation = (dst_layout.total_nr_elems() * arg.filter[1] * - arg.filter[2] * arg.filter[3] * arg.filter[4] * 2.0) / + arg.filter[2] * arg.filter[3] * 2.0) / (1024 * 1024 * 1024) * 1e3; benchmarker.set_param(arg.param); auto used = benchmarker.exec({arg.src, arg.filter, arg.bias, {}, {}}) / RUNS; @@ -1176,8 +1176,7 @@ void benchmark_with_contrast( {arg_contrast.bias, data_type_contrast[2]}, {}, dst_layout_contrast); float computation_contrast = (dst_layout_contrast.total_nr_elems() * arg_contrast.filter[1] * - arg_contrast.filter[2] * arg_contrast.filter[3] * - arg_contrast.filter[4] * 2.0) / + arg_contrast.filter[2] * arg_contrast.filter[3] * 2.0) / (1024 * 1024 * 1024) * 1e3; benchmarker_contrast.set_param(arg_contrast.param); auto used_contrast = benchmarker_contrast.exec( -- GitLab