From bde5cf35642233949d97ed5b45dbae45b2b530dc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 19 Sep 2021 21:09:57 +0800 Subject: [PATCH] feat(dnn): add resize linear for arm GitOrigin-RevId: 14ac5bda3f60f530ca9d42e94b1f4e401d0a1309 --- dnn/src/arm_common/resize/direct_nchwxx.cpp | 105 ++++++ dnn/src/arm_common/resize/direct_nchwxx.h | 36 ++ dnn/src/arm_common/resize/helper.h | 134 +++++++ dnn/src/arm_common/resize/opr_impl.cpp | 329 ++++++++---------- dnn/src/arm_common/resize/opr_impl.h | 10 - dnn/src/arm_common/resize/upsample2_nchw.cpp | 228 ++++++++++++ dnn/src/arm_common/resize/upsample2_nchw.h | 36 ++ .../arm_common/resize/upsample2_nchwxx.cpp | 197 +++++++++++ dnn/src/arm_common/resize/upsample2_nchwxx.h | 36 ++ dnn/test/arm_common/resize.cpp | 58 ++- dnn/test/cuda/resize.cpp | 3 +- 11 files changed, 976 insertions(+), 196 deletions(-) create mode 100644 dnn/src/arm_common/resize/direct_nchwxx.cpp create mode 100644 dnn/src/arm_common/resize/direct_nchwxx.h create mode 100644 dnn/src/arm_common/resize/helper.h create mode 100644 dnn/src/arm_common/resize/upsample2_nchw.cpp create mode 100644 dnn/src/arm_common/resize/upsample2_nchw.h create mode 100644 dnn/src/arm_common/resize/upsample2_nchwxx.cpp create mode 100644 dnn/src/arm_common/resize/upsample2_nchwxx.h diff --git a/dnn/src/arm_common/resize/direct_nchwxx.cpp b/dnn/src/arm_common/resize/direct_nchwxx.cpp new file mode 100644 index 000000000..93edd1c67 --- /dev/null +++ b/dnn/src/arm_common/resize/direct_nchwxx.cpp @@ -0,0 +1,105 @@ +/** + * \file dnn/src/arm_common/resize/direct_nchwxx.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/resize/direct_nchwxx.h" + +#include "src/arm_common/resize/helper.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +using namespace megdnn; +using namespace arm_common; +using namespace resize; + +namespace { + +template +void resize_direct_nchwxx(const ctype* sptr, ctype* dptr, size_t N, size_t IH, + size_t IW, size_t OH, size_t OW) { + using simd_helper = SIMDHelper; + constexpr size_t PC = simd_helper::simd_width; + using simd_type = typename simd_helper::simd_type; + + float scale_h = static_cast(OH) / IH; + float scale_w = static_cast(OW) / IW; + + for (size_t n = 0; n < N; ++n) { + for (size_t oh = 0; oh < OH; ++oh) { + for (size_t ow = 0; ow < OW; ++ow) { + int ih0, ih1, iw0, iw1; + float ah0, ah1, aw0, aw1; + + std::tie(ah0, ih0, ah1, ih1) = + get_nearest_linear_coord(imode, scale_h, IH, oh); + std::tie(aw0, iw0, aw1, iw1) = + get_nearest_linear_coord(imode, scale_w, IW, ow); + + simd_type r0 = simd_helper::load(sptr + (ih0 * IW + iw0) * PC); + simd_type r1 = simd_helper::load(sptr + (ih0 * IW + iw1) * PC); + simd_type r2 = simd_helper::load(sptr + (ih1 * IW + iw0) * PC); + simd_type r3 = simd_helper::load(sptr + (ih1 * IW + iw1) * PC); + + // FIXME: weight fp16 may cause precision problem + ctype a0 = ah0 * aw0; + ctype a1 = ah0 * aw1; + ctype a2 = ah1 * aw0; + ctype a3 = ah1 * aw1; + + simd_type c = simd_helper::dup(0); + c = simd_helper::fma(c, r0, a0); + c = simd_helper::fma(c, r1, a1); + c = simd_helper::fma(c, r2, a2); + c = simd_helper::fma(c, r3, a3); + + simd_helper::store(dptr + (oh * OW + ow) * PC, c); + } + } + sptr += IH * IW * PC; + dptr += OH * OW * PC; + } +} +} + +void megdnn::arm_common::resize_direct_nearest_nchw44_fp32( + const ResizeImpl::KernParam& kern_param) { + resize_direct_nchwxx( + kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, + kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); +} + +void megdnn::arm_common::resize_direct_linear_nchw44_fp32( + const ResizeImpl::KernParam& kern_param) { + resize_direct_nchwxx( + kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, + kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +void megdnn::arm_common::resize_direct_nearest_nchw88_fp16( + const ResizeImpl::KernParam& kern_param) { + auto sptr = reinterpret_cast(kern_param.sptr); + auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); + resize_direct_nchwxx<__fp16, InterpolationMode::INTER_NEAREST>( + sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, + kern_param.iw, kern_param.oh, kern_param.ow); +} + +void megdnn::arm_common::resize_direct_linear_nchw88_fp16( + const ResizeImpl::KernParam& kern_param) { + auto sptr = reinterpret_cast(kern_param.sptr); + auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); + resize_direct_nchwxx<__fp16, InterpolationMode::INTER_LINEAR>( + sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, + kern_param.iw, kern_param.oh, kern_param.ow); +} + +#endif diff --git a/dnn/src/arm_common/resize/direct_nchwxx.h b/dnn/src/arm_common/resize/direct_nchwxx.h new file mode 100644 index 000000000..aec01a5cb --- /dev/null +++ b/dnn/src/arm_common/resize/direct_nchwxx.h @@ -0,0 +1,36 @@ +/** + * \file dnn/src/arm_common/resize/direct_nchwxx.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "src/arm_common/resize/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +void resize_direct_linear_nchw44_fp32( + const ResizeImpl::KernParam& kern_param); + +void resize_direct_nearest_nchw44_fp32( + const ResizeImpl::KernParam& kern_param); + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +void resize_direct_linear_nchw88_fp16( + const ResizeImpl::KernParam& kern_param); + +void resize_direct_nearest_nchw88_fp16( + const ResizeImpl::KernParam& kern_param); + +#endif + +} // namespace arm_common +} // namespace megdnn diff --git a/dnn/src/arm_common/resize/helper.h b/dnn/src/arm_common/resize/helper.h new file mode 100644 index 000000000..4117024fc --- /dev/null +++ b/dnn/src/arm_common/resize/helper.h @@ -0,0 +1,134 @@ +/** + * \file dnn/src/arm_common/resize/helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace arm_common { +namespace resize { + +using InterpolationMode = Resize::InterpolationMode; + +template +struct SIMDHelper {}; + +template <> +struct SIMDHelper { + using simd_type = float32x4_t; + using simd_type_x2 = float32x4x2_t; + using ctype = float; + static constexpr size_t simd_width = 4; + + static inline simd_type load(const ctype* src_ptr) { + return vld1q_f32(src_ptr); + } + static inline void store(ctype* dst_ptr, const simd_type& rdst) { + vst1q_f32(dst_ptr, rdst); + } + static inline void store2_interleave(ctype* dst_ptr, const simd_type& rdst1, + const simd_type& rdst2) { + simd_type_x2 rdst; + rdst.val[0] = rdst1; + rdst.val[1] = rdst2; + vst2q_f32(dst_ptr, rdst); + } + static inline simd_type fma(const simd_type& a, const simd_type& b, + ctype n) { +#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__) + return vfmaq_n_f32(a, b, n); +#else + return vmlaq_n_f32(a, b, n); +#endif + } + static inline simd_type fma(const simd_type& a, const simd_type& b, + const simd_type& c) { +#if defined(__ARM_FEATURE_FMA) + return vfmaq_f32(a, b, c); +#else + return vmlaq_f32(a, b, c); +#endif + } + static inline simd_type dup(float val) { return vdupq_n_f32(val); } +}; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +template <> +struct SIMDHelper<__fp16> { + using simd_type = float16x8_t; + using simd_type_x2 = float16x8x2_t; + using ctype = __fp16; + static constexpr size_t simd_width = 8; + + static inline simd_type load(const ctype* src_ptr) { + return vld1q_f16(src_ptr); + } + static inline void store(ctype* dst_ptr, const simd_type& rdst) { + vst1q_f16(dst_ptr, rdst); + } + static inline void store2_interleave(ctype* dst_ptr, const simd_type& rdst1, + const simd_type& rdst2) { + simd_type_x2 rdst; + rdst.val[0] = rdst1; + rdst.val[1] = rdst2; + vst2q_f16(dst_ptr, rdst); + } + static inline simd_type fma(const simd_type& a, const simd_type& b, + ctype n) { +#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__) + return vfmaq_n_f16(a, b, n); +#else + return vaddq_f16(a, vmulq_n_f16(b, n)); +#endif + } + static inline simd_type fma(const simd_type& a, const simd_type& b, + const simd_type& c) { + return vfmaq_f16(a, b, c); + } + static inline simd_type dup(float val) { return vdupq_n_f16(val); } +}; + +#endif + +static inline int get_nearest_src(float scale, int size, int idx) { + return std::min(static_cast(idx / scale), size - 1); +} + +static inline std::tuple get_nearest_linear_coord( + InterpolationMode imode, float scale, int size, int idx) { + if (size == 1) { + return std::make_tuple(1.0f, 0, 0.0f, 0); + } + + float alpha = (idx + 0.5f) / scale - 0.5f; + int origin_idx = static_cast(floor(alpha)); + alpha -= origin_idx; + + if (imode == InterpolationMode::INTER_NEAREST) { + origin_idx = get_nearest_src(scale, size, idx); + alpha = 0; + } + + if (origin_idx < 0) { + origin_idx = 0; + alpha = 0; + } else if (origin_idx + 1 >= size) { + origin_idx = size - 2; + alpha = 1; + } + + return std::make_tuple(1 - alpha, origin_idx, alpha, origin_idx + 1); +} +}; +}; +}; diff --git a/dnn/src/arm_common/resize/opr_impl.cpp b/dnn/src/arm_common/resize/opr_impl.cpp index 9d145a836..3c421becc 100644 --- a/dnn/src/arm_common/resize/opr_impl.cpp +++ b/dnn/src/arm_common/resize/opr_impl.cpp @@ -12,212 +12,181 @@ #include "src/arm_common/resize/opr_impl.h" #include "src/arm_common/handle.h" +#include "src/arm_common/resize/direct_nchwxx.h" #include "src/arm_common/resize/resize_cv.h" +#include "src/arm_common/resize/upsample2_nchw.h" +#include "src/arm_common/resize/upsample2_nchwxx.h" #include "src/arm_common/simd_macro/marm_neon.h" -using namespace megdnn; -using namespace arm_common; +#include "midout.h" +MIDOUT_DECL(megdnn_arm_resize) + +namespace megdnn { +namespace arm_common { void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); - if (param().format == param::Resize::Format::NCHW44 || - param().format == param::Resize::Format::NCHW88) { - bool is_contiguous = - src.layout.is_contiguous() && dst.layout.is_contiguous(); - bool dtype_same = src.layout.dtype == dst.layout.dtype; - bool nchw44_enable = param().format == param::Resize::Format::NCHW44 && - src.layout.dtype == dtype::Float32(); - bool nchw88_enable = - param().format == param::Resize::Format::NCHW88 && - DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false); - bool interp_supported = - param().imode == - param::Resize::InterpolationMode::INTER_NEAREST || - param().imode == param::Resize::InterpolationMode::INTER_LINEAR; - bool is_upsample2 = - param().imode == - param::Resize::InterpolationMode::INTER_NEAREST && - src.layout.shape[2] * 2 == dst.layout.shape[2] && - src.layout.shape[3] * 2 == dst.layout.shape[3]; - bool need_fallback = !is_contiguous || !dtype_same || - !interp_supported || - (!nchw44_enable && !nchw88_enable); + bool is_contiguous = + src.layout.is_contiguous() && dst.layout.is_contiguous(); + bool is_dtype_same = src.layout.dtype == dst.layout.dtype; + bool is_dtype_fp32 = src.layout.dtype == dtype::Float32(); + bool is_dtype_fp16 = + DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false); + bool is_dtype_supported = is_dtype_same && (is_dtype_fp32 || is_dtype_fp16); + + bool is_nchw = param().format == param::Resize::Format::NCHW && + (is_dtype_fp32 || is_dtype_fp16); + bool is_nchw44_fp32 = + param().format == param::Resize::Format::NCHW44 && is_dtype_fp32; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + bool is_nchw88_fp16 = + param().format == param::Resize::Format::NCHW88 && is_dtype_fp16; +#endif - if (need_fallback) { - fallback::ResizeImpl::exec(src, dst, workspace); - } else if (nchw44_enable) { - auto kern_param = KernParam::from_tensors( - param().format, param().imode, src, dst, workspace); + bool is_imode_nearest = + param().imode == param::Resize::InterpolationMode::INTER_NEAREST; + bool is_imode_linear = + param().imode == param::Resize::InterpolationMode::INTER_LINEAR; + bool is_imode_supported = is_imode_nearest || is_imode_linear; + + bool is_upsample2 = src.layout.shape[2] * 2 == dst.layout.shape[2] && + src.layout.shape[3] * 2 == dst.layout.shape[3]; + bool usable = is_contiguous && is_dtype_supported && is_imode_supported; + + if (param().format == param::Resize::Format::NHWC && + (src.layout[3] == 1 || src.layout[3] == 3) && + is_nhwc_contig_wc(src.layout)) { + MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode)); + } else if (!usable) { + fallback::ResizeImpl::exec(src, dst, workspace); + } else if (is_dtype_fp32) { + auto kern_param = KernParam::from_tensors( + param().format, param().imode, src, dst, workspace); + if (is_nchw44_fp32) { if (is_upsample2) { - MEGDNN_DISPATCH_CPU_KERN_OPR( - kern_nearest_upsample2_pack_simd_width(src, dst)); + if (is_imode_nearest) { + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(0)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_nearest_upsample2_nchw44_fp32( + kern_param)); + } + MIDOUT_END(); + } else { + megdnn_assert(is_imode_linear, "invalid imode"); + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(1)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_linear_upsample2_nchw44_fp32( + kern_param)); + } + MIDOUT_END(); + } } else { - MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw44_fp32(kern_param)); + if (is_imode_nearest) { + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(2)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_direct_nearest_nchw44_fp32(kern_param)); + } + MIDOUT_END(); + } else { + megdnn_assert(is_imode_linear, "invalid imode"); + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(3)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_direct_linear_nchw44_fp32(kern_param)); + } + MIDOUT_END(); + } } -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - } else if (nchw88_enable) { - auto kern_param = KernParam::from_tensors( - param().format, param().imode, src, dst, workspace); + } else if (is_nchw) { if (is_upsample2) { - MEGDNN_DISPATCH_CPU_KERN_OPR( - kern_nearest_upsample2_pack_simd_width(src, dst)); + if (is_imode_nearest) { + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(4)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_nearest_upsample2_nchw_fp32(kern_param)); + } + MIDOUT_END(); + } else { + megdnn_assert(is_imode_linear, "invalid imode"); + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(5)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_linear_upsample2_nchw_fp32(kern_param)); + } + MIDOUT_END(); + } } else { - MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw88_fp16(kern_param)); + fallback::ResizeImpl::exec(src, dst, workspace); } -#endif } else { fallback::ResizeImpl::exec(src, dst, workspace); } - } else if (param().format == param::Resize::Format::NCHW || - (src.layout[3] != 1 && src.layout[3] != 3) || - !is_nhwc_contig_wc(src.layout)) { - fallback::ResizeImpl::exec(src, dst, workspace); - } else { - megdnn_assert(param().format == param::Resize::Format::NHWC, - "invalid resize format"); - MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode)); - } -} - -template -void ResizeImpl::kern_nchw44_fp32(const KernParam& kern_param) { - UNPACK_RESIZE_FWD_KERN_PARAM(kern_param); - float scale_h = static_cast(OH) / IH; - float scale_w = static_cast(OW) / IW; - - for (size_t n = 0; n < N; ++n) { - for (size_t c = 0; c < C / 4; ++c) { - for (size_t oh = 0; oh < OH; ++oh) { - for (size_t ow = 0; ow < OW; ++ow) { - int ih0, ih1, iw0, iw1; - float ah0, ah1, aw0, aw1; - - std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( - kern_param.imode, scale_h, IH, oh); - std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( - kern_param.imode, scale_w, IW, ow); - -#define SRC_ADDRESS(ih, iw) \ - (sptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 4) -#define DST_ADDRESS(oh, ow) \ - (dptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 4) - float32x4_t r0 = vld1q_f32(SRC_ADDRESS(ih0, iw0)); - float32_t a0 = ah0 * aw0; - float32x4_t r1 = vld1q_f32(SRC_ADDRESS(ih0, iw1)); - float32_t a1 = ah0 * aw1; - float32x4_t r2 = vld1q_f32(SRC_ADDRESS(ih1, iw0)); - float32_t a2 = ah1 * aw0; - float32x4_t r3 = vld1q_f32(SRC_ADDRESS(ih1, iw1)); - float32_t a3 = ah1 * aw1; - - r0 = vmulq_n_f32(r0, a0); -#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__) - r0 = vfmaq_n_f32(r0, r1, a1); - r0 = vfmaq_n_f32(r0, r2, a2); - r0 = vfmaq_n_f32(r0, r3, a3); -#else - r0 = vaddq_f32(r0, vmulq_n_f32(r1, a1)); - r0 = vaddq_f32(r0, vmulq_n_f32(r2, a2)); - r0 = vaddq_f32(r0, vmulq_n_f32(r3, a3)); -#endif - - vst1q_f32(DST_ADDRESS(oh, ow), r0); -#undef SRC_ADDRESS -#undef DST_ADDRESS - } - } - } - } -} - #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template -void ResizeImpl::kern_nchw88_fp16(const KernParam& kern_param) { - UNPACK_RESIZE_FWD_KERN_PARAM(kern_param); - float scale_h = static_cast(OH) / IH; - float scale_w = static_cast(OW) / IW; - const float16_t* src_ptr = reinterpret_cast(sptr); - float16_t* dst_ptr = reinterpret_cast(dptr); - - for (size_t n = 0; n < N; ++n) { - for (size_t c = 0; c < C / 8; ++c) { - for (size_t oh = 0; oh < OH; ++oh) { - for (size_t ow = 0; ow < OW; ++ow) { - int ih0, ih1, iw0, iw1; - float ah0, ah1, aw0, aw1; - - std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( - kern_param.imode, scale_h, IH, oh); - std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( - kern_param.imode, scale_w, IW, ow); - -#define SRC_ADDRESS(ih, iw) \ - (src_ptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 8) -#define DST_ADDRESS(oh, ow) \ - (dst_ptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 8) - float16x8_t r0 = vld1q_f16(SRC_ADDRESS(ih0, iw0)); - float32_t a0 = ah0 * aw0; - float16x8_t r1 = vld1q_f16(SRC_ADDRESS(ih0, iw1)); - float32_t a1 = ah0 * aw1; - float16x8_t r2 = vld1q_f16(SRC_ADDRESS(ih1, iw0)); - float32_t a2 = ah1 * aw0; - float16x8_t r3 = vld1q_f16(SRC_ADDRESS(ih1, iw1)); - float32_t a3 = ah1 * aw1; - - r0 = vmulq_n_f16(r0, a0); -#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__) - r0 = vfmaq_n_f16(r0, r1, a1); - r0 = vfmaq_n_f16(r0, r2, a2); - r0 = vfmaq_n_f16(r0, r3, a3); -#else - r0 = vaddq_f16(r0, vmulq_n_f16(r1, a1)); - r0 = vaddq_f16(r0, vmulq_n_f16(r2, a2)); - r0 = vaddq_f16(r0, vmulq_n_f16(r3, a3)); -#endif - - vst1q_f16(DST_ADDRESS(oh, ow), r0); -#undef SRC_ADDRESS -#undef DST_ADDRESS + } else if (is_dtype_fp16) { + auto kern_param = KernParam::from_tensors( + param().format, param().imode, src, dst, workspace); + if (is_nchw88_fp16) { + if (is_upsample2) { + if (is_imode_nearest) { + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(6)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_nearest_upsample2_nchw88_fp16( + kern_param)); + } + MIDOUT_END(); + } else { + megdnn_assert(is_imode_linear, "invalid imode"); + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(7)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_linear_upsample2_nchw88_fp16( + kern_param)); + } + MIDOUT_END(); + } + } else { + if (is_imode_nearest) { + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(8)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_direct_nearest_nchw88_fp16(kern_param)); + } + MIDOUT_END(); + } else { + megdnn_assert(is_imode_linear, "invalid imode"); + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(9)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_direct_linear_nchw88_fp16(kern_param)); + } + MIDOUT_END(); } } - } - } -} -#endif - -void ResizeImpl::kern_nearest_upsample2_pack_simd_width( - _megdnn_tensor_in src, _megdnn_tensor_out dst) { - const uint8_t* src_ptr = reinterpret_cast(src.raw_ptr); - uint8_t* dst_ptr = reinterpret_cast(dst.raw_ptr); - - size_t S = 2; - size_t N = src.layout.shape[0]; - size_t IC = src.layout.shape[1]; - size_t IH = src.layout.shape[2]; - size_t IW = src.layout.shape[3]; - size_t OH = dst.layout.shape[2]; - size_t OW = dst.layout.shape[3]; - - for (size_t i = 0; i < N * IC; ++i) { - for (size_t ih = 0; ih < IH; ++ih) { - for (size_t iw = 0; iw < IW; ++iw) { - size_t oh = ih * S; - size_t ow = iw * S; - uint8x16_t r0 = vld1q_u8(src_ptr + i * IH * IW * 16 + - ih * IW * 16 + iw * 16); - - for (size_t fh = 0; fh < S; ++fh) { - for (size_t fw = 0; fw < S; ++fw) { - vst1q_u8(dst_ptr + i * OH * OW * 16 + - (oh + fh) * OW * 16 + (ow + fw) * 16, - r0); + } else if (is_nchw) { + if (is_upsample2) { + if (is_imode_nearest) { + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(10)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_nearest_upsample2_nchw_fp16(kern_param)); } + MIDOUT_END(); + } else { + megdnn_assert(is_imode_linear, "invalid imode"); + MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(11)) { + MEGDNN_DISPATCH_CPU_KERN_OPR( + resize_linear_upsample2_nchw_fp16(kern_param)); + } + MIDOUT_END(); } + } else { + fallback::ResizeImpl::exec(src, dst, workspace); } + } else { + fallback::ResizeImpl::exec(src, dst, workspace); } +#endif + } else { + fallback::ResizeImpl::exec(src, dst, workspace); } } +} // namespace arm_common +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/resize/opr_impl.h b/dnn/src/arm_common/resize/opr_impl.h index f40f25212..6f2233685 100644 --- a/dnn/src/arm_common/resize/opr_impl.h +++ b/dnn/src/arm_common/resize/opr_impl.h @@ -26,16 +26,6 @@ public: const TensorLayout&) override { return 0; } - -private: - template - void kern_nchw44_fp32(const KernParam& kern_param); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - template - void kern_nchw88_fp16(const KernParam& kern_param); -#endif - void kern_nearest_upsample2_pack_simd_width(_megdnn_tensor_in src, - _megdnn_tensor_out dst); }; } // namespace arm_common diff --git a/dnn/src/arm_common/resize/upsample2_nchw.cpp b/dnn/src/arm_common/resize/upsample2_nchw.cpp new file mode 100644 index 000000000..ac012d40f --- /dev/null +++ b/dnn/src/arm_common/resize/upsample2_nchw.cpp @@ -0,0 +1,228 @@ +/** + * \file dnn/src/arm_common/resize/upsample2_nchw.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/resize/upsample2_nchw.h" + +#include "src/arm_common/resize/helper.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +using namespace megdnn; +using namespace arm_common; +using namespace resize; + +namespace { + +template +static inline ctype compute_linear_element(const ctype src[4], + const ctype alpha[2]) { + return src[0] * alpha[0 ^ fh] * alpha[0 ^ fw] + + src[1] * alpha[0 ^ fh] * alpha[1 ^ fw] + + src[2] * alpha[1 ^ fh] * alpha[0 ^ fw] + + src[3] * alpha[1 ^ fh] * alpha[1 ^ fw]; +} + +template +static inline typename simd_helper::simd_type compute_linear_element_simd( + const typename simd_helper::simd_type src[4], + const typename simd_helper::simd_type alpha[2][2]) { + typename simd_helper::simd_type c = simd_helper::dup(0); + c = simd_helper::fma(c, src[0], alpha[0 ^ fh][0 ^ fw]); + c = simd_helper::fma(c, src[1], alpha[0 ^ fh][1 ^ fw]); + c = simd_helper::fma(c, src[2], alpha[1 ^ fh][0 ^ fw]); + c = simd_helper::fma(c, src[3], alpha[1 ^ fh][1 ^ fw]); + return c; +} + +template +static inline void compute_linear_2x2_element(const ctype* src, ctype* dst, + size_t IW, size_t OW, + const ctype alpha[2]) { + const ctype* src_ptr[4] = {src, src, src, src}; + + if (has_right) { + src_ptr[1] += 1; + src_ptr[3] += 1; + } + if (has_bottom) { + src_ptr[2] += IW; + src_ptr[3] += IW; + } + + ctype rsrc[4]; + rsrc[0] = *src_ptr[0]; + rsrc[1] = *src_ptr[1]; + rsrc[2] = *src_ptr[2]; + rsrc[3] = *src_ptr[3]; + + dst[0] = compute_linear_element(rsrc, alpha); + if (has_right) { + dst[1] = compute_linear_element(rsrc, alpha); + } + if (has_bottom) { + dst[OW] = compute_linear_element(rsrc, alpha); + } + if (has_right && has_bottom) { + dst[OW + 1] = compute_linear_element(rsrc, alpha); + } +} + +template +static inline void compute_linear_2x2_element_simd( + const typename simd_helper::ctype* src, + typename simd_helper::ctype* dst, size_t IW, size_t OW, + const typename simd_helper::simd_type alpha[2][2]) { + using simd_type = typename simd_helper::simd_type; + + simd_type rsrc[4]; + rsrc[0] = simd_helper::load(src); + rsrc[1] = simd_helper::load(src + 1); + rsrc[2] = simd_helper::load(src + IW); + rsrc[3] = simd_helper::load(src + IW + 1); + + simd_type rdst[4]; + rdst[0] = compute_linear_element_simd(rsrc, alpha); + rdst[1] = compute_linear_element_simd(rsrc, alpha); + rdst[2] = compute_linear_element_simd(rsrc, alpha); + rdst[3] = compute_linear_element_simd(rsrc, alpha); + + simd_helper::store2_interleave(dst, rdst[0], rdst[1]); + simd_helper::store2_interleave(dst + OW, rdst[2], rdst[3]); +} + +template +void linear_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N, + size_t IH, size_t IW) { + using simd_helper = SIMDHelper; + size_t OW = IW * 2; + constexpr size_t PC = simd_helper::simd_width; + + ctype alpha[2] = {0.75, 0.25}; + + typename simd_helper::simd_type simd_alpha[2][2]; + simd_alpha[0][0] = simd_helper::dup(0.75 * 0.75); + simd_alpha[0][1] = simd_helper::dup(0.75 * 0.25); + simd_alpha[1][0] = simd_helper::dup(0.25 * 0.75); + simd_alpha[1][1] = simd_helper::dup(0.25 * 0.25); + + for (size_t i = 0; i < N; ++i) { + compute_linear_2x2_element(src_ptr, dst_ptr, IW, + OW, alpha); + { + for (size_t iw = 0; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha); + } + } + compute_linear_2x2_element( + src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha); + dst_ptr += OW; + + for (size_t ih = 0; ih + 1 < IH; ++ih) { + compute_linear_2x2_element(src_ptr, dst_ptr, IW, + OW, alpha); + size_t iw = 0; + for (; iw + PC < IW; iw += PC) { + compute_linear_2x2_element_simd( + src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, + simd_alpha); + } + for (; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha); + } + compute_linear_2x2_element( + src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha); + + src_ptr += IW; + dst_ptr += 2 * OW; + } + + compute_linear_2x2_element(src_ptr, dst_ptr, IW, + OW, alpha); + { + for (size_t iw = 0; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, alpha); + } + } + compute_linear_2x2_element( + src_ptr + (IW - 1), dst_ptr + (OW - 1), IW, OW, alpha); + src_ptr += IW; + dst_ptr += OW; + } +} + +template +void nearest_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N, + size_t IH, size_t IW) { + using simd_helper = SIMDHelper; + size_t OW = IW * 2; + constexpr size_t PC = simd_helper::simd_width; + + for (size_t i = 0; i < N; ++i) { + for (size_t ih = 0; ih < IH; ++ih) { + size_t iw = 0; + for (; iw + PC - 1 < IW; iw += PC) { + typename simd_helper::simd_type r0 = + simd_helper::load(src_ptr + iw); + + simd_helper::store2_interleave(dst_ptr + (iw * 2), r0, r0); + simd_helper::store2_interleave(dst_ptr + (OW + iw * 2), r0, r0); + } + for (; iw < IW; iw += 1) { + ctype v = src_ptr[iw]; + dst_ptr[iw * 2] = v; + dst_ptr[iw * 2 + 1] = v; + dst_ptr[OW + iw * 2] = v; + dst_ptr[OW + iw * 2 + 1] = v; + } + src_ptr += IW; + dst_ptr += 2 * OW; + } + } +} + +} // namespace + +void megdnn::arm_common::resize_linear_upsample2_nchw_fp32( + const ResizeImpl::KernParam& kern_param) { + linear_upsample2_nchw(kern_param.sptr, kern_param.dptr, + kern_param.n * kern_param.c, kern_param.ih, + kern_param.iw); +} + +void megdnn::arm_common::resize_nearest_upsample2_nchw_fp32( + const ResizeImpl::KernParam& kern_param) { + nearest_upsample2_nchw(kern_param.sptr, kern_param.dptr, + kern_param.n * kern_param.c, kern_param.ih, + kern_param.iw); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +void megdnn::arm_common::resize_linear_upsample2_nchw_fp16( + const ResizeImpl::KernParam& kern_param) { + auto sptr = reinterpret_cast(kern_param.sptr); + auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); + linear_upsample2_nchw(sptr, dptr, kern_param.n * kern_param.c, + kern_param.ih, kern_param.iw); +} + +void megdnn::arm_common::resize_nearest_upsample2_nchw_fp16( + const ResizeImpl::KernParam& kern_param) { + auto sptr = reinterpret_cast(kern_param.sptr); + auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); + nearest_upsample2_nchw(sptr, dptr, kern_param.n * kern_param.c, + kern_param.ih, kern_param.iw); +} + +#endif diff --git a/dnn/src/arm_common/resize/upsample2_nchw.h b/dnn/src/arm_common/resize/upsample2_nchw.h new file mode 100644 index 000000000..3b6aa7ce2 --- /dev/null +++ b/dnn/src/arm_common/resize/upsample2_nchw.h @@ -0,0 +1,36 @@ +/** + * \file dnn/src/arm_common/resize/upsample2_nchw.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "src/arm_common/resize/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +void resize_linear_upsample2_nchw_fp32( + const ResizeImpl::KernParam& kern_param); + +void resize_nearest_upsample2_nchw_fp32( + const ResizeImpl::KernParam& kern_param); + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +void resize_linear_upsample2_nchw_fp16( + const ResizeImpl::KernParam& kern_param); + +void resize_nearest_upsample2_nchw_fp16( + const ResizeImpl::KernParam& kern_param); + +#endif + +} // namespace arm_common +} // namespace megdnn diff --git a/dnn/src/arm_common/resize/upsample2_nchwxx.cpp b/dnn/src/arm_common/resize/upsample2_nchwxx.cpp new file mode 100644 index 000000000..7e91c416c --- /dev/null +++ b/dnn/src/arm_common/resize/upsample2_nchwxx.cpp @@ -0,0 +1,197 @@ +/** + * \file dnn/src/arm_common/resize/upsample2_nchwxx.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/resize/upsample2_nchwxx.h" + +#include "src/arm_common/resize/helper.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +using namespace megdnn; +using namespace arm_common; +using namespace resize; + +namespace { + +template +static inline typename simd_helper::simd_type compute_linear_element( + const typename simd_helper::simd_type src[4], + const typename simd_helper::simd_type alpha[2][2]) { + typename simd_helper::simd_type c = simd_helper::dup(0); + c = simd_helper::fma(c, src[0], alpha[0 ^ fh][0 ^ fw]); + c = simd_helper::fma(c, src[1], alpha[0 ^ fh][1 ^ fw]); + c = simd_helper::fma(c, src[2], alpha[1 ^ fh][0 ^ fw]); + c = simd_helper::fma(c, src[3], alpha[1 ^ fh][1 ^ fw]); + return c; +} + +template +static inline void compute_linear_2x2_element( + const typename simd_helper::ctype* src, + typename simd_helper::ctype* dst, size_t IW, size_t OW, + const typename simd_helper::simd_type alpha[2][2]) { + constexpr size_t PC = simd_helper::simd_width; + const typename simd_helper::ctype* src_ptr[4] = {src, src, src, src}; + + if (has_right) { + src_ptr[1] += PC; + src_ptr[3] += PC; + } + if (has_bottom) { + src_ptr[2] += IW * PC; + src_ptr[3] += IW * PC; + } + + typename simd_helper::simd_type rsrc[4]; + rsrc[0] = simd_helper::load(src_ptr[0]); + rsrc[1] = simd_helper::load(src_ptr[1]); + rsrc[2] = simd_helper::load(src_ptr[2]); + rsrc[3] = simd_helper::load(src_ptr[3]); + + typename simd_helper::simd_type rdst[4]; + rdst[0] = compute_linear_element(rsrc, alpha); + rdst[1] = compute_linear_element(rsrc, alpha); + rdst[2] = compute_linear_element(rsrc, alpha); + rdst[3] = compute_linear_element(rsrc, alpha); + + simd_helper::store(dst, rdst[0]); + if (has_right) { + simd_helper::store(dst + PC, rdst[1]); + } + if (has_bottom) { + simd_helper::store(dst + OW * PC, rdst[2]); + } + if (has_right && has_bottom) { + simd_helper::store(dst + (OW + 1) * PC, rdst[3]); + } +} + +template +void linear_upsample2_nchwxx(const ctype* src_ptr, ctype* dst_ptr, size_t N, + size_t IH, size_t IW) { + using simd_helper = SIMDHelper; + size_t OW = IW * 2; + constexpr size_t PC = simd_helper::simd_width; + + typename simd_helper::simd_type alpha[2][2]; + alpha[0][0] = simd_helper::dup(0.75 * 0.75); + alpha[0][1] = simd_helper::dup(0.75 * 0.25); + alpha[1][0] = simd_helper::dup(0.25 * 0.75); + alpha[1][1] = simd_helper::dup(0.25 * 0.25); + + for (size_t i = 0; i < N; ++i) { + compute_linear_2x2_element(src_ptr, dst_ptr, + IW, OW, alpha); + + { + for (size_t iw = 0; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, + alpha); + } + } + compute_linear_2x2_element( + src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, + alpha); + dst_ptr += OW * PC; + + for (size_t ih = 0; ih + 1 < IH; ++ih) { + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); + for (size_t iw = 0; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, + alpha); + } + compute_linear_2x2_element( + src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, + alpha); + + src_ptr += IW * PC; + dst_ptr += 2 * OW * PC; + } + + compute_linear_2x2_element(src_ptr, dst_ptr, + IW, OW, alpha); + { + for (size_t iw = 0; iw + 1 < IW; ++iw) { + compute_linear_2x2_element( + src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, + alpha); + } + } + + compute_linear_2x2_element( + src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, + alpha); + src_ptr += IW * PC; + dst_ptr += OW * PC; + } +} + +template +void nearest_upsample2_nchwxx(const ctype* src_ptr, ctype* dst_ptr, size_t N, + size_t IH, size_t IW) { + using simd_helper = SIMDHelper; + size_t OW = IW * 2; + constexpr size_t PC = simd_helper::simd_width; + + for (size_t i = 0; i < N; ++i) { + for (size_t ih = 0; ih < IH; ++ih) { + for (size_t iw = 0; iw < IW; ++iw) { + typename simd_helper::simd_type r0 = + simd_helper::load(src_ptr + iw * PC); + + simd_helper::store(dst_ptr + (iw * 2) * PC, r0); + simd_helper::store(dst_ptr + (iw * 2 + 1) * PC, r0); + simd_helper::store(dst_ptr + (OW + iw * 2) * PC, r0); + simd_helper::store(dst_ptr + (OW + iw * 2 + 1) * PC, r0); + } + src_ptr += IW * PC; + dst_ptr += 2 * OW * PC; + } + } +} +} // namespace + +void megdnn::arm_common::resize_linear_upsample2_nchw44_fp32( + const ResizeImpl::KernParam& kern_param) { + linear_upsample2_nchwxx(kern_param.sptr, kern_param.dptr, + kern_param.n * kern_param.c / 4, kern_param.ih, + kern_param.iw); +} + +void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32( + const ResizeImpl::KernParam& kern_param) { + nearest_upsample2_nchwxx(kern_param.sptr, kern_param.dptr, + kern_param.n * kern_param.c / 4, kern_param.ih, + kern_param.iw); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +void megdnn::arm_common::resize_linear_upsample2_nchw88_fp16( + const ResizeImpl::KernParam& kern_param) { + auto sptr = reinterpret_cast(kern_param.sptr); + auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); + linear_upsample2_nchwxx(sptr, dptr, kern_param.n * kern_param.c / 8, + kern_param.ih, kern_param.iw); +} + +void megdnn::arm_common::resize_nearest_upsample2_nchw88_fp16( + const ResizeImpl::KernParam& kern_param) { + auto sptr = reinterpret_cast(kern_param.sptr); + auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); + nearest_upsample2_nchwxx(sptr, dptr, kern_param.n * kern_param.c / 8, + kern_param.ih, kern_param.iw); +} + +#endif diff --git a/dnn/src/arm_common/resize/upsample2_nchwxx.h b/dnn/src/arm_common/resize/upsample2_nchwxx.h new file mode 100644 index 000000000..22e3fdbc5 --- /dev/null +++ b/dnn/src/arm_common/resize/upsample2_nchwxx.h @@ -0,0 +1,36 @@ +/** + * \file dnn/src/arm_common/resize/upsample2_nchwxx.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "src/arm_common/resize/opr_impl.h" + +namespace megdnn { +namespace arm_common { + +void resize_linear_upsample2_nchw44_fp32( + const ResizeImpl::KernParam& kern_param); + +void resize_nearest_upsample2_nchw44_fp32( + const ResizeImpl::KernParam& kern_param); + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +void resize_linear_upsample2_nchw88_fp16( + const ResizeImpl::KernParam& kern_param); + +void resize_nearest_upsample2_nchw88_fp16( + const ResizeImpl::KernParam& kern_param); + +#endif + +} // namespace arm_common +} // namespace megdnn diff --git a/dnn/test/arm_common/resize.cpp b/dnn/test/arm_common/resize.cpp index 1455c725a..66504321e 100644 --- a/dnn/test/arm_common/resize.cpp +++ b/dnn/test/arm_common/resize.cpp @@ -16,8 +16,25 @@ namespace megdnn { namespace test { +using namespace resize; + +static void set_nchw_args(IMode imode, std::vector& args) { + param::Resize param; + param.format = param::Resize::Format::NCHW; + param.imode = imode; + rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul) + args.emplace_back( + param, TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul}, + TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul}); + args.emplace_back(param, TensorShape{1, 1, 10, 10}, + TensorShape{1, 1, 20, 20}); + args.emplace_back(param, TensorShape{1, 1, 10, 10}, + TensorShape{1, 1, 7, 9}); + args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8}); + args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); +} + TEST_F(ARM_COMMON, RESIZE_CV) { - using namespace resize; std::vector args = get_cv_args(); Checker checker(handle()); @@ -37,8 +54,38 @@ TEST_F(ARM_COMMON, RESIZE_CV) { } } -TEST_F(ARM_COMMON, RESIZE_NCHW44) { - using namespace resize; +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON, RESIZE_NCHW_FP16) { + std::vector args; + set_nchw_args(IMode::INTER_LINEAR, args); + set_nchw_args(IMode::INTER_NEAREST, args); + Checker checker(handle()); + + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_epsilon(0.01) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .execs({arg.src, arg.dst}); + } +} +#endif + +TEST_F(ARM_COMMON, RESIZE_NCHW_FP32) { + std::vector args; + set_nchw_args(IMode::INTER_LINEAR, args); + set_nchw_args(IMode::INTER_NEAREST, args); + Checker checker(handle()); + + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .execs({arg.src, arg.dst}); + } +} + +TEST_F(ARM_COMMON, RESIZE_NCHW44_FP32) { std::vector args = get_nchw44_args(); Checker checker(handle()); @@ -50,8 +97,8 @@ TEST_F(ARM_COMMON, RESIZE_NCHW44) { } } -TEST_F(ARM_COMMON, RESIZE_NCHW88) { - using namespace resize; +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +TEST_F(ARM_COMMON, RESIZE_NCHW88_FP16) { std::vector args = get_nchw88_args(); Checker checker(handle()); @@ -63,6 +110,7 @@ TEST_F(ARM_COMMON, RESIZE_NCHW88) { .execs({arg.src, arg.dst}); } } +#endif } // namespace test } // namespace megdnn diff --git a/dnn/test/cuda/resize.cpp b/dnn/test/cuda/resize.cpp index cfc219299..ac777cf4c 100644 --- a/dnn/test/cuda/resize.cpp +++ b/dnn/test/cuda/resize.cpp @@ -52,6 +52,7 @@ TEST_F(CUDA, RESIZE_FORWARD) { checker.set_param(arg.param) .set_dtype(0, dtype::Uint8()) .set_dtype(1, dtype::Uint8()) + .set_epsilon(1) .execs({arg.src, arg.dst}); } @@ -67,7 +68,7 @@ TEST_F(CUDA, RESIZE_FORWARD) { checker.set_param(arg.param) .set_dtype(0, dtype::Int8()) .set_dtype(1, dtype::Int8()) - .set_epsilon(1e-3) + .set_epsilon(1) .execs({arg.src, arg.dst}); } } -- GitLab