提交 67575d58 编写于 作者: M Megvii Engine Team

feat(mge/opr): add interpolate bilinear mode

GitOrigin-RevId: f7023a3fd381f36e64702893576702d59c5be2c6
上级 0558b212
...@@ -197,7 +197,7 @@ public: ...@@ -197,7 +197,7 @@ public:
protected: protected:
//! get origin coord //! get origin coord
std::pair<float, int> get_origin_coord(float scale, int size, int idx); std::pair<float, int> get_origin_coord(float scale, int size, int idx, bool cubic=false);
//! get nearest index in src //! get nearest index in src
int get_nearest_src(float scale, int size, int idx); int get_nearest_src(float scale, int size, int idx);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
*/ */
#include "megdnn/handle.h" #include "megdnn/handle.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/common/utils.h" #include "src/common/utils.h"
...@@ -29,8 +30,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src, ...@@ -29,8 +30,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src,
if (param().format == Param::Format::NCHW) { if (param().format == Param::Format::NCHW) {
megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str()); megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str());
auto imode = param().imode; auto imode = param().imode;
megdnn_assert(imode == param::Resize::InterpolationMode::INTER_LINEAR || using IMode = param::Resize::InterpolationMode;
imode == param::Resize::InterpolationMode::NEAREST); megdnn_assert(imode == IMode::INTER_LINEAR || imode == IMode::NEAREST ||
imode == IMode::INTER_CUBIC);
} else if (param().format == Param::Format::NHWC) { } else if (param().format == Param::Format::NHWC) {
megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str()); megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str());
} else if (param().format == Param::Format::NCHW4) { } else if (param().format == Param::Format::NCHW4) {
...@@ -66,19 +68,20 @@ void ResizeBackward::check_exec(const TensorLayout& diff, ...@@ -66,19 +68,20 @@ void ResizeBackward::check_exec(const TensorLayout& diff,
} }
std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size, std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size,
int idx) { int idx, bool cubic) {
//! copy from resize_cv.cpp //! copy from resize_cv.cpp
float alpha = (idx + 0.5f) / scale - 0.5f; float alpha = (idx + 0.5f) / scale - 0.5f;
int origin_idx = static_cast<int>(floor(alpha)); int origin_idx = static_cast<int>(floor(alpha));
alpha -= origin_idx; alpha -= origin_idx;
if (origin_idx < 0) { if (!cubic) {
origin_idx = 0; if (origin_idx < 0) {
alpha = 0; origin_idx = 0;
} else if (origin_idx + 1 >= size) { alpha = 0;
origin_idx = size - 2; } else if (origin_idx + 1 >= size) {
alpha = 1; origin_idx = size - 2;
alpha = 1;
}
} }
return {alpha, origin_idx}; return {alpha, origin_idx};
} }
......
/**
* \file dnn/src/common/resize.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/arch.h"
#if MEGDNN_CC_HOST && !defined(__host__)
#if __GNUC__ || __has_attribute(always_inline)
#define __forceinline__ inline __attribute__((always_inline))
#else
#define __forceinline__ inline
#endif
#endif
namespace megdnn {
namespace resize {
MEGDNN_HOST MEGDNN_DEVICE __forceinline__ void interpolate_cubic(
float x, float* coeffs) {
const float A = -0.75f;
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A;
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1;
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1;
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2];
}
} // namespace resize
} // namespace megdnn
/* vim: set ft=cpp: */
...@@ -71,7 +71,10 @@ struct RoundingConverter<uint8_t> { ...@@ -71,7 +71,10 @@ struct RoundingConverter<uint8_t> {
__host__ __device__ __forceinline__ uint8_t operator()(float x) const { __host__ __device__ __forceinline__ uint8_t operator()(float x) const {
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
using std::round; using std::round;
using std::max;
using std::min;
#endif #endif
x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places
return static_cast<uint8_t>(round(x)); return static_cast<uint8_t>(round(x));
} }
}; };
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#pragma once #pragma once
#include "src/common/cv/enums.h" #include "src/common/cv/enums.h"
#include "src/common/resize.cuh"
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
...@@ -49,15 +50,6 @@ __device__ inline void interpolate_linear_coefs(float x, float* coeffs) { ...@@ -49,15 +50,6 @@ __device__ inline void interpolate_linear_coefs(float x, float* coeffs) {
coeffs[1] = x; coeffs[1] = x;
} }
__host__ __device__ inline void interpolate_cubic_coefs(float x,
float* coeffs) {
const float A = -0.75f;
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A;
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1;
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1;
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2];
}
__device__ inline void interpolate_lanczos4_coefs(float x, float* coeffs) { __device__ inline void interpolate_lanczos4_coefs(float x, float* coeffs) {
const float s45 = 0.70710678118654752440084436210485; const float s45 = 0.70710678118654752440084436210485;
const float cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, const float cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45},
...@@ -197,7 +189,7 @@ __device__ inline void interpolate_coefs<INTER_LINEAR>(float x, float* coeffs) { ...@@ -197,7 +189,7 @@ __device__ inline void interpolate_coefs<INTER_LINEAR>(float x, float* coeffs) {
} }
template <> template <>
__device__ inline void interpolate_coefs<INTER_CUBIC>(float x, float* coeffs) { __device__ inline void interpolate_coefs<INTER_CUBIC>(float x, float* coeffs) {
interpolate_cubic_coefs(x, coeffs); megdnn::resize::interpolate_cubic(x, coeffs);
} }
template <> template <>
__device__ inline void interpolate_coefs<INTER_LANCZOS4>(float x, __device__ inline void interpolate_coefs<INTER_LANCZOS4>(float x,
......
...@@ -12,6 +12,10 @@ ...@@ -12,6 +12,10 @@
#include "src/cuda/resize/common.h" #include "src/cuda/resize/common.h"
#include "src/cuda/utils.cuh" #include "src/cuda/utils.cuh"
#include "src/cuda/cv/kernel_common.cuh"
using megdnn::resize::interpolate_cubic;
using megdnn::megcv::saturate;
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
...@@ -72,6 +76,42 @@ __global__ void resize_bwd_nearest_kernel(const float* hidden, float* dst, ...@@ -72,6 +76,42 @@ __global__ void resize_bwd_nearest_kernel(const float* hidden, float* dst,
} }
} }
} }
__global__ void resize_bwd_cubic_kernel(const float* hidden, float* dst, int N,
int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) {
int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
hidden += n * C * OH * OW;
dst += n * C * IH * IW;
if (ow < OW && oh < OH) {
float alphah, alphaw;
int ih0, iw0;
get_origin_coord(scale_h, IH, oh, alphah, ih0, true);
get_origin_coord(scale_w, IW, ow, alphaw, iw0, true);
ih0--;
iw0--;
float h_coeff[4], w_coeff[4];
interpolate_cubic(alphah, h_coeff);
interpolate_cubic(alphaw, w_coeff);
for (int c = 0; c < C; ++c) {
constexpr int ksize = 4;
for (int kh = 0; kh < ksize; kh++) {
int ih = saturate(ih0 + kh, 0, IH - 1);
for (int kw = 0; kw < ksize; kw++) {
int iw = saturate(iw0 + kw, 0, IW - 1);
atomicAdd(dst + ih * IW + iw,
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]);
}
}
hidden += OH * OW;
dst += IH * IW;
}
}
}
void backward_data_proxy(InterpolationMode imode, const float* diff, void backward_data_proxy(InterpolationMode imode, const float* diff,
float* grad, int N, int C, int IH, int IW, int OH, float* grad, int N, int C, int IH, int IW, int OH,
int OW, cudaStream_t stream) { int OW, cudaStream_t stream) {
...@@ -83,13 +123,26 @@ void backward_data_proxy(InterpolationMode imode, const float* diff, ...@@ -83,13 +123,26 @@ void backward_data_proxy(InterpolationMode imode, const float* diff,
stream)); stream));
float scale_h = static_cast<float>(OH) / IH; float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW; float scale_w = static_cast<float>(OW) / IW;
if(imode == InterpolationMode::INTER_LINEAR) { switch (imode) {
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>( case InterpolationMode::INTER_LINEAR: {
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>(
} diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
else if (imode == InterpolationMode::INTER_NEAREST) { break;
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>( }
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); case InterpolationMode::INTER_NEAREST: {
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
case InterpolationMode::INTER_CUBIC: {
resize_bwd_cubic_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
default: {
megdnn_throw("unsupported interpolation mode");
break;
}
} }
} }
after_kernel_launch(); after_kernel_launch();
......
...@@ -15,16 +15,19 @@ namespace cuda { ...@@ -15,16 +15,19 @@ namespace cuda {
namespace resize { namespace resize {
__device__ inline void get_origin_coord(float scale, int size, int idx, __device__ inline void get_origin_coord(float scale, int size, int idx,
float& alpha, int& origin_idx) { float& alpha, int& origin_idx,
bool cubic = false) {
alpha = (idx + 0.5f) / scale - 0.5f; alpha = (idx + 0.5f) / scale - 0.5f;
origin_idx = static_cast<int>(floor(alpha)); origin_idx = static_cast<int>(floor(alpha));
alpha -= origin_idx; alpha -= origin_idx;
if (origin_idx < 0) { if (!cubic) {
origin_idx = 0; if (origin_idx < 0) {
alpha = 0; origin_idx = 0;
} else if (origin_idx + 1 >= size) { alpha = 0;
origin_idx = size - 2; } else if (origin_idx + 1 >= size) {
alpha = 1; origin_idx = size - 2;
alpha = 1;
}
} }
} }
......
...@@ -147,9 +147,11 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, ...@@ -147,9 +147,11 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
C, IH, IW, OH, OW, stream); C, IH, IW, OH, OW, stream);
return; return;
} }
megdnn_assert(param().imode == Param::InterpolationMode::LINEAR || megdnn_assert(
param().imode == Param::InterpolationMode::NEAREST, param().imode == Param::InterpolationMode::LINEAR ||
"unsupported interpolation mode for NCHW format"); param().imode == Param::InterpolationMode::NEAREST ||
param().imode == Param::InterpolationMode::INTER_CUBIC,
"unsupported interpolation mode for NCHW format");
if (src.layout.dtype == dtype::Float32{}) { if (src.layout.dtype == dtype::Float32{}) {
resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)), resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
......
...@@ -8,15 +8,20 @@ ...@@ -8,15 +8,20 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "src/common/rounding_converter.cuh"
#include "src/common/utils.cuh"
#include "src/cuda/resize/common.cuh" #include "src/cuda/resize/common.cuh"
#include "src/cuda/resize/common.h" #include "src/cuda/resize/common.h"
#include "src/common/rounding_converter.cuh" #include "src/cuda/resize/resize_cv.cuh"
#include "src/cuda/utils.cuh" #include "src/cuda/cv/kernel_common.cuh"
#include "src/common/resize.cuh"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
using namespace resize; using namespace megdnn::cuda::resize;
using megdnn::resize::interpolate_cubic;
using megdnn::megcv::saturate;
namespace { namespace {
...@@ -81,8 +86,7 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst, ...@@ -81,8 +86,7 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst,
int iw = get_nearest_src(scale_w, IW, ow); int iw = get_nearest_src(scale_w, IW, ow);
for (int c = 0; c < C; ++c) { for (int c = 0; c < C; ++c) {
dst[oh * OW + ow] = output_converter( dst[oh * OW + ow] = output_converter(sptr[ih * S_IH + iw * S_IW]);
sptr[ih * S_IH + iw * S_IW]);
sptr += S_IC; sptr += S_IC;
dst += OH * OW; dst += OH * OW;
...@@ -90,6 +94,45 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst, ...@@ -90,6 +94,45 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst,
} }
} }
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_cubic(SrcVisitor src, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, int S_IN,
int S_IC, int S_IH, int S_IW, float scale_h,
float scale_w) {
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, S_IN);
dst += blockIdx.z * C * OH * OW;
if (ow < OW && oh < OH) {
float alphah, alphaw;
int ih0, iw0;
get_origin_coord(scale_h, IH, oh, alphah, ih0, true);
get_origin_coord(scale_w, IW, ow, alphaw, iw0, true);
ih0--;
iw0--;
float h_coeff[4], w_coeff[4];
interpolate_cubic(alphah, h_coeff);
interpolate_cubic(alphaw, w_coeff);
for (int c = 0; c < C; ++c) {
float ret = 0;
constexpr int ksize = 4;
for (int kh = 0; kh < ksize; kh++) {
int ih = saturate(ih0 + kh, 0, IH - 1);
for (int kw = 0; kw < ksize; kw++) {
int iw = saturate(iw0 + kw, 0, IW - 1);
ret += sptr[ih * S_IH + iw * S_IW] * h_coeff[kh] *
w_coeff[kw];
}
}
dst[oh * OW + ow] = output_converter(ret);
sptr += S_IC;
dst += OH * OW;
}
}
}
template <typename ctype, typename SrcVisitor, typename OutputConverter> template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C, __global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, float scale_h, int IH, int IW, int OH, int OW, float scale_h,
...@@ -140,18 +183,31 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode, ...@@ -140,18 +183,31 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode,
<<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH, <<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH,
OW, scale_h, scale_w); OW, scale_h, scale_w);
} else { } else {
if (imode == InterpolationMode::INTER_LINEAR) { switch (imode) {
kern_general_linear<ctype, SrcVisitor, case InterpolationMode::INTER_LINEAR:
rounding::RoundingConverter<ctype>> kern_general_linear<ctype, SrcVisitor,
<<<blocks, threads, 0, stream>>>( rounding::RoundingConverter<ctype>>
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH, <<<blocks, threads, 0, stream>>>(
S_IW, scale_h, scale_w); src, dst, C, IH, IW, OH, OW, S_IN, S_IC,
} else if (imode == InterpolationMode::INTER_NEAREST) { S_IH, S_IW, scale_h, scale_w);
kern_general_nearest<ctype, SrcVisitor, break;
rounding::RoundingConverter<ctype>> case InterpolationMode::INTER_NEAREST:
<<<blocks, threads, 0, stream>>>( kern_general_nearest<ctype, SrcVisitor,
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH, rounding::RoundingConverter<ctype>>
S_IW, scale_h, scale_w); <<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, scale_h, scale_w);
break;
case InterpolationMode::INTER_CUBIC:
kern_general_cubic<ctype, SrcVisitor,
rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, scale_h, scale_w);
break;
default:
megdnn_throw("unsupported interpolation mode");
break;
} }
} }
N -= curr_batch_size; N -= curr_batch_size;
...@@ -162,8 +218,8 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode, ...@@ -162,8 +218,8 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode,
template <typename ctype, typename SrcVisitor, typename OutputConverter> template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C, __global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, float scale_h, int IH, int IW, int OH, int OW,
float scale_w) { float scale_h, float scale_w) {
OutputConverter output_converter; OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x; int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y; int oh = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -188,10 +244,11 @@ __global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C, ...@@ -188,10 +244,11 @@ __global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C,
#pragma unroll #pragma unroll
for (int c1 = 0; c1 < 4; ++c1) { for (int c1 = 0; c1 < 4; ++c1) {
dst[o_coor + c1] = output_converter( dst[o_coor + c1] = output_converter(
sptr[i_coor00 + c1] * (1.0f - alphaw) * (1.0f - alphah) + sptr[i_coor00 + c1] * (1.0f - alphaw) *
sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) + (1.0f - alphah) +
sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah + sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) +
sptr[i_coor11 + c1] * alphaw * alphah); sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah +
sptr[i_coor11 + c1] * alphaw * alphah);
} }
dst += OH * OW * 4; dst += OH * OW * 4;
sptr += IH * IW * 4; sptr += IH * IW * 4;
...@@ -250,18 +307,18 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH, ...@@ -250,18 +307,18 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH,
after_kernel_launch(); after_kernel_launch();
} }
#define INST(ctype) \ #define INST(ctype) \
template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, int, int, int, \ template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, \
int, int, int, int, int, int, int, \ int, int, int, int, int, int, int, int, int, \
cudaStream_t); int, cudaStream_t);
INST(float) INST(float)
INST(uint8_t) INST(uint8_t)
INST(int8_t) INST(int8_t)
#undef INST #undef INST
#define INST(ctype) \ #define INST(ctype) \
template void forward_proxy_nchw4(const ctype*, ctype*, int, int, int, \ template void forward_proxy_nchw4(const ctype*, ctype*, int, int, int, \
int, int, int, cudaStream_t) int, int, int, cudaStream_t)
INST(int8_t); INST(int8_t);
#undef INST #undef INST
......
...@@ -59,12 +59,14 @@ ...@@ -59,12 +59,14 @@
* --------------------------------------------------------------------------- * ---------------------------------------------------------------------------
*/ */
#include "src/cuda/cv/kernel_common.cuh" #include "src/cuda/cv/kernel_common.cuh"
#include "src/common/resize.cuh"
#include "src/cuda/resize/resize_cv.cuh" #include "src/cuda/resize/resize_cv.cuh"
#include "src/cuda/utils.cuh" #include "src/cuda/utils.cuh"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
using namespace megcv; using namespace megcv;
using megdnn::resize::interpolate_cubic;
namespace { namespace {
...@@ -126,7 +128,7 @@ __global__ void precompute_cubic_coef_f32(float* dst, float scale, ...@@ -126,7 +128,7 @@ __global__ void precompute_cubic_coef_f32(float* dst, float scale,
fr -= sr[tid]; fr -= sr[tid];
float coef[4]; float coef[4];
interpolate_cubic_coefs(fr, coef); interpolate_cubic(fr, coef);
#pragma unroll #pragma unroll
for (int j = 0, index = 0; j < 4; j++, index += size) { for (int j = 0, index = 0; j < 4; j++, index += size) {
dst[tid + index] = coef[j]; dst[tid + index] = coef[j];
...@@ -144,7 +146,7 @@ __global__ void precompute_cubic_coef_u8(short* dst, float scale, size_t size) { ...@@ -144,7 +146,7 @@ __global__ void precompute_cubic_coef_u8(short* dst, float scale, size_t size) {
fr -= sr[tid]; fr -= sr[tid];
float coef[4]; float coef[4];
interpolate_cubic_coefs(fr, coef); interpolate_cubic(fr, coef);
#pragma unroll #pragma unroll
for (int j = 0, index = 0; j < 4; j++, index += size) { for (int j = 0, index = 0; j < 4; j++, index += size) {
dst[tid + index] = (short)(coef[j] * ONE); dst[tid + index] = (short)(coef[j] * ONE);
...@@ -406,7 +408,7 @@ __global__ void resize_cubic_32f_kernel_vector( ...@@ -406,7 +408,7 @@ __global__ void resize_cubic_32f_kernel_vector(
int sc = floor(fc); int sc = floor(fc);
fc -= sc; fc -= sc;
float coef_col[4]; float coef_col[4];
interpolate_cubic_coefs(fc, coef_col); interpolate_cubic(fc, coef_col);
for (int i = 0; i < ELEMENTS_PER_THREADS; i++) { for (int i = 0; i < ELEMENTS_PER_THREADS; i++) {
if (dr >= dst_rows) if (dr >= dst_rows)
...@@ -415,7 +417,7 @@ __global__ void resize_cubic_32f_kernel_vector( ...@@ -415,7 +417,7 @@ __global__ void resize_cubic_32f_kernel_vector(
int sr = floor(fr); int sr = floor(fr);
fr -= sr; fr -= sr;
float coef_row[4]; float coef_row[4];
interpolate_cubic_coefs(fr, coef_row); interpolate_cubic(fr, coef_row);
float dst_data[CH] = {0}; float dst_data[CH] = {0};
#pragma unroll #pragma unroll
for (int offset_r = 0; offset_r < 4; ++offset_r) { for (int offset_r = 0; offset_r < 4; ++offset_r) {
...@@ -459,7 +461,7 @@ __global__ void resize_cubic_8u_kernel_vector( ...@@ -459,7 +461,7 @@ __global__ void resize_cubic_8u_kernel_vector(
short icoef_col[4] = {0}; short icoef_col[4] = {0};
float coef_col[4]; float coef_col[4];
interpolate_cubic_coefs(fc, coef_col); interpolate_cubic(fc, coef_col);
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
icoef_col[i] = (short)(coef_col[i] * ONE); icoef_col[i] = (short)(coef_col[i] * ONE);
...@@ -473,7 +475,7 @@ __global__ void resize_cubic_8u_kernel_vector( ...@@ -473,7 +475,7 @@ __global__ void resize_cubic_8u_kernel_vector(
fr -= sr; fr -= sr;
short icoef_row[4]; short icoef_row[4];
float coef_row[4]; float coef_row[4];
interpolate_cubic_coefs(fr, coef_row); interpolate_cubic(fr, coef_row);
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
icoef_row[i] = (short)(coef_row[i] * ONE); icoef_row[i] = (short)(coef_row[i] * ONE);
......
...@@ -118,7 +118,7 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, ...@@ -118,7 +118,7 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
check_exec(src.layout, dst.layout, workspace.size); check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW4 || if (param().format == param::Resize::Format::NCHW4 ||
(param().format == param::Resize::Format::NCHW && (param().format == param::Resize::Format::NCHW &&
param().imode == param::Resize::InterpolationMode::NEAREST)) { param().imode != param::Resize::InterpolationMode::INTER_LINEAR)) {
naive::ResizeImpl::exec(src, dst, workspace); naive::ResizeImpl::exec(src, dst, workspace);
return; return;
} }
......
...@@ -9,18 +9,21 @@ ...@@ -9,18 +9,21 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "src/naive/resize/opr_impl.h"
#include "midout.h"
#include "src/common/cv/enums.h"
#include "src/common/resize.cuh"
#include "src/common/rounding_converter.cuh" #include "src/common/rounding_converter.cuh"
#include "src/common/utils.cuh" #include "src/common/utils.cuh"
#include "src/naive/handle.h" #include "src/naive/handle.h"
#include "src/naive/resize/opr_impl.h"
#include "src/naive/resize/resize_cv.h" #include "src/naive/resize/resize_cv.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_resize_layout) MIDOUT_DECL(megdnn_naive_resize_layout)
MIDOUT_DECL(megdnn_naive_resize_layout_nearest) MIDOUT_DECL(megdnn_naive_resize_nchw)
using namespace megdnn; using namespace megdnn;
using namespace naive; using namespace naive;
using namespace resize;
template <typename ctype> template <typename ctype>
ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors( ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors(
...@@ -90,20 +93,84 @@ INST(dt_quint8); ...@@ -90,20 +93,84 @@ INST(dt_quint8);
#undef INST #undef INST
template <typename ctype> template <typename ctype>
void ResizeImpl::kern_nchw_nearest (const KernParam<ctype>& kern_param) { void ResizeImpl::kern_nchw(const KernParam<ctype>& kern_param,
InterpolationMode imode) {
megdnn_assert(kern_param.format == Format::NCHW); megdnn_assert(kern_param.format == Format::NCHW);
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param); UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param);
float scale_h = static_cast<float>(OH) / IH; float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW; float scale_w = static_cast<float>(OW) / IW;
rounding::RoundingConverter<ctype> output_converter;
rep(n, N) { rep(n, N) {
rep(oh, OH) rep(ow, OW) { rep(oh, OH) rep(ow, OW) {
auto ih = get_nearest_src(scale_h, IH, oh); switch (imode) {
auto iw = get_nearest_src(scale_w, IW, ow); case InterpolationMode::NEAREST: {
auto ih = get_nearest_src(scale_h, IH, oh);
auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) {
dptr[c * OH * OW + oh * OW + ow] =
sptr[c * S_IC + ih * S_IH + iw * S_IW];
}
break;
}
case InterpolationMode::INTER_LINEAR: {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, static_cast<int>(C)) {
dptr[c * OH * OW + oh * OW + ow] = output_converter(
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * (1.0f - alphah) +
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] *
alphaw * (1.0f - alphah) +
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * alphah +
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] *
alphaw * alphah);
}
break;
}
case InterpolationMode::INTER_CUBIC: {
auto coord_h = get_origin_coord(scale_h, IH, oh, true);
auto coord_w = get_origin_coord(scale_w, IW, ow, true);
float alphah = coord_h.first;
float alphaw = coord_w.first;
rep(c, static_cast<int>(C)) { int ih0 = coord_h.second - 1;
dptr[c * OH * OW + oh * OW + ow] = sptr[c * S_IC + ih * S_IH + iw * S_IW]; int iw0 = coord_w.second - 1;
float h_coeff[4], w_coeff[4];
interpolate_cubic(alphah, h_coeff);
interpolate_cubic(alphaw, w_coeff);
rep(c, static_cast<int>(C)) {
constexpr int ksize = 4;
float ret = 0;
rep(kh, ksize) {
int h = saturate<int, int>(ih0 + kh, 0, IH - 1);
rep(kw, ksize) {
int w = saturate<int, int>(iw0 + kw, 0, IW - 1);
ret += sptr[c * S_IC + h * S_IH + w * S_IW] *
h_coeff[kh] * w_coeff[kw];
}
}
dptr[c * OH * OW + oh * OW + ow] =
output_converter(ret);
}
break;
}
default:
megdnn_throw("unsupported mode in ResizeBackwardImpl");
break;
} }
} }
sptr += S_IN; sptr += S_IN;
...@@ -131,40 +198,6 @@ void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) { ...@@ -131,40 +198,6 @@ void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) {
MIDOUT_END(); MIDOUT_END();
return; return;
} }
megdnn_assert(kern_param.format == Format::NCHW);
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param);
rounding::RoundingConverter<ctype> output_converter;
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
rep(n, N) {
rep(oh, OH) rep(ow, OW) {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, static_cast<int>(C)) {
dptr[c * OH * OW + oh * OW + ow] = output_converter(
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * (1.0f - alphah) +
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * alphaw *
(1.0f - alphah) +
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * alphah +
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * alphaw *
alphah);
}
}
sptr += S_IN;
dptr += C * OH * OW;
}
} }
template <typename ctype> template <typename ctype>
...@@ -290,18 +323,16 @@ void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) { ...@@ -290,18 +323,16 @@ void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) {
void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size); check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW && if (param().format == param::Resize::Format::NCHW) {
param().imode == param::Resize::InterpolationMode::NEAREST) { #define cb(dt, ct, _midout_iv) \
#define cb(dt, ct, _midout_iv) \ case DTypeTrait<dt>::enumv: { \
case DTypeTrait<dt>::enumv: { \ MIDOUT_BEGIN(megdnn_naive_resize_nchw, midout_iv(_midout_iv)) { \
MIDOUT_BEGIN(megdnn_naive_resize_layout_nearest, \ auto kparam = KernParam<ct>::from_tensors(param().format, src, \
midout_iv(_midout_iv)) { \ dst, workspace); \
auto kparam = KernParam<ct>::from_tensors(param().format, src, \ MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw(kparam, param().imode)); \
dst, workspace); \ } \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw_nearest(kparam)); \ MIDOUT_END(); \
} \ return; \
MIDOUT_END(); \
return; \
} }
switch (src.layout.dtype.enumv()) { switch (src.layout.dtype.enumv()) {
...@@ -319,12 +350,10 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, ...@@ -319,12 +350,10 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
return; return;
} }
#undef cb
#undef cb #undef cb
} }
if ((param().format == param::Resize::Format::NCHW || if (((src.layout[3] != 1 && src.layout[3] != 3) ||
(src.layout[3] != 1 && src.layout[3] != 3) ||
!is_nhwc_contig_wc(src.layout)) || !is_nhwc_contig_wc(src.layout)) ||
(param().imode == param::Resize::InterpolationMode::LINEAR)) { (param().imode == param::Resize::InterpolationMode::LINEAR)) {
#define cb(dt, ct, _midout_iv) \ #define cb(dt, ct, _midout_iv) \
...@@ -378,37 +407,73 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, ...@@ -378,37 +407,73 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); std::memset(sptr, 0, sizeof(float) * N * C * IH * IW);
rep(n, N) { rep(n, N) {
rep(oh, OH) rep(ow, OW) { rep(oh, OH) rep(ow, OW) {
if(param().imode == InterpolationMode::INTER_LINEAR) { switch (param().imode) {
auto coord_h = get_origin_coord(scale_h, IH, oh); case InterpolationMode::INTER_LINEAR: {
auto coord_w = get_origin_coord(scale_w, IW, ow); auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
float alphah = coord_h.first;
float alphaw = coord_w.first; float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1; int ih0 = coord_h.second;
int iw0 = coord_w.second; int ih1 = ih0 + 1;
int iw1 = iw0 + 1; int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, C) {
float hidden = hptr[c * OH * OW + oh * OW + ow]; rep(c, C) {
sptr[c * IH * IW + ih0 * IW + iw0] += float hidden = hptr[c * OH * OW + oh * OW + ow];
(1.0f - alphaw) * (1.0f - alphah) * hidden; sptr[c * IH * IW + ih0 * IW + iw0] +=
sptr[c * IH * IW + ih1 * IW + iw0] += (1.0f - alphaw) * (1.0f - alphah) * hidden;
(1.0f - alphaw) * alphah * hidden; sptr[c * IH * IW + ih1 * IW + iw0] +=
sptr[c * IH * IW + ih0 * IW + iw1] += (1.0f - alphaw) * alphah * hidden;
alphaw * (1.0f - alphah) * hidden; sptr[c * IH * IW + ih0 * IW + iw1] +=
sptr[c * IH * IW + ih1 * IW + iw1] += alphaw * (1.0f - alphah) * hidden;
alphaw * alphah * hidden; sptr[c * IH * IW + ih1 * IW + iw1] +=
alphaw * alphah * hidden;
}
break;
} }
} else if (param().imode == InterpolationMode::NEAREST) { case InterpolationMode::NEAREST: {
auto ih = get_nearest_src(scale_h, IH, oh); auto ih = get_nearest_src(scale_h, IH, oh);
auto iw = get_nearest_src(scale_w, IW, ow); auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) { rep(c, static_cast<int>(C)) {
sptr[c * IH * IW + ih * IW + iw] += hptr[c * OH * OW + oh * OW + ow]; sptr[c * IH * IW + ih * IW + iw] +=
hptr[c * OH * OW + oh * OW + ow];
}
break;
}
case InterpolationMode::INTER_CUBIC: {
auto coord_h = get_origin_coord(scale_h, IH, oh, true);
auto coord_w = get_origin_coord(scale_w, IW, ow, true);
float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second - 1;
int iw0 = coord_w.second - 1;
float h_coeff[4], w_coeff[4];
interpolate_cubic(alphah, h_coeff);
interpolate_cubic(alphaw, w_coeff);
rep(c, static_cast<int>(C)) {
constexpr int ksize = 4;
rep(kh, ksize) {
int h = saturate<int, int>(ih0 + kh, 0, IH - 1);
rep(kw, ksize) {
int w = saturate<int, int>(iw0 + kw, 0, IW - 1);
sptr[c * IH * IW + h * IW + w] +=
hptr[c * OH * OW + oh * OW + ow] *
h_coeff[kh] * w_coeff[kw];
}
}
}
break;
}
default: {
megdnn_throw("unsupported mode in ResizeBackwardImpl");
break;
} }
} }
else megdnn_throw("unsupported mode in ResizeBackwardImpl");
} }
sptr += C * IH * IW; sptr += C * IH * IW;
hptr += C * OH * OW; hptr += C * OH * OW;
......
...@@ -47,7 +47,7 @@ private: ...@@ -47,7 +47,7 @@ private:
void kern_naive(const KernParam<ctype>& kern_param); void kern_naive(const KernParam<ctype>& kern_param);
template <typename ctype> template <typename ctype>
void kern_nchw_nearest(const KernParam<ctype>& kern_param); void kern_nchw(const KernParam<ctype>& kern_param, InterpolationMode imode);
template <typename ctype> template <typename ctype>
void kern_naive_nhwc(const KernParam<ctype>& kern_param); void kern_naive_nhwc(const KernParam<ctype>& kern_param);
......
...@@ -68,6 +68,7 @@ ...@@ -68,6 +68,7 @@
#include "src/common/cv/helper.h" #include "src/common/cv/helper.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
#include "src/common/resize.cuh"
MIDOUT_DECL(megdnn_naive_resizecv_imode) MIDOUT_DECL(megdnn_naive_resizecv_imode)
MIDOUT_DECL(megdnn_naive_resizecv_dtype) MIDOUT_DECL(megdnn_naive_resizecv_dtype)
...@@ -75,6 +76,7 @@ MIDOUT_DECL(megdnn_naive_resizecv_dtype) ...@@ -75,6 +76,7 @@ MIDOUT_DECL(megdnn_naive_resizecv_dtype)
using namespace megdnn; using namespace megdnn;
using namespace naive; using namespace naive;
using namespace megcv; using namespace megcv;
using namespace megdnn::resize;
namespace { namespace {
...@@ -383,14 +385,6 @@ using ResizeAreaFunc = void (*)(const Mat<T>& src, Mat<T>& dst, ...@@ -383,14 +385,6 @@ using ResizeAreaFunc = void (*)(const Mat<T>& src, Mat<T>& dst,
const DecimateAlpha* ytab, int ytab_size, const DecimateAlpha* ytab, int ytab_size,
const int* yofs); const int* yofs);
static inline void interpolate_cubic(float x, float* coeffs) {
const float A = -0.75f;
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A;
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1;
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1;
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2];
}
static inline void interpolate_lanczos4(float x, float* coeffs) { static inline void interpolate_lanczos4(float x, float* coeffs) {
static const double s45 = 0.70710678118654752440084436210485; static const double s45 = 0.70710678118654752440084436210485;
static const double cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, static const double cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45},
......
...@@ -43,7 +43,7 @@ TEST_F(CUDA, RESIZE_CV) { ...@@ -43,7 +43,7 @@ TEST_F(CUDA, RESIZE_CV) {
TEST_F(CUDA, RESIZE_FORWARD) { TEST_F(CUDA, RESIZE_FORWARD) {
using namespace resize; using namespace resize;
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC};
for (auto imode : modes) { for (auto imode : modes) {
std::vector<TestArg> args = get_args(imode); std::vector<TestArg> args = get_args(imode);
Checker<Resize> checker(handle_cuda()); Checker<Resize> checker(handle_cuda());
...@@ -88,7 +88,7 @@ TEST_F(CUDA, RESIZE_NCHW4) { ...@@ -88,7 +88,7 @@ TEST_F(CUDA, RESIZE_NCHW4) {
} }
TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) {
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC};
for (auto imode : modes) { for (auto imode : modes) {
param::Resize param; param::Resize param;
param.format = param::Resize::Format::NCHW; param.format = param::Resize::Format::NCHW;
...@@ -117,7 +117,7 @@ TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { ...@@ -117,7 +117,7 @@ TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) {
} }
TEST_F(CUDA, RESIZE_BACKWARD) { TEST_F(CUDA, RESIZE_BACKWARD) {
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC};
for (auto imode : modes) { for (auto imode : modes) {
Checker<ResizeBackward> checker(handle_cuda()); Checker<ResizeBackward> checker(handle_cuda());
param::Resize param; param::Resize param;
......
...@@ -574,19 +574,25 @@ def interpolate( ...@@ -574,19 +574,25 @@ def interpolate(
raise ValueError("under linear mode, size can only be single value") raise ValueError("under linear mode, size can only be single value")
dsize = size dsize = size
if not align_corners and mode in ("bilinear", "nearest") and inp.ndim in [4, 5]: if not align_corners:
# fastpath for interpolate # fastpath for interpolate
op = builtin.Resize( mode_map = {
imode="linear" if mode == "bilinear" else "nearest", format="NCHW" "linear": "linear",
) "bilinear": "linear",
"nearest": "nearest",
"bicubic": "cubic",
}
op = builtin.Resize(imode=mode_map[mode], format="NCHW")
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape) (ret,) = apply(op, inp, shape)
return result else:
assert mode in [
oh, ow = dsize[0], dsize[1] "linear",
ih, iw = inp.shape[2], inp.shape[3] "bilinear",
], "align_corners only support linear or bilinear mode"
if align_corners: oh, ow = dsize[0], dsize[1]
ih, iw = inp.shape[2], inp.shape[3]
hscale = (ih - 1.0) / (oh - 1.0) hscale = (ih - 1.0) / (oh - 1.0)
wscale = 1.0 * iw / ow wscale = 1.0 * iw / ow
if mode != "linear": if mode != "linear":
...@@ -607,34 +613,11 @@ def interpolate( ...@@ -607,34 +613,11 @@ def interpolate(
axis=0, axis=0,
).reshape(1, 3, 3) ).reshape(1, 3, 3)
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) weight = broadcast_to(weight, (inp.shape[0], 3, 3))
else:
hscale = 1.0 * ih / oh
wscale = 1.0 * iw / ow
row0 = concat(
[wscale, Tensor(0, dtype="float32", device=inp.device), 0.5 * wscale - 0.5],
axis=0,
).reshape(1, 3)
row1 = concat(
[Tensor(0, dtype="float32", device=inp.device), hscale, 0.5 * hscale - 0.5],
axis=0,
).reshape(1, 3)
weight = concat(
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
axis=0,
).reshape(1, 3, 3)
weight = broadcast_to(weight, (inp.shape[0], 3, 3))
weight = weight.astype("float32")
if mode in ["linear", "bilinear"]:
ret = warp_perspective(inp, weight, dsize, interp_mode="linear") ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
if mode == "linear":
ret = reshape(ret, ret.shape[0:3]) if mode == "linear":
else: ret = reshape(ret, ret.shape[0:3])
# only NHWC format support "cubic" mode
assert mode == "bicubic"
inp = transpose(inp, (0, 2, 3, 1))
ret = warp_perspective(inp, weight, dsize, format="NHWC", interp_mode="cubic",)
ret = transpose(ret, (0, 3, 1, 2))
return ret return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册