From 67575d582c038c0e4aad52a6c8e66a4e8d784486 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 13 Aug 2021 12:20:26 +0800 Subject: [PATCH] feat(mge/opr): add interpolate bilinear mode GitOrigin-RevId: f7023a3fd381f36e64702893576702d59c5be2c6 --- dnn/include/megdnn/oprs/cv.h | 2 +- dnn/src/common/resize.cpp | 23 +- dnn/src/common/resize.cuh | 39 +++ dnn/src/common/rounding_converter.cuh | 3 + dnn/src/cuda/cv/kernel_common.cuh | 12 +- dnn/src/cuda/resize/backward.cu | 67 ++++- dnn/src/cuda/resize/common.cuh | 17 +- dnn/src/cuda/resize/forward.cpp | 8 +- dnn/src/cuda/resize/forward.cu | 115 ++++++--- dnn/src/cuda/resize/resize_cv.cu | 14 +- dnn/src/fallback/resize/opr_impl.cpp | 2 +- dnn/src/naive/resize/opr_impl.cpp | 235 +++++++++++------- dnn/src/naive/resize/opr_impl.h | 2 +- dnn/src/naive/resize/resize_cv.cpp | 10 +- dnn/test/cuda/resize.cpp | 6 +- .../python/megengine/functional/vision.py | 57 ++--- 16 files changed, 404 insertions(+), 208 deletions(-) create mode 100644 dnn/src/common/resize.cuh diff --git a/dnn/include/megdnn/oprs/cv.h b/dnn/include/megdnn/oprs/cv.h index 940e6639d..55306f327 100644 --- a/dnn/include/megdnn/oprs/cv.h +++ b/dnn/include/megdnn/oprs/cv.h @@ -197,7 +197,7 @@ public: protected: //! get origin coord - std::pair get_origin_coord(float scale, int size, int idx); + std::pair get_origin_coord(float scale, int size, int idx, bool cubic=false); //! get nearest index in src int get_nearest_src(float scale, int size, int idx); diff --git a/dnn/src/common/resize.cpp b/dnn/src/common/resize.cpp index 799ac15b9..f9c78602c 100644 --- a/dnn/src/common/resize.cpp +++ b/dnn/src/common/resize.cpp @@ -11,6 +11,7 @@ */ #include "megdnn/handle.h" +#include "megdnn/opr_param_defs.h" #include "megdnn/oprs.h" #include "src/common/utils.h" @@ -29,8 +30,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src, if (param().format == Param::Format::NCHW) { megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str()); auto imode = param().imode; - megdnn_assert(imode == param::Resize::InterpolationMode::INTER_LINEAR || - imode == param::Resize::InterpolationMode::NEAREST); + using IMode = param::Resize::InterpolationMode; + megdnn_assert(imode == IMode::INTER_LINEAR || imode == IMode::NEAREST || + imode == IMode::INTER_CUBIC); } else if (param().format == Param::Format::NHWC) { megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str()); } else if (param().format == Param::Format::NCHW4) { @@ -66,19 +68,20 @@ void ResizeBackward::check_exec(const TensorLayout& diff, } std::pair ResizeBase::get_origin_coord(float scale, int size, - int idx) { + int idx, bool cubic) { //! copy from resize_cv.cpp float alpha = (idx + 0.5f) / scale - 0.5f; int origin_idx = static_cast(floor(alpha)); alpha -= origin_idx; - if (origin_idx < 0) { - origin_idx = 0; - alpha = 0; - } else if (origin_idx + 1 >= size) { - origin_idx = size - 2; - alpha = 1; + if (!cubic) { + if (origin_idx < 0) { + origin_idx = 0; + alpha = 0; + } else if (origin_idx + 1 >= size) { + origin_idx = size - 2; + alpha = 1; + } } - return {alpha, origin_idx}; } diff --git a/dnn/src/common/resize.cuh b/dnn/src/common/resize.cuh new file mode 100644 index 000000000..8445dfb85 --- /dev/null +++ b/dnn/src/common/resize.cuh @@ -0,0 +1,39 @@ +/** + * \file dnn/src/common/resize.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megdnn/arch.h" + +#if MEGDNN_CC_HOST && !defined(__host__) +#if __GNUC__ || __has_attribute(always_inline) +#define __forceinline__ inline __attribute__((always_inline)) +#else +#define __forceinline__ inline +#endif +#endif + +namespace megdnn { +namespace resize { + +MEGDNN_HOST MEGDNN_DEVICE __forceinline__ void interpolate_cubic( + float x, float* coeffs) { + const float A = -0.75f; + + coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; + coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; + coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; + coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; +} +} // namespace resize +} // namespace megdnn + +/* vim: set ft=cpp: */ diff --git a/dnn/src/common/rounding_converter.cuh b/dnn/src/common/rounding_converter.cuh index ba06dbbff..8d03f533b 100644 --- a/dnn/src/common/rounding_converter.cuh +++ b/dnn/src/common/rounding_converter.cuh @@ -71,7 +71,10 @@ struct RoundingConverter { __host__ __device__ __forceinline__ uint8_t operator()(float x) const { #if MEGDNN_CC_HOST using std::round; + using std::max; + using std::min; #endif + x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places return static_cast(round(x)); } }; diff --git a/dnn/src/cuda/cv/kernel_common.cuh b/dnn/src/cuda/cv/kernel_common.cuh index 4c1e5968c..07bcf0cdb 100644 --- a/dnn/src/cuda/cv/kernel_common.cuh +++ b/dnn/src/cuda/cv/kernel_common.cuh @@ -11,6 +11,7 @@ #pragma once #include "src/common/cv/enums.h" +#include "src/common/resize.cuh" #include "megdnn/basic_types.h" @@ -49,15 +50,6 @@ __device__ inline void interpolate_linear_coefs(float x, float* coeffs) { coeffs[1] = x; } -__host__ __device__ inline void interpolate_cubic_coefs(float x, - float* coeffs) { - const float A = -0.75f; - coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; - coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; - coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; - coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; -} - __device__ inline void interpolate_lanczos4_coefs(float x, float* coeffs) { const float s45 = 0.70710678118654752440084436210485; const float cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, @@ -197,7 +189,7 @@ __device__ inline void interpolate_coefs(float x, float* coeffs) { } template <> __device__ inline void interpolate_coefs(float x, float* coeffs) { - interpolate_cubic_coefs(x, coeffs); + megdnn::resize::interpolate_cubic(x, coeffs); } template <> __device__ inline void interpolate_coefs(float x, diff --git a/dnn/src/cuda/resize/backward.cu b/dnn/src/cuda/resize/backward.cu index 88d6a6e68..981879231 100644 --- a/dnn/src/cuda/resize/backward.cu +++ b/dnn/src/cuda/resize/backward.cu @@ -12,6 +12,10 @@ #include "src/cuda/resize/common.h" #include "src/cuda/utils.cuh" +#include "src/cuda/cv/kernel_common.cuh" + +using megdnn::resize::interpolate_cubic; +using megdnn::megcv::saturate; namespace megdnn { namespace cuda { @@ -72,6 +76,42 @@ __global__ void resize_bwd_nearest_kernel(const float* hidden, float* dst, } } } + +__global__ void resize_bwd_cubic_kernel(const float* hidden, float* dst, int N, + int C, int IH, int IW, int OH, int OW, + float scale_h, float scale_w) { + int n = blockIdx.z; + int ow = blockIdx.x * blockDim.x + threadIdx.x; + int oh = blockIdx.y * blockDim.y + threadIdx.y; + hidden += n * C * OH * OW; + dst += n * C * IH * IW; + if (ow < OW && oh < OH) { + float alphah, alphaw; + int ih0, iw0; + get_origin_coord(scale_h, IH, oh, alphah, ih0, true); + get_origin_coord(scale_w, IW, ow, alphaw, iw0, true); + ih0--; + iw0--; + float h_coeff[4], w_coeff[4]; + interpolate_cubic(alphah, h_coeff); + interpolate_cubic(alphaw, w_coeff); + for (int c = 0; c < C; ++c) { + constexpr int ksize = 4; + for (int kh = 0; kh < ksize; kh++) { + int ih = saturate(ih0 + kh, 0, IH - 1); + for (int kw = 0; kw < ksize; kw++) { + int iw = saturate(iw0 + kw, 0, IW - 1); + atomicAdd(dst + ih * IW + iw, + hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]); + } + } + + hidden += OH * OW; + dst += IH * IW; + } + } +} + void backward_data_proxy(InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH, int IW, int OH, int OW, cudaStream_t stream) { @@ -83,13 +123,26 @@ void backward_data_proxy(InterpolationMode imode, const float* diff, stream)); float scale_h = static_cast(OH) / IH; float scale_w = static_cast(OW) / IW; - if(imode == InterpolationMode::INTER_LINEAR) { - resize_bwd_linear_kernel<<>>( - diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); - } - else if (imode == InterpolationMode::INTER_NEAREST) { - resize_bwd_nearest_kernel<<>>( - diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); + switch (imode) { + case InterpolationMode::INTER_LINEAR: { + resize_bwd_linear_kernel<<>>( + diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); + break; + } + case InterpolationMode::INTER_NEAREST: { + resize_bwd_nearest_kernel<<>>( + diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); + break; + } + case InterpolationMode::INTER_CUBIC: { + resize_bwd_cubic_kernel<<>>( + diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); + break; + } + default: { + megdnn_throw("unsupported interpolation mode"); + break; + } } } after_kernel_launch(); diff --git a/dnn/src/cuda/resize/common.cuh b/dnn/src/cuda/resize/common.cuh index 9eb02f280..7a15cb8b8 100644 --- a/dnn/src/cuda/resize/common.cuh +++ b/dnn/src/cuda/resize/common.cuh @@ -15,16 +15,19 @@ namespace cuda { namespace resize { __device__ inline void get_origin_coord(float scale, int size, int idx, - float& alpha, int& origin_idx) { + float& alpha, int& origin_idx, + bool cubic = false) { alpha = (idx + 0.5f) / scale - 0.5f; origin_idx = static_cast(floor(alpha)); alpha -= origin_idx; - if (origin_idx < 0) { - origin_idx = 0; - alpha = 0; - } else if (origin_idx + 1 >= size) { - origin_idx = size - 2; - alpha = 1; + if (!cubic) { + if (origin_idx < 0) { + origin_idx = 0; + alpha = 0; + } else if (origin_idx + 1 >= size) { + origin_idx = size - 2; + alpha = 1; + } } } diff --git a/dnn/src/cuda/resize/forward.cpp b/dnn/src/cuda/resize/forward.cpp index 06f760e97..f3d311bc9 100644 --- a/dnn/src/cuda/resize/forward.cpp +++ b/dnn/src/cuda/resize/forward.cpp @@ -147,9 +147,11 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, C, IH, IW, OH, OW, stream); return; } - megdnn_assert(param().imode == Param::InterpolationMode::LINEAR || - param().imode == Param::InterpolationMode::NEAREST, - "unsupported interpolation mode for NCHW format"); + megdnn_assert( + param().imode == Param::InterpolationMode::LINEAR || + param().imode == Param::InterpolationMode::NEAREST || + param().imode == Param::InterpolationMode::INTER_CUBIC, + "unsupported interpolation mode for NCHW format"); if (src.layout.dtype == dtype::Float32{}) { resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)), diff --git a/dnn/src/cuda/resize/forward.cu b/dnn/src/cuda/resize/forward.cu index cb6778149..6be6d99b4 100644 --- a/dnn/src/cuda/resize/forward.cu +++ b/dnn/src/cuda/resize/forward.cu @@ -8,15 +8,20 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "src/common/rounding_converter.cuh" +#include "src/common/utils.cuh" #include "src/cuda/resize/common.cuh" #include "src/cuda/resize/common.h" -#include "src/common/rounding_converter.cuh" +#include "src/cuda/resize/resize_cv.cuh" -#include "src/cuda/utils.cuh" +#include "src/cuda/cv/kernel_common.cuh" +#include "src/common/resize.cuh" using namespace megdnn; using namespace cuda; -using namespace resize; +using namespace megdnn::cuda::resize; +using megdnn::resize::interpolate_cubic; +using megdnn::megcv::saturate; namespace { @@ -81,8 +86,7 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst, int iw = get_nearest_src(scale_w, IW, ow); for (int c = 0; c < C; ++c) { - dst[oh * OW + ow] = output_converter( - sptr[ih * S_IH + iw * S_IW]); + dst[oh * OW + ow] = output_converter(sptr[ih * S_IH + iw * S_IW]); sptr += S_IC; dst += OH * OW; @@ -90,6 +94,45 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst, } } +template +__global__ void kern_general_cubic(SrcVisitor src, ctype* __restrict dst, int C, + int IH, int IW, int OH, int OW, int S_IN, + int S_IC, int S_IH, int S_IW, float scale_h, + float scale_w) { + OutputConverter output_converter; + int ow = blockIdx.x * blockDim.x + threadIdx.x; + int oh = blockIdx.y * blockDim.y + threadIdx.y; + const ctype* __restrict sptr = src.get(blockIdx.z, S_IN); + dst += blockIdx.z * C * OH * OW; + + if (ow < OW && oh < OH) { + float alphah, alphaw; + int ih0, iw0; + get_origin_coord(scale_h, IH, oh, alphah, ih0, true); + get_origin_coord(scale_w, IW, ow, alphaw, iw0, true); + ih0--; + iw0--; + float h_coeff[4], w_coeff[4]; + interpolate_cubic(alphah, h_coeff); + interpolate_cubic(alphaw, w_coeff); + for (int c = 0; c < C; ++c) { + float ret = 0; + constexpr int ksize = 4; + for (int kh = 0; kh < ksize; kh++) { + int ih = saturate(ih0 + kh, 0, IH - 1); + for (int kw = 0; kw < ksize; kw++) { + int iw = saturate(iw0 + kw, 0, IW - 1); + ret += sptr[ih * S_IH + iw * S_IW] * h_coeff[kh] * + w_coeff[kw]; + } + } + dst[oh * OW + ow] = output_converter(ret); + + sptr += S_IC; + dst += OH * OW; + } + } +} template __global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C, int IH, int IW, int OH, int OW, float scale_h, @@ -140,18 +183,31 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode, <<>>(src, dst, C, IH, IW, OH, OW, scale_h, scale_w); } else { - if (imode == InterpolationMode::INTER_LINEAR) { - kern_general_linear> - <<>>( - src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH, - S_IW, scale_h, scale_w); - } else if (imode == InterpolationMode::INTER_NEAREST) { - kern_general_nearest> - <<>>( - src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH, - S_IW, scale_h, scale_w); + switch (imode) { + case InterpolationMode::INTER_LINEAR: + kern_general_linear> + <<>>( + src, dst, C, IH, IW, OH, OW, S_IN, S_IC, + S_IH, S_IW, scale_h, scale_w); + break; + case InterpolationMode::INTER_NEAREST: + kern_general_nearest> + <<>>( + src, dst, C, IH, IW, OH, OW, S_IN, S_IC, + S_IH, S_IW, scale_h, scale_w); + break; + case InterpolationMode::INTER_CUBIC: + kern_general_cubic> + <<>>( + src, dst, C, IH, IW, OH, OW, S_IN, S_IC, + S_IH, S_IW, scale_h, scale_w); + break; + default: + megdnn_throw("unsupported interpolation mode"); + break; } } N -= curr_batch_size; @@ -162,8 +218,8 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode, template __global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C, - int IH, int IW, int OH, int OW, float scale_h, - float scale_w) { + int IH, int IW, int OH, int OW, + float scale_h, float scale_w) { OutputConverter output_converter; int ow = blockIdx.x * blockDim.x + threadIdx.x; int oh = blockIdx.y * blockDim.y + threadIdx.y; @@ -188,10 +244,11 @@ __global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C, #pragma unroll for (int c1 = 0; c1 < 4; ++c1) { dst[o_coor + c1] = output_converter( - sptr[i_coor00 + c1] * (1.0f - alphaw) * (1.0f - alphah) + - sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) + - sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah + - sptr[i_coor11 + c1] * alphaw * alphah); + sptr[i_coor00 + c1] * (1.0f - alphaw) * + (1.0f - alphah) + + sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) + + sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah + + sptr[i_coor11 + c1] * alphaw * alphah); } dst += OH * OW * 4; sptr += IH * IW * 4; @@ -250,18 +307,18 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH, after_kernel_launch(); } -#define INST(ctype) \ - template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, int, int, int, \ - int, int, int, int, int, int, int, \ - cudaStream_t); +#define INST(ctype) \ + template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, \ + int, int, int, int, int, int, int, int, int, \ + int, cudaStream_t); INST(float) INST(uint8_t) INST(int8_t) #undef INST -#define INST(ctype) \ +#define INST(ctype) \ template void forward_proxy_nchw4(const ctype*, ctype*, int, int, int, \ - int, int, int, cudaStream_t) + int, int, int, cudaStream_t) INST(int8_t); #undef INST diff --git a/dnn/src/cuda/resize/resize_cv.cu b/dnn/src/cuda/resize/resize_cv.cu index d3a7bdd27..d3ec01cd2 100644 --- a/dnn/src/cuda/resize/resize_cv.cu +++ b/dnn/src/cuda/resize/resize_cv.cu @@ -59,12 +59,14 @@ * --------------------------------------------------------------------------- */ #include "src/cuda/cv/kernel_common.cuh" +#include "src/common/resize.cuh" #include "src/cuda/resize/resize_cv.cuh" #include "src/cuda/utils.cuh" using namespace megdnn; using namespace cuda; using namespace megcv; +using megdnn::resize::interpolate_cubic; namespace { @@ -126,7 +128,7 @@ __global__ void precompute_cubic_coef_f32(float* dst, float scale, fr -= sr[tid]; float coef[4]; - interpolate_cubic_coefs(fr, coef); + interpolate_cubic(fr, coef); #pragma unroll for (int j = 0, index = 0; j < 4; j++, index += size) { dst[tid + index] = coef[j]; @@ -144,7 +146,7 @@ __global__ void precompute_cubic_coef_u8(short* dst, float scale, size_t size) { fr -= sr[tid]; float coef[4]; - interpolate_cubic_coefs(fr, coef); + interpolate_cubic(fr, coef); #pragma unroll for (int j = 0, index = 0; j < 4; j++, index += size) { dst[tid + index] = (short)(coef[j] * ONE); @@ -406,7 +408,7 @@ __global__ void resize_cubic_32f_kernel_vector( int sc = floor(fc); fc -= sc; float coef_col[4]; - interpolate_cubic_coefs(fc, coef_col); + interpolate_cubic(fc, coef_col); for (int i = 0; i < ELEMENTS_PER_THREADS; i++) { if (dr >= dst_rows) @@ -415,7 +417,7 @@ __global__ void resize_cubic_32f_kernel_vector( int sr = floor(fr); fr -= sr; float coef_row[4]; - interpolate_cubic_coefs(fr, coef_row); + interpolate_cubic(fr, coef_row); float dst_data[CH] = {0}; #pragma unroll for (int offset_r = 0; offset_r < 4; ++offset_r) { @@ -459,7 +461,7 @@ __global__ void resize_cubic_8u_kernel_vector( short icoef_col[4] = {0}; float coef_col[4]; - interpolate_cubic_coefs(fc, coef_col); + interpolate_cubic(fc, coef_col); #pragma unroll for (int i = 0; i < 4; i++) { icoef_col[i] = (short)(coef_col[i] * ONE); @@ -473,7 +475,7 @@ __global__ void resize_cubic_8u_kernel_vector( fr -= sr; short icoef_row[4]; float coef_row[4]; - interpolate_cubic_coefs(fr, coef_row); + interpolate_cubic(fr, coef_row); #pragma unroll for (int i = 0; i < 4; i++) { icoef_row[i] = (short)(coef_row[i] * ONE); diff --git a/dnn/src/fallback/resize/opr_impl.cpp b/dnn/src/fallback/resize/opr_impl.cpp index cb8b02eb7..207c892f1 100644 --- a/dnn/src/fallback/resize/opr_impl.cpp +++ b/dnn/src/fallback/resize/opr_impl.cpp @@ -118,7 +118,7 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, check_exec(src.layout, dst.layout, workspace.size); if (param().format == param::Resize::Format::NCHW4 || (param().format == param::Resize::Format::NCHW && - param().imode == param::Resize::InterpolationMode::NEAREST)) { + param().imode != param::Resize::InterpolationMode::INTER_LINEAR)) { naive::ResizeImpl::exec(src, dst, workspace); return; } diff --git a/dnn/src/naive/resize/opr_impl.cpp b/dnn/src/naive/resize/opr_impl.cpp index 9778c17cd..df6b52744 100644 --- a/dnn/src/naive/resize/opr_impl.cpp +++ b/dnn/src/naive/resize/opr_impl.cpp @@ -9,18 +9,21 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "src/naive/resize/opr_impl.h" +#include "midout.h" +#include "src/common/cv/enums.h" +#include "src/common/resize.cuh" #include "src/common/rounding_converter.cuh" #include "src/common/utils.cuh" #include "src/naive/handle.h" -#include "src/naive/resize/opr_impl.h" #include "src/naive/resize/resize_cv.h" -#include "midout.h" MIDOUT_DECL(megdnn_naive_resize_layout) -MIDOUT_DECL(megdnn_naive_resize_layout_nearest) +MIDOUT_DECL(megdnn_naive_resize_nchw) using namespace megdnn; using namespace naive; +using namespace resize; template ResizeImpl::KernParam ResizeImpl::KernParam::from_tensors( @@ -90,20 +93,84 @@ INST(dt_quint8); #undef INST template -void ResizeImpl::kern_nchw_nearest (const KernParam& kern_param) { +void ResizeImpl::kern_nchw(const KernParam& kern_param, + InterpolationMode imode) { megdnn_assert(kern_param.format == Format::NCHW); UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param); float scale_h = static_cast(OH) / IH; float scale_w = static_cast(OW) / IW; + rounding::RoundingConverter output_converter; rep(n, N) { rep(oh, OH) rep(ow, OW) { - auto ih = get_nearest_src(scale_h, IH, oh); - auto iw = get_nearest_src(scale_w, IW, ow); + switch (imode) { + case InterpolationMode::NEAREST: { + auto ih = get_nearest_src(scale_h, IH, oh); + auto iw = get_nearest_src(scale_w, IW, ow); + + rep(c, static_cast(C)) { + dptr[c * OH * OW + oh * OW + ow] = + sptr[c * S_IC + ih * S_IH + iw * S_IW]; + } + break; + } + case InterpolationMode::INTER_LINEAR: { + auto coord_h = get_origin_coord(scale_h, IH, oh); + auto coord_w = get_origin_coord(scale_w, IW, ow); + + float alphah = coord_h.first; + float alphaw = coord_w.first; + + int ih0 = coord_h.second; + int ih1 = ih0 + 1; + int iw0 = coord_w.second; + int iw1 = iw0 + 1; + rep(c, static_cast(C)) { + dptr[c * OH * OW + oh * OW + ow] = output_converter( + sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * + (1.0f - alphaw) * (1.0f - alphah) + + sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * + alphaw * (1.0f - alphah) + + sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * + (1.0f - alphaw) * alphah + + sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * + alphaw * alphah); + } + break; + } + case InterpolationMode::INTER_CUBIC: { + auto coord_h = get_origin_coord(scale_h, IH, oh, true); + auto coord_w = get_origin_coord(scale_w, IW, ow, true); + + float alphah = coord_h.first; + float alphaw = coord_w.first; - rep(c, static_cast(C)) { - dptr[c * OH * OW + oh * OW + ow] = sptr[c * S_IC + ih * S_IH + iw * S_IW]; + int ih0 = coord_h.second - 1; + int iw0 = coord_w.second - 1; + float h_coeff[4], w_coeff[4]; + interpolate_cubic(alphah, h_coeff); + interpolate_cubic(alphaw, w_coeff); + + rep(c, static_cast(C)) { + constexpr int ksize = 4; + float ret = 0; + rep(kh, ksize) { + int h = saturate(ih0 + kh, 0, IH - 1); + rep(kw, ksize) { + int w = saturate(iw0 + kw, 0, IW - 1); + ret += sptr[c * S_IC + h * S_IH + w * S_IW] * + h_coeff[kh] * w_coeff[kw]; + } + } + dptr[c * OH * OW + oh * OW + ow] = + output_converter(ret); + } + break; + } + default: + megdnn_throw("unsupported mode in ResizeBackwardImpl"); + break; } } sptr += S_IN; @@ -131,40 +198,6 @@ void ResizeImpl::kern_naive(const KernParam& kern_param) { MIDOUT_END(); return; } - megdnn_assert(kern_param.format == Format::NCHW); - UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param); - rounding::RoundingConverter output_converter; - float scale_h = static_cast(OH) / IH; - float scale_w = static_cast(OW) / IW; - - rep(n, N) { - rep(oh, OH) rep(ow, OW) { - auto coord_h = get_origin_coord(scale_h, IH, oh); - auto coord_w = get_origin_coord(scale_w, IW, ow); - - float alphah = coord_h.first; - float alphaw = coord_w.first; - - int ih0 = coord_h.second; - int ih1 = ih0 + 1; - int iw0 = coord_w.second; - int iw1 = iw0 + 1; - - rep(c, static_cast(C)) { - dptr[c * OH * OW + oh * OW + ow] = output_converter( - sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * - (1.0f - alphaw) * (1.0f - alphah) + - sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * alphaw * - (1.0f - alphah) + - sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * - (1.0f - alphaw) * alphah + - sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * alphaw * - alphah); - } - } - sptr += S_IN; - dptr += C * OH * OW; - } } template @@ -290,18 +323,16 @@ void ResizeImpl::kern_naive_nchw4(const KernParam& kern_param) { void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); - if (param().format == param::Resize::Format::NCHW && - param().imode == param::Resize::InterpolationMode::NEAREST) { -#define cb(dt, ct, _midout_iv) \ - case DTypeTrait
::enumv: { \ - MIDOUT_BEGIN(megdnn_naive_resize_layout_nearest, \ - midout_iv(_midout_iv)) { \ - auto kparam = KernParam::from_tensors(param().format, src, \ - dst, workspace); \ - MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw_nearest(kparam)); \ - } \ - MIDOUT_END(); \ - return; \ + if (param().format == param::Resize::Format::NCHW) { +#define cb(dt, ct, _midout_iv) \ + case DTypeTrait
::enumv: { \ + MIDOUT_BEGIN(megdnn_naive_resize_nchw, midout_iv(_midout_iv)) { \ + auto kparam = KernParam::from_tensors(param().format, src, \ + dst, workspace); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw(kparam, param().imode)); \ + } \ + MIDOUT_END(); \ + return; \ } switch (src.layout.dtype.enumv()) { @@ -319,12 +350,10 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, return; } -#undef cb #undef cb } - if ((param().format == param::Resize::Format::NCHW || - (src.layout[3] != 1 && src.layout[3] != 3) || + if (((src.layout[3] != 1 && src.layout[3] != 3) || !is_nhwc_contig_wc(src.layout)) || (param().imode == param::Resize::InterpolationMode::LINEAR)) { #define cb(dt, ct, _midout_iv) \ @@ -378,37 +407,73 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); rep(n, N) { rep(oh, OH) rep(ow, OW) { - if(param().imode == InterpolationMode::INTER_LINEAR) { - auto coord_h = get_origin_coord(scale_h, IH, oh); - auto coord_w = get_origin_coord(scale_w, IW, ow); - - float alphah = coord_h.first; - float alphaw = coord_w.first; - - int ih0 = coord_h.second; - int ih1 = ih0 + 1; - int iw0 = coord_w.second; - int iw1 = iw0 + 1; - - rep(c, C) { - float hidden = hptr[c * OH * OW + oh * OW + ow]; - sptr[c * IH * IW + ih0 * IW + iw0] += - (1.0f - alphaw) * (1.0f - alphah) * hidden; - sptr[c * IH * IW + ih1 * IW + iw0] += - (1.0f - alphaw) * alphah * hidden; - sptr[c * IH * IW + ih0 * IW + iw1] += - alphaw * (1.0f - alphah) * hidden; - sptr[c * IH * IW + ih1 * IW + iw1] += - alphaw * alphah * hidden; + switch (param().imode) { + case InterpolationMode::INTER_LINEAR: { + auto coord_h = get_origin_coord(scale_h, IH, oh); + auto coord_w = get_origin_coord(scale_w, IW, ow); + + float alphah = coord_h.first; + float alphaw = coord_w.first; + + int ih0 = coord_h.second; + int ih1 = ih0 + 1; + int iw0 = coord_w.second; + int iw1 = iw0 + 1; + + rep(c, C) { + float hidden = hptr[c * OH * OW + oh * OW + ow]; + sptr[c * IH * IW + ih0 * IW + iw0] += + (1.0f - alphaw) * (1.0f - alphah) * hidden; + sptr[c * IH * IW + ih1 * IW + iw0] += + (1.0f - alphaw) * alphah * hidden; + sptr[c * IH * IW + ih0 * IW + iw1] += + alphaw * (1.0f - alphah) * hidden; + sptr[c * IH * IW + ih1 * IW + iw1] += + alphaw * alphah * hidden; + } + break; } - } else if (param().imode == InterpolationMode::NEAREST) { - auto ih = get_nearest_src(scale_h, IH, oh); - auto iw = get_nearest_src(scale_w, IW, ow); - rep(c, static_cast(C)) { - sptr[c * IH * IW + ih * IW + iw] += hptr[c * OH * OW + oh * OW + ow]; + case InterpolationMode::NEAREST: { + auto ih = get_nearest_src(scale_h, IH, oh); + auto iw = get_nearest_src(scale_w, IW, ow); + rep(c, static_cast(C)) { + sptr[c * IH * IW + ih * IW + iw] += + hptr[c * OH * OW + oh * OW + ow]; + } + break; + } + case InterpolationMode::INTER_CUBIC: { + auto coord_h = get_origin_coord(scale_h, IH, oh, true); + auto coord_w = get_origin_coord(scale_w, IW, ow, true); + + float alphah = coord_h.first; + float alphaw = coord_w.first; + + int ih0 = coord_h.second - 1; + int iw0 = coord_w.second - 1; + float h_coeff[4], w_coeff[4]; + interpolate_cubic(alphah, h_coeff); + interpolate_cubic(alphaw, w_coeff); + + rep(c, static_cast(C)) { + constexpr int ksize = 4; + rep(kh, ksize) { + int h = saturate(ih0 + kh, 0, IH - 1); + rep(kw, ksize) { + int w = saturate(iw0 + kw, 0, IW - 1); + sptr[c * IH * IW + h * IW + w] += + hptr[c * OH * OW + oh * OW + ow] * + h_coeff[kh] * w_coeff[kw]; + } + } + } + break; + } + default: { + megdnn_throw("unsupported mode in ResizeBackwardImpl"); + break; } } - else megdnn_throw("unsupported mode in ResizeBackwardImpl"); } sptr += C * IH * IW; hptr += C * OH * OW; diff --git a/dnn/src/naive/resize/opr_impl.h b/dnn/src/naive/resize/opr_impl.h index b9e3bfb86..59c6fb9dc 100644 --- a/dnn/src/naive/resize/opr_impl.h +++ b/dnn/src/naive/resize/opr_impl.h @@ -47,7 +47,7 @@ private: void kern_naive(const KernParam& kern_param); template - void kern_nchw_nearest(const KernParam& kern_param); + void kern_nchw(const KernParam& kern_param, InterpolationMode imode); template void kern_naive_nhwc(const KernParam& kern_param); diff --git a/dnn/src/naive/resize/resize_cv.cpp b/dnn/src/naive/resize/resize_cv.cpp index d9bd079f7..ed511f231 100644 --- a/dnn/src/naive/resize/resize_cv.cpp +++ b/dnn/src/naive/resize/resize_cv.cpp @@ -68,6 +68,7 @@ #include "src/common/cv/helper.h" #include "src/common/utils.h" #include "src/naive/handle.h" +#include "src/common/resize.cuh" MIDOUT_DECL(megdnn_naive_resizecv_imode) MIDOUT_DECL(megdnn_naive_resizecv_dtype) @@ -75,6 +76,7 @@ MIDOUT_DECL(megdnn_naive_resizecv_dtype) using namespace megdnn; using namespace naive; using namespace megcv; +using namespace megdnn::resize; namespace { @@ -383,14 +385,6 @@ using ResizeAreaFunc = void (*)(const Mat& src, Mat& dst, const DecimateAlpha* ytab, int ytab_size, const int* yofs); -static inline void interpolate_cubic(float x, float* coeffs) { - const float A = -0.75f; - - coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; - coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; - coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; - coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; -} static inline void interpolate_lanczos4(float x, float* coeffs) { static const double s45 = 0.70710678118654752440084436210485; static const double cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, diff --git a/dnn/test/cuda/resize.cpp b/dnn/test/cuda/resize.cpp index e668543ec..cfc219299 100644 --- a/dnn/test/cuda/resize.cpp +++ b/dnn/test/cuda/resize.cpp @@ -43,7 +43,7 @@ TEST_F(CUDA, RESIZE_CV) { TEST_F(CUDA, RESIZE_FORWARD) { using namespace resize; - IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; + IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC}; for (auto imode : modes) { std::vector args = get_args(imode); Checker checker(handle_cuda()); @@ -88,7 +88,7 @@ TEST_F(CUDA, RESIZE_NCHW4) { } TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { - IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; + IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC}; for (auto imode : modes) { param::Resize param; param.format = param::Resize::Format::NCHW; @@ -117,7 +117,7 @@ TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { } TEST_F(CUDA, RESIZE_BACKWARD) { - IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; + IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC}; for (auto imode : modes) { Checker checker(handle_cuda()); param::Resize param; diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index d26e792e6..f067f08f9 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -574,19 +574,25 @@ def interpolate( raise ValueError("under linear mode, size can only be single value") dsize = size - if not align_corners and mode in ("bilinear", "nearest") and inp.ndim in [4, 5]: + if not align_corners: # fastpath for interpolate - op = builtin.Resize( - imode="linear" if mode == "bilinear" else "nearest", format="NCHW" - ) + mode_map = { + "linear": "linear", + "bilinear": "linear", + "nearest": "nearest", + "bicubic": "cubic", + } + + op = builtin.Resize(imode=mode_map[mode], format="NCHW") shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) - (result,) = apply(op, inp, shape) - return result - - oh, ow = dsize[0], dsize[1] - ih, iw = inp.shape[2], inp.shape[3] - - if align_corners: + (ret,) = apply(op, inp, shape) + else: + assert mode in [ + "linear", + "bilinear", + ], "align_corners only support linear or bilinear mode" + oh, ow = dsize[0], dsize[1] + ih, iw = inp.shape[2], inp.shape[3] hscale = (ih - 1.0) / (oh - 1.0) wscale = 1.0 * iw / ow if mode != "linear": @@ -607,34 +613,11 @@ def interpolate( axis=0, ).reshape(1, 3, 3) weight = broadcast_to(weight, (inp.shape[0], 3, 3)) - else: - hscale = 1.0 * ih / oh - wscale = 1.0 * iw / ow - row0 = concat( - [wscale, Tensor(0, dtype="float32", device=inp.device), 0.5 * wscale - 0.5], - axis=0, - ).reshape(1, 3) - row1 = concat( - [Tensor(0, dtype="float32", device=inp.device), hscale, 0.5 * hscale - 0.5], - axis=0, - ).reshape(1, 3) - weight = concat( - [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], - axis=0, - ).reshape(1, 3, 3) - weight = broadcast_to(weight, (inp.shape[0], 3, 3)) - weight = weight.astype("float32") - if mode in ["linear", "bilinear"]: ret = warp_perspective(inp, weight, dsize, interp_mode="linear") - if mode == "linear": - ret = reshape(ret, ret.shape[0:3]) - else: - # only NHWC format support "cubic" mode - assert mode == "bicubic" - inp = transpose(inp, (0, 2, 3, 1)) - ret = warp_perspective(inp, weight, dsize, format="NHWC", interp_mode="cubic",) - ret = transpose(ret, (0, 3, 1, 2)) + + if mode == "linear": + ret = reshape(ret, ret.shape[0:3]) return ret -- GitLab