From a1cbd9bb71151c9a948015436866c366ccc52766 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 9 Oct 2022 10:32:25 +0800 Subject: [PATCH] feat(dnn): add fp16 nchw88 pooling algo GitOrigin-RevId: 7a5e9c7df242fd5d7d7811b1af9213e58be20e91 --- dnn/src/arm_common/pooling/algo.h | 13 + .../pooling/fp16/algo_fp16_nchw88_pooling.cpp | 128 ++++++++ .../pooling/fp16/kern_fp16_nchw88_pooling.h | 269 ++++++++++++++++ .../pooling/kern_fp32_pooling_nchw44.h | 290 ------------------ dnn/src/arm_common/pooling/opr_impl.cpp | 6 + dnn/src/arm_common/pooling/opr_impl.h | 6 + dnn/src/common/unroll_macro.h | 39 +++ dnn/test/arm_common/pooling.cpp | 108 +++++++ 8 files changed, 569 insertions(+), 290 deletions(-) create mode 100644 dnn/src/arm_common/pooling/fp16/algo_fp16_nchw88_pooling.cpp create mode 100644 dnn/src/arm_common/pooling/fp16/kern_fp16_nchw88_pooling.h delete mode 100644 dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h diff --git a/dnn/src/arm_common/pooling/algo.h b/dnn/src/arm_common/pooling/algo.h index d96847e7b..663572e67 100644 --- a/dnn/src/arm_common/pooling/algo.h +++ b/dnn/src/arm_common/pooling/algo.h @@ -124,6 +124,19 @@ public: MEGDNN_DECL_ALGO_TYPE(ARM_Filter5ModexStridexNCHW44) }; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +class PoolingImpl::AlgoFilterxModexStridexNCHW88 final : public AlgoBase { +public: + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { + return "ARM_POOLING_FILTERX_MODEX_STRIDEX_NCHW88"; + } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Fp16FilterxModexStridexNCHW88) +}; +#endif + class PoolingImpl::AlgoFallback final : public AlgoBase { public: AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; diff --git a/dnn/src/arm_common/pooling/fp16/algo_fp16_nchw88_pooling.cpp b/dnn/src/arm_common/pooling/fp16/algo_fp16_nchw88_pooling.cpp new file mode 100644 index 000000000..f60968960 --- /dev/null +++ b/dnn/src/arm_common/pooling/fp16/algo_fp16_nchw88_pooling.cpp @@ -0,0 +1,128 @@ +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "midout.h" +#include "src/arm_common/pooling/algo.h" +#include "src/arm_common/pooling/fp16/kern_fp16_nchw88_pooling.h" + +MIDOUT_DECL(megdnn_arm_common_fp16_nchw88_pooling) + +namespace megdnn { +namespace arm_common { +bool PoolingImpl::AlgoFilterxModexStridexNCHW88::usable( + const PoolingKernSizeParam& param) const { + uint32_t sh = param.stride[0]; + uint32_t sw = param.stride[1]; + uint32_t fh = param.filter[0]; + uint32_t fw = param.filter[1]; + + bool usable = param.src_type.enumv() == DTypeEnum::Float16 && + param.format == param::Pooling::Format::NCHW88 && + (param.mode == PoolingBase::Mode::MAX || + param.mode == PoolingBase::Mode::AVERAGE) && + fh == fw && sh == sw; + bool size_ok = + (((fh == 2 || fh == 3 || fh == 4 || fh == 5) && (sh == 1 || sh == 2)) || + ((fh == 9 || fh == 13) && (sh == 1))); + return usable && size_ok; +} + +void PoolingImpl::AlgoFilterxModexStridexNCHW88::exec( + const PoolingKernParam& param) const { + int ih = param.isz[0]; + int iw = param.isz[1]; + int oh = param.osz[0]; + int ow = param.osz[1]; + int n = param.n; + int ic = param.ic; + int ph = param.padding[0]; + int pw = param.padding[1]; + int sh = param.stride[0]; + int fh = param.filter[0]; + + auto src = param.src_ptr; + auto dst = param.dst_ptr; + +#define DISPATCH_FUNC(filter, stride, mode) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_fp16_nchw88_pooling, midout_iv(0), \ + midout_iv(#filter #stride #mode##_hash)) { \ + auto run = [=](size_t index, size_t) { \ + const int c_idx = index; \ + pooling_fp16_nchw88( \ + static_cast(src.get_ptr()) + c_idx * ih * iw * 8, \ + static_cast<__fp16*>(dst.get_ptr()) + c_idx * oh * ow * 8, ih, iw, \ + oh, ow, ph, pw); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), n* ic, run); \ + } \ + MIDOUT_END(); + +#define DISPATCH_MODE(filter, stride) \ + switch (param.mode) { \ + case PoolingBase::Mode::MAX: \ + DISPATCH_FUNC(filter, stride, PoolingBase::Mode::MAX); \ + break; \ + case PoolingBase::Mode::AVERAGE: \ + DISPATCH_FUNC(filter, stride, PoolingBase::Mode::AVERAGE); \ + break; \ + default: \ + megdnn_assert(0, "invalid mode %u", static_cast(param.mode)); \ + } + +#define DISPATCH_STRIDE(filter) \ + switch (sh) { \ + case 1: \ + DISPATCH_MODE(filter, 1); \ + break; \ + case 2: \ + DISPATCH_MODE(filter, 2); \ + break; \ + default: \ + megdnn_assert( \ + 0, \ + "Invalid stride %d. When the filter size is 2, 3, 4 or 5, stride " \ + "can only be 1 or 2.", \ + sh); \ + } + +#define DISPATCH_STRIDE1(filter) \ + switch (sh) { \ + case 1: \ + DISPATCH_MODE(filter, 1); \ + break; \ + default: \ + megdnn_assert( \ + 0, \ + "Invalid stride %d. When the filter size is 9 or 13, stride " \ + "can only be 1.", \ + sh); \ + } + +#define DISPATCH_FILTER() \ + switch (fh) { \ + case 2: \ + DISPATCH_STRIDE(2); \ + break; \ + case 3: \ + DISPATCH_STRIDE(3); \ + break; \ + case 4: \ + DISPATCH_STRIDE(4); \ + break; \ + case 5: \ + DISPATCH_STRIDE(5); \ + break; \ + case 9: \ + DISPATCH_STRIDE1(9); \ + break; \ + case 13: \ + DISPATCH_STRIDE1(13); \ + break; \ + } + + DISPATCH_FILTER(); +} + +} // namespace arm_common +} // namespace megdnn +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/pooling/fp16/kern_fp16_nchw88_pooling.h b/dnn/src/arm_common/pooling/fp16/kern_fp16_nchw88_pooling.h new file mode 100644 index 000000000..1ea81c0ce --- /dev/null +++ b/dnn/src/arm_common/pooling/fp16/kern_fp16_nchw88_pooling.h @@ -0,0 +1,269 @@ +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#pragma once +#include +#include +#include "src/arm_common/pooling/opr_impl.h" +#include "src/common/unroll_macro.h" + +namespace megdnn { +namespace arm_common { +namespace { +#if MEGDNN_AARCH64 +#define OW_STEP 4 +#else +#define OW_STEP 2 +#endif + +template +struct CalXsXNchw88 { + static void impl(T1 result, T2 src); +}; + +#define CAL_MAX_CB(step, ow_step) \ + result[ow_step] = vmaxq_f16(result[ow_step], src[ow_step * stride + step]); + +#define CAL_AVE_CB(step, ow_step) \ + result[ow_step] = vaddq_f16(result[ow_step], src[ow_step * stride + step]); + +#define INSTANCE_CAL(filter, ow_step) \ + template \ + struct CalXsXNchw88 { \ + static void impl(T1 result, T2 src) { \ + UNROLL_CALL_NOWRAPPER_D2(filter, ow_step, CAL_MAX_CB); \ + } \ + }; \ + template \ + struct CalXsXNchw88 { \ + static void impl(T1 result, T2 src) { \ + UNROLL_CALL_NOWRAPPER_D2(filter, ow_step, CAL_AVE_CB); \ + } \ + }; + +INSTANCE_CAL(2, OW_STEP) +INSTANCE_CAL(3, OW_STEP) +INSTANCE_CAL(4, OW_STEP) +INSTANCE_CAL(5, OW_STEP) +INSTANCE_CAL(9, OW_STEP) +INSTANCE_CAL(13, OW_STEP) + +#undef INSTANCE_CAL +#undef CAL_AVE_CB +#undef CAL_MAX_CB + +template +void calculate_xsx_nchw88(T1 result, T2 src) { + CalXsXNchw88::impl(result, src); +} + +template +struct KerPoolingFilterXStrideXNchw88 { + static void impl(const __fp16* src_ptr, __fp16* dst_ptr, size_t iw); +}; + +template +struct KerPoolingFilterXStrideXNchw88 { + static void impl(const __fp16* src_ptr, __fp16* dst_ptr, size_t iw) { + constexpr int src_reg_size = stride * (OW_STEP - 1) + filter; + constexpr int packed_ic = 8; + constexpr int simd_len = 8; + constexpr dt_float16 min_float16 = std::numeric_limits::lowest(); + + float16x8_t result[OW_STEP], src[src_reg_size]; +#define cb(i) result[i] = vdupq_n_f16(min_float16); + UNROLL_CALL_NOWRAPPER(OW_STEP, cb); +#undef cb + + for (int fh_idx = 0; fh_idx < filter; ++fh_idx) { + auto src_base_ptr = src_ptr + fh_idx * iw * packed_ic; + rep(i, src_reg_size) { src[i] = vld1q_f16(src_base_ptr + i * simd_len); } + calculate_xsx_nchw88(result, src); + } + +#define cb(i) vst1q_f16(dst_ptr + i * packed_ic, result[i]); + UNROLL_CALL_NOWRAPPER(OW_STEP, cb) +#undef cb + } +}; + +template +struct KerPoolingFilterXStrideXNchw88 { + static void impl(const __fp16* src_ptr, __fp16* dst_ptr, size_t iw) { + constexpr int src_reg_size = stride * (OW_STEP - 1) + filter; + constexpr int packed_ic = 8; + constexpr int simd_len = 8; + const __fp16 zero = static_cast<__fp16>(0); + const __fp16 div_filter_pow = static_cast<__fp16>(1.0 / (filter * filter)); + const float16x8_t div_filter_pow_vec = vdupq_n_f16(div_filter_pow); + + float16x8_t result[OW_STEP], src[src_reg_size]; +#define cb(i) result[i] = vdupq_n_f16(zero); + UNROLL_CALL_NOWRAPPER(OW_STEP, cb) +#undef cb + rep(fh, filter) { + auto src_base_ptr = src_ptr + fh * iw * packed_ic; + rep(i, src_reg_size) { src[i] = vld1q_f16(src_base_ptr + i * simd_len); } + calculate_xsx_nchw88( + result, src); + } +#define cb(i) \ + vst1q_f16(dst_ptr + i * simd_len, vmulq_f16(result[i], div_filter_pow_vec)); + UNROLL_CALL_NOWRAPPER(OW_STEP, cb) +#undef cb + } +}; + +template +void kern_pooling_nchw88_remain_pad( + const __fp16* src, __fp16* dst, const int iw, const int pad_top, + const int pad_left, const int pad_bottom, const int pad_right, + const int filter); + +template <> +void kern_pooling_nchw88_remain_pad( + const __fp16* src, __fp16* dst, const int iw, const int pad_top, + const int pad_left, const int pad_bottom, const int pad_right, + const int filter) { + constexpr int ic_step = 8; + const int fh_end = filter - pad_bottom; + const int fw_end = filter - pad_right; + float16x8_t result = vdupq_n_f16(std::numeric_limits::lowest()); + for (int fh_idx = pad_top; fh_idx < fh_end; ++fh_idx) { + for (int fw_idx = pad_left; fw_idx < fw_end; ++fw_idx) { + float16x8_t s = vld1q_f16(src + (fw_idx - pad_left) * ic_step); + result = vmaxq_f16(result, s); + } + src += iw * ic_step; + } + vst1q_f16(dst, result); +} + +template <> +void kern_pooling_nchw88_remain_pad( + const __fp16* src, __fp16* dst, const int iw, const int pad_top, + const int pad_left, const int pad_bottom, const int pad_right, + const int filter) { + constexpr int ic_step = 8; + const int fh_end = filter - pad_bottom; + const int fw_end = filter - pad_right; + float16x8_t result = vdupq_n_f16(static_cast(0)); + float16x8_t div_filter_pow_vec = vdupq_n_f16(1.0 / (filter * filter)); + for (int fh_idx = pad_top; fh_idx < fh_end; ++fh_idx) { + for (int fw_idx = pad_left; fw_idx < fw_end; ++fw_idx) { + float16x8_t s = vld1q_f16(src + (fw_idx - pad_left) * ic_step); + result = vaddq_f16(result, s); + } + src += iw * ic_step; + } + vst1q_f16(dst, vmulq_f16(result, div_filter_pow_vec)); +} + +template +static inline void kern_pooling_with_pad_nchw88( + const __fp16* src, __fp16* dst, const int filter, const int ow_start, + const int ow_end, const int iw, const int ow, const int stride_w, const int pw, + const int real_ih_idx, const int oh_idx, const int pad_top, + const int pad_bottom) { + constexpr int ic_step = 8; + constexpr int oc_step = 8; + for (int ow_idx = ow_start; ow_idx < ow_end; ++ow_idx) { + const int iw_idx = ow_idx * stride_w; + const int real_iw_idx = std::max(0, iw_idx - pw); + const int pad_left = std::max(0, pw - iw_idx); + const int pad_right = std::max(0, iw_idx - pw + filter - iw); + const int src_offset = (real_ih_idx * iw + real_iw_idx) * ic_step; + const int dst_offset = (oh_idx * ow + ow_idx) * oc_step; + kern_pooling_nchw88_remain_pad( + src + src_offset, dst + dst_offset, iw, pad_top, pad_left, pad_bottom, + pad_right, filter); + } +} + +template +static inline void pooling_fp16_nchw88_pad( + const __fp16* src, __fp16* dst, int ih, int iw, int oh, int ow, int ph, + int pw) { + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int ic_step = 8; + constexpr int oc_step = 8; + constexpr int ow_step = OW_STEP; + const int ow_pad_left_end = div_ceil(pw, stride_w); + const int ow_pad_right_start = (iw + pw - filter) / stride_w + 1; //!!!! CHECK + const int ow_pad_right_step_end = + (ow_pad_right_start - ow_pad_left_end) / ow_step * ow_step + + ow_pad_left_end; + + rep(oh_idx, oh) { + const int ih_idx = oh_idx * stride_h; + const int real_ih_idx = std::max(ih_idx - ph, 0); + const int pad_top = std::max(0, ph - ih_idx); + const int pad_bottom = std::max(0, ih_idx - ph + filter - ih); + if (pad_top > 0 || pad_bottom > 0) { + kern_pooling_with_pad_nchw88( + src, dst, filter, 0, ow, iw, ow, stride_w, pw, real_ih_idx, oh_idx, + pad_top, pad_bottom); + } else { + kern_pooling_with_pad_nchw88( + src, dst, filter, 0, ow_pad_left_end, iw, ow, stride_w, pw, + real_ih_idx, oh_idx, pad_bottom, pad_bottom); + for (int ow_idx = ow_pad_left_end; ow_idx < ow_pad_right_step_end; + ow_idx += ow_step) { + const int iw_idx = ow_idx * stride_w; + const int real_iw_idx = std::max(0, iw_idx - pw); + const int src_offset = (real_ih_idx * iw + real_iw_idx) * ic_step; + const int dst_offset = (oh_idx * ow + ow_idx) * oc_step; + KerPoolingFilterXStrideXNchw88::impl( + src + src_offset, dst + dst_offset, iw); + } + kern_pooling_with_pad_nchw88( + src, dst, filter, ow_pad_right_step_end, ow, iw, ow, stride_w, pw, + real_ih_idx, oh_idx, pad_top, pad_bottom); + } + } +} + +template +static inline void pooling_fp16_nchw88_no_pad( + const __fp16* src, __fp16* dst, const int iw, const int oh, const int ow) { + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int ic_step = 8; + constexpr int oc_step = 8; + constexpr int ow_step = OW_STEP; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + + rep(oh_idx, oh) { + const int ih_idx = oh_idx * stride_h; + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int iw_idx = ow_idx * stride_w; + const int src_offset = (ih_idx * iw + iw_idx) * ic_step; + const int dst_offset = (oh_idx * ow + ow_idx) * oc_step; + KerPoolingFilterXStrideXNchw88::impl( + src + src_offset, dst + dst_offset, iw); + } + + if (ow_remain > 0) { + kern_pooling_with_pad_nchw88( + src, dst, filter, ow_end, ow, iw, ow, stride_w, 0, ih_idx, oh_idx, + 0, 0); + } + } +} + +template +static inline void pooling_fp16_nchw88( + const __fp16* src, __fp16* dst, const int ih, const int iw, const int oh, + const int ow, const int ph, const int pw) { + if (ph > 0 || pw > 0) { + pooling_fp16_nchw88_pad(src, dst, ih, iw, oh, ow, ph, pw); + } else { + pooling_fp16_nchw88_no_pad(src, dst, iw, oh, ow); + } +} + +#undef OW_STEP +} // namespace +} // namespace arm_common +} // namespace megdnn +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h b/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h deleted file mode 100644 index ab24adc82..000000000 --- a/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h +++ /dev/null @@ -1,290 +0,0 @@ -#pragma once -#include -#include "megdnn/opr_param_defs.h" -#include "src/arm_common/intrinsic_helper.h" -#include "src/arm_common/neon_struct.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/common/unroll_macro.h" - -namespace megdnn { -namespace arm_common { -namespace { - -template < - int filter, int stride, int ow_step, PoolingBase::Mode mode, typename T1, - typename T2> -struct CalXsXNchw44 { - static void impl(T1 result, T2 src); -}; - -template < - int filter, int stride, int ow_step, PoolingBase::Mode mode, typename T1, - typename T2> -void calculate_xsx_nchw44(T1 result, T2 src) { - CalXsXNchw44::impl(result, src); -}; - -#define CALCULATE_MAX_CB(step) \ - result[0] = vmaxq_f32(result[0], src[0 * stride + step]); \ - result[1] = vmaxq_f32(result[1], src[1 * stride + step]); \ - result[2] = vmaxq_f32(result[2], src[2 * stride + step]); \ - result[3] = vmaxq_f32(result[3], src[3 * stride + step]); - -#define CALCULATE_AVG_CB(step) \ - result[0] = vaddq_f32(result[0], src[0 * stride + step]); \ - result[1] = vaddq_f32(result[1], src[1 * stride + step]); \ - result[2] = vaddq_f32(result[2], src[2 * stride + step]); \ - result[3] = vaddq_f32(result[3], src[3 * stride + step]); - -#define INSTANCE_CAL(filter) \ - template \ - struct CalXsXNchw44 { \ - static void impl(T1 result, T2 src) { \ - UNROLL_CALL_RAW(filter, CALCULATE_MAX_CB); \ - } \ - }; \ - template \ - struct CalXsXNchw44 { \ - static void impl(T1 result, T2 src) { \ - UNROLL_CALL_RAW(filter, CALCULATE_AVG_CB); \ - } \ - }; - -INSTANCE_CAL(2) -INSTANCE_CAL(3) -INSTANCE_CAL(4) -INSTANCE_CAL(5) -INSTANCE_CAL(9) -INSTANCE_CAL(13) - -#undef INSTANCE_CAL -#undef CALCULATE_AVG_CB -#undef CALCULATE_MAX_CB - -template -struct KerPoolingFilterXStrideXNchw44 { - static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw); -}; - -template -struct KerPoolingFilterXStrideXNchw44 { - static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw) { - constexpr int src_reg_size = ow_step * stride + filter - stride; - constexpr int packed_ic = 4; - constexpr int simd_len = 4; - constexpr float default_float = std::numeric_limits::lowest(); - float32x4_t result[ow_step]; - float32x4_t src[src_reg_size]; - - result[0] = vdupq_n_f32(default_float); - result[1] = vdupq_n_f32(default_float); - result[2] = vdupq_n_f32(default_float); - result[3] = vdupq_n_f32(default_float); - - for (int fh_idx = 0; fh_idx < filter; ++fh_idx) { - load_helper( - src, src_ptr + fh_idx * iw * packed_ic, 0); - calculate_xsx_nchw44( - result, src); - } - - vst1q_f32(dst_ptr + 0 * packed_ic, result[0]); - vst1q_f32(dst_ptr + 1 * packed_ic, result[1]); - vst1q_f32(dst_ptr + 2 * packed_ic, result[2]); - vst1q_f32(dst_ptr + 3 * packed_ic, result[3]); - } -}; - -template -struct KerPoolingFilterXStrideXNchw44< - filter, stride, ow_step, PoolingBase::Mode::AVERAGE> { - static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw) { - constexpr int src_reg_size = ow_step * stride + filter - stride; - constexpr int packed_ic = 4; - constexpr int simd_len = 4; - constexpr float default_float = 0; - constexpr float div_filter_size = 1.f / (filter * filter); - const float32x4_t div_filter_size_vec = vdupq_n_f32(div_filter_size); - float32x4_t result[ow_step]; - float32x4_t src[src_reg_size]; - - result[0] = vdupq_n_f32(default_float); - result[1] = vdupq_n_f32(default_float); - result[2] = vdupq_n_f32(default_float); - result[3] = vdupq_n_f32(default_float); - - for (int fh_idx = 0; fh_idx < filter; ++fh_idx) { - load_helper( - src, src_ptr + fh_idx * iw * packed_ic, 0); - calculate_xsx_nchw44( - result, src); - } - result[0] = vmulq_f32(result[0], div_filter_size_vec); - result[1] = vmulq_f32(result[1], div_filter_size_vec); - result[2] = vmulq_f32(result[2], div_filter_size_vec); - result[3] = vmulq_f32(result[3], div_filter_size_vec); - vst1q_f32(dst_ptr + 0 * packed_ic, result[0]); - vst1q_f32(dst_ptr + 1 * packed_ic, result[1]); - vst1q_f32(dst_ptr + 2 * packed_ic, result[2]); - vst1q_f32(dst_ptr + 3 * packed_ic, result[3]); - } -}; - -template -void ker_pooling_nchw44_remain_pad( - const float32_t* src_ptr, float32_t* dst_ptr, const int iw, const int pad_top, - const int pad_bottom, const int pad_left, const int pad_right, - const int filter); -template <> -void ker_pooling_nchw44_remain_pad( - const float32_t* src_ptr, float32_t* dst_ptr, const int iw, const int pad_top, - const int pad_bottom, const int pad_left, const int pad_right, - const int filter) { - constexpr int ic_step = 4; - const int ih_end = filter - pad_bottom; - const int iw_end = filter - pad_right; - float32x4_t result = vdupq_n_f32(std::numeric_limits::lowest()); - for (int ih_idx = pad_top; ih_idx < ih_end; ++ih_idx) { - for (int iw_idx = pad_left; iw_idx < iw_end; ++iw_idx) { - float32x4_t src = vld1q_f32(src_ptr + (iw_idx - pad_left) * ic_step); - result = vmaxq_f32(result, src); - } - src_ptr += iw * ic_step; - } - vst1q_f32(dst_ptr, result); -} - -template <> -void ker_pooling_nchw44_remain_pad( - const float32_t* src_ptr, float32_t* dst_ptr, const int iw, const int pad_top, - const int pad_bottom, const int pad_left, const int pad_right, - const int filter) { - constexpr int ic_step = 4; - const int ih_end = filter - pad_bottom; - const int iw_end = filter - pad_right; - const float div_filter_size = 1.f / (filter * filter); - const float32x4_t div_filter_size_vec = vdupq_n_f32(div_filter_size); - float32x4_t result = vdupq_n_f32(0.f); - - for (int ih_idx = pad_top; ih_idx < ih_end; ++ih_idx) { - for (int iw_idx = pad_left; iw_idx < iw_end; ++iw_idx) { - float32x4_t src = vld1q_f32(src_ptr + (iw_idx - pad_left) * ic_step); - result = vaddq_f32(result, src); - } - src_ptr += iw * ic_step; - } - result = vmulq_f32(result, div_filter_size_vec); - vst1q_f32(dst_ptr, result); -} - -template -static inline void kern_pooling_with_pad_nchw44( - const float32_t* src, float32_t* dst, const int filter, const int ow_start, - const int ow_end, const int iw, const int ow, const int stride_w, const int pw, - const int real_ih_idx, const int oh_idx, const int pad_top, - const int pad_bottom) { - constexpr int ic_step = 4; - constexpr int oc_step = 4; - for (int ow_idx = ow_start; ow_idx < ow_end; ++ow_idx) { - const int iw_idx = ow_idx * stride_w; - const int real_iw_idx = std::max(iw_idx - pw, 0); - const int pad_left = std::max(0, pw - iw_idx); - const int pad_right = std::max(0, iw_idx - pw + filter - iw); - const int src_offset = (real_ih_idx * iw + real_iw_idx) * ic_step; - const int dst_offset = (oh_idx * ow + ow_idx) * oc_step; - ker_pooling_nchw44_remain_pad( - src + src_offset, dst + dst_offset, iw, pad_top, pad_bottom, pad_left, - pad_right, filter); - } -} - -template -static inline void pooling_fp32_nchw44_pad( - const float32_t* src, float32_t* dst, int ih, int iw, int oh, int ow, int ph, - int pw) { - constexpr int stride_h = stride; - constexpr int stride_w = stride; - constexpr int ic_step = 4; - constexpr int oc_step = 4; - constexpr int ow_step = 4; - const int ow_pad_left_end = div_ceil(pw, stride_w); - const int ow_pad_right_end = (iw - filter + pw - 1) / stride_w; - const int ow_pad_right_step_end = - (ow_pad_right_end - ow_pad_left_end) / ow_step * ow_step + ow_pad_left_end; - - rep(oh_idx, oh) { - const int ih_idx = oh_idx * stride_h; - const int real_ih_idx = std::max(ih_idx - ph, 0); - const int pad_top = std::max(0, ph - ih_idx); - const int pad_bottom = std::max(0, ih_idx - ph + filter - ih); - if (pad_top > 0 || pad_bottom > 0) { - kern_pooling_with_pad_nchw44( - src, dst, filter, 0, ow, iw, ow, stride_w, pw, real_ih_idx, oh_idx, - pad_top, pad_bottom); - - } else { - kern_pooling_with_pad_nchw44( - src, dst, filter, 0, ow_pad_left_end, iw, ow, stride_w, pw, - real_ih_idx, oh_idx, pad_top, pad_bottom); - for (int ow_idx = ow_pad_left_end; ow_idx < ow_pad_right_step_end; - ow_idx += ow_step) { - const int iw_idx = ow_idx * stride_w; - const int real_iw_idx = std::max(iw_idx - pw, 0); - const int src_offset = (real_ih_idx * iw + real_iw_idx) * ic_step; - const int dst_offset = (oh_idx * ow + ow_idx) * oc_step; - KerPoolingFilterXStrideXNchw44::impl( - src + src_offset, dst + dst_offset, iw); - } - kern_pooling_with_pad_nchw44( - src, dst, filter, ow_pad_right_step_end, ow, iw, ow, stride_w, pw, - real_ih_idx, oh_idx, pad_top, pad_bottom); - } - } -} - -template -static inline void pooling_fp32_nchw44_no_pad( - const float32_t* src, float32_t* dst, int, int iw, int oh, int ow) { - constexpr int stride_h = stride; - constexpr int stride_w = stride; - constexpr int ic_step = 4; - constexpr int oc_step = 4; - constexpr int ow_step = 4; - const int ow_end = ow / ow_step * ow_step; - const int ow_remain = ow - ow_end; - - rep(oh_idx, oh) { - const int ih_idx = oh_idx * stride_h; - const int src_ih_offset = ih_idx * iw; - const int dst_oh_offset = oh_idx * ow; - for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int iw_idx = ow_idx * stride_w; - const int src_offset = (src_ih_offset + iw_idx) * ic_step; - const int dst_offset = (dst_oh_offset + ow_idx) * oc_step; - KerPoolingFilterXStrideXNchw44::impl( - src + src_offset, dst + dst_offset, iw); - } - if (ow_remain > 0) { - kern_pooling_with_pad_nchw44( - src, dst, filter, ow_end, ow, iw, ow, stride_w, 0, ih_idx, oh_idx, - 0, 0); - } - } -} - -template -static inline void pooling_fp32_nchw44( - const float32_t* src, float32_t* dst, int ih, int iw, int oh, int ow, int ph, - int pw) { - if (ph > 0 || pw > 0) { - pooling_fp32_nchw44_pad(src, dst, ih, iw, oh, ow, ph, pw); - } else { - pooling_fp32_nchw44_no_pad(src, dst, ih, iw, oh, ow); - } -} - -} // namespace -} // namespace arm_common -} // namespace megdnn - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index 4a356893f..5377ee822 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -22,6 +22,9 @@ private: AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4; AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4; AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + AlgoFilterxModexStridexNCHW88 algo_fp16_filterx_modex_stridex_nchw88; +#endif AlgoFallback algo_fallback; public: @@ -38,6 +41,9 @@ public: all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4); all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4); all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + all_algos.emplace_back(&algo_fp16_filterx_modex_stridex_nchw88); +#endif all_algos.emplace_back(&algo_fallback); for (auto&& algo : all_algos) { diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h index 0b6b862d5..d75127e20 100644 --- a/dnn/src/arm_common/pooling/opr_impl.h +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -24,6 +24,9 @@ private: class AlgoFilter3ModexStridexNCHW44; class AlgoFilter4ModexStridexNCHW44; class AlgoFilter5ModexStridexNCHW44; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + class AlgoFilterxModexStridexNCHW88; +#endif class AlgoFallback; class AlgoPack; static AlgoPack sm_algo_pack; @@ -56,6 +59,9 @@ public: ARM_Filter3ModexStridexNCHW44, ARM_Filter4ModexStridexNCHW44, ARM_Filter5ModexStridexNCHW44, +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + ARM_Fp16FilterxModexStridexNCHW88, +#endif ARM_Fp32ModexStridexNCHW44, ARM_Fallback }; diff --git a/dnn/src/common/unroll_macro.h b/dnn/src/common/unroll_macro.h index 37975b0ce..004adc825 100644 --- a/dnn/src/common/unroll_macro.h +++ b/dnn/src/common/unroll_macro.h @@ -165,6 +165,10 @@ cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) \ cb(8, 4, ##a) cb(8, 5, ##a) cb(8, 6, ##a) cb(8, 7, ##a) cb(8, 8, ##a) +#define UNROLL_RAW_3x2(cb, v0, a...) \ + UNROLL_RAW_2x2(cb, v0, ##a) \ + cb(2, 0, ##a) cb(2, 1, ##a) + #define UNROLL_RAW_4x2(cb, v0, a...) \ cb(0, 0, ##a) cb(0, 1, ##a) cb(1, 0, ##a) cb(1, 1, ##a) \ cb(2, 0, ##a) cb(2, 1, ##a) cb(3, 0, ##a) cb(3, 1, ##a) @@ -177,6 +181,19 @@ UNROLL_RAW_5x2(cb, v0, ##a) \ cb(5, 0, ##a) cb(5, 1, ##a) +#define UNROLL_RAW_9x2(cb, v0, a...) \ + UNROLL_RAW_6x2(cb, v0, ##a) \ + cb(6, 0, ##a) cb(6, 1, ##a) \ + cb(7, 0, ##a) cb(7, 1, ##a) \ + cb(8, 0, ##a) cb(8, 1, ##a) + +#define UNROLL_RAW_13x2(cb, v0, a...) \ + UNROLL_RAW_9x2(cb, v0, ##a) \ + cb(9, 0, ##a) cb(9, 1, ##a) \ + cb(10, 0, ##a) cb(10, 1, ##a) \ + cb(11, 0, ##a) cb(11, 1, ##a) \ + cb(12, 0, ##a) cb(12, 1, ##a) + #define UNROLL_RAW_4x6(cb, v0, a...) \ cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) cb(0, 4, ##a) cb(0, 5, ##a) \ cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a) cb(1, 4, ##a) cb(1, 5, ##a) \ @@ -186,6 +203,28 @@ UNROLL_RAW_4x6(cb, v0, ##a) \ cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a) cb(4, 4, ##a) cb(4, 5, ##a) +#define UNROLL_RAW_2x4(cb, v0, a...) \ + cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) \ + cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a) +#define UNROLL_RAW_3x4(cb, v0, a...) \ + UNROLL_RAW_2x4(cb, v0, ##a) \ + cb(2, 0, ##a) cb(2, 1, ##a) cb(2, 2, ##a) cb(2, 3, ##a) +#define UNROLL_RAW_5x4(cb, v0, a...) \ + UNROLL_RAW_4x4(cb, v0, ##a) \ + cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a) +#define UNROLL_RAW_9x4(cb, v0, a...) \ + UNROLL_RAW_5x4(cb, v0, ##a) \ + cb(5, 0, ##a) cb(5, 1, ##a) cb(5, 2, ##a) cb(5, 3, ##a) \ + cb(6, 0, ##a) cb(6, 1, ##a) cb(6, 2, ##a) cb(6, 3, ##a) \ + cb(7, 0, ##a) cb(7, 1, ##a) cb(7, 2, ##a) cb(7, 3, ##a) \ + cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) +#define UNROLL_RAW_13x4(cb, v0, a...) \ + UNROLL_RAW_9x4(cb, v0, ##a) \ + cb(9, 0, ##a) cb(9, 1, ##a) cb(9, 2, ##a) cb(9, 3, ##a) \ + cb(10, 0, ##a) cb(10, 1, ##a) cb(10, 2, ##a) cb(10, 3, ##a) \ + cb(11, 0, ##a) cb(11, 1, ##a) cb(11, 2, ##a) cb(11, 3, ##a) \ + cb(12, 0, ##a) cb(12, 1, ##a) cb(12, 2, ##a) cb(12, 3, ##a) + #define UNROLL_CALL0_D2(step, step2, cb, v...) \ UNROLL_RAW_##step##x##step2(cb, 0, ##v) #define UNROLL_CALL1_D2(step, step2, cb, v...) \ diff --git a/dnn/test/arm_common/pooling.cpp b/dnn/test/arm_common/pooling.cpp index ae6b824d3..13558aed1 100644 --- a/dnn/test/arm_common/pooling.cpp +++ b/dnn/test/arm_common/pooling.cpp @@ -216,6 +216,48 @@ TEST_F(ARM_COMMON, POOLING_FP16) { checker.set_param(param).exec({{2, 3, ih, iw}, {}}); } } + +TEST_F(ARM_COMMON, POOLING_FP16_NCHW88) { + Checker checker(handle()); + checker.set_dtype(0, dtype::Float16{}); + checker.set_dtype(1, dtype::Float16{}); + checker.set_dtype(2, dtype::Float16{}); + checker.set_dtype(4, dtype::Float16{}); + checker.set_epsilon(0.003); + for (size_t ic : {1, 2, 3, 5, 7, 11}) + for (size_t ih : {20, 15}) + for (size_t iw : {15, 20, 27, 51, 76, 101, 256}) + for (size_t pad : {2, 3, 5}) + for (auto mode : + {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) { + param::Pooling param; + param.mode = mode; + param.format = param::Pooling::Format::NCHW88; + param.pad_h = pad; + param.pad_w = pad; + for (size_t kernel : {2, 3, 4, 5}) { + if (kernel > pad && ih + 2 * pad >= kernel && + iw + 2 * pad >= kernel) { + param.window_h = param.window_w = kernel; + param.stride_h = param.stride_w = 1; + checker.set_param(param).exec( + TensorShapeArray{{2, ic, ih, iw, 8}, {}}); + param.stride_h = param.stride_w = 2; + checker.set_param(param).exec( + TensorShapeArray{{2, ic, ih, iw, 8}, {}}); + } + } + for (size_t kernel : {9, 13}) { + if (kernel > pad && ih + 2 * pad >= kernel && + iw + 2 * pad >= kernel) { + param.window_h = param.window_w = kernel; + param.stride_h = param.stride_w = 1; + checker.set_param(param).exec( + TensorShapeArray{{2, ic, ih, iw, 8}, {}}); + } + } + } +} #endif TEST_F(ARM_COMMON, POOLING_QUANTIZED) { @@ -367,6 +409,72 @@ TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_POOLING_NCHW44_FP32) { benchmark_nchw44_fp32(handle()); } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void benchmark_nchw88_fp16(Handle* handle) { + using Param = param::Pooling; + auto run = [&](size_t n, size_t c, size_t h, size_t w, size_t filter, size_t stride, + size_t pad, Param::Mode mode) { + Param param; + param.window_h = param.window_w = filter; + param.stride_h = param.stride_w = stride; + param.pad_h = param.pad_w = pad; + param.format = Param::Format::NCHW44; + param.mode = mode; + TensorShape nchw44_shape = {n, c / 4, h, w, 4}; + TensorShape nchw88_shape = {n, c / 8, h, w, 8}; + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = param; + opr->deduce_layout({nchw44_shape, dtype::Float32()}, dst_layout); + float calc_amount = + dst_layout.total_nr_elems() * param.window_h * param.window_w; + + Benchmarker benchmarker_float16_nchw88(handle); + Benchmarker benchmarker_float32_nchw44(handle); + size_t RUN = 500; + auto t1 = benchmarker_float32_nchw44.set_display(false) + .set_times(RUN) + .set_param(param) + .exec({nchw44_shape, {}}); + + param.format = Param::Format::NCHW88; + auto t2 = benchmarker_float16_nchw88.set_display(false) + .set_dtype(0, dtype::Float16{}) + .set_dtype(1, dtype::Float16{}) + .set_dtype(2, dtype::Float16{}) + .set_dtype(4, dtype::Float16{}) + .set_times(RUN) + .set_param(param) + .exec({nchw88_shape, {}}); + + printf("{%zu %zu %zu %zu} filter = %zu, stride = %zu pad = %zu\n" + "nchw44_fp32={%.3f ms, %.3f Mflops}, " + "nchw88_fp16={%.3f ms, %.3f Mflops, speed_up %f}\n\n", + n, c, h, w, filter, stride, pad, t1 / RUN, + calc_amount / (t1 / RUN * 1000), t2 / RUN, + calc_amount / (t2 / RUN * 1000), t1 / t2); + }; + // Resnet50 + run(1, 64, 112, 112, 3, 2, 1, param::Pooling::Mode::MAX); + run(1, 2048, 7, 7, 7, 1, 0, param::Pooling::Mode::AVERAGE); + + // VGG16 + run(1, 64, 224, 224, 2, 2, 0, param::Pooling::Mode::MAX); + run(1, 128, 112, 112, 2, 2, 0, param::Pooling::Mode::MAX); + run(1, 256, 56, 56, 2, 2, 0, param::Pooling::Mode::MAX); + run(1, 512, 28, 28, 2, 2, 0, param::Pooling::Mode::MAX); + run(1, 512, 14, 14, 2, 2, 0, param::Pooling::Mode::MAX); +} + +TEST_F(ARM_COMMON, BENCHMARK_POOLING_NCHW88_FP16) { + benchmark_nchw88_fp16(handle()); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_POOLING_NCHW88_FP16) { + benchmark_nchw88_fp16(handle()); +} +#endif + TEST_F(ARM_COMMON, BENCHMARK_POOLING_INT8_W3x3_S2x2) { using Param = param::Pooling; auto run = [&](const TensorShapeArray& shapes, Param param) { -- GitLab