提交 0558b212 编写于 作者: M Megvii Engine Team

feat(mge/opr): add interpolate nearest mode

GitOrigin-RevId: d384b87f504c7dd2731bb3c618f35f8b70d00ed2
上级 171d6915
...@@ -198,6 +198,9 @@ public: ...@@ -198,6 +198,9 @@ public:
protected: protected:
//! get origin coord //! get origin coord
std::pair<float, int> get_origin_coord(float scale, int size, int idx); std::pair<float, int> get_origin_coord(float scale, int size, int idx);
//! get nearest index in src
int get_nearest_src(float scale, int size, int idx);
void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
}; };
......
...@@ -6,9 +6,11 @@ ...@@ -6,9 +6,11 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "megdnn/handle.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/common/utils.h" #include "src/common/utils.h"
...@@ -26,8 +28,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src, ...@@ -26,8 +28,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src,
errmsg().c_str()); errmsg().c_str());
if (param().format == Param::Format::NCHW) { if (param().format == Param::Format::NCHW) {
megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str()); megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str());
megdnn_assert(param().imode == auto imode = param().imode;
param::Resize::InterpolationMode::INTER_LINEAR); megdnn_assert(imode == param::Resize::InterpolationMode::INTER_LINEAR ||
imode == param::Resize::InterpolationMode::NEAREST);
} else if (param().format == Param::Format::NHWC) { } else if (param().format == Param::Format::NHWC) {
megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str()); megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str());
} else if (param().format == Param::Format::NCHW4) { } else if (param().format == Param::Format::NCHW4) {
...@@ -79,6 +82,9 @@ std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size, ...@@ -79,6 +82,9 @@ std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size,
return {alpha, origin_idx}; return {alpha, origin_idx};
} }
int ResizeBase::get_nearest_src(float scale, int size, int idx) {
return std::min(static_cast<int>(idx / scale), size - 1);
}
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -30,8 +30,9 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, ...@@ -30,8 +30,9 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
size_t max_batch_size = max_batch_x_channel / C; size_t max_batch_size = max_batch_x_channel / C;
while (N > 0) { while (N > 0) {
size_t curr_batch_size = N > max_batch_size ? max_batch_size : N; size_t curr_batch_size = N > max_batch_size ? max_batch_size : N;
resize::backward_data_proxy(diff_ptr, grad_ptr, curr_batch_size, C, IH, resize::backward_data_proxy(resize::get_imode(param().imode), diff_ptr,
IW, OH, OW, stream); grad_ptr, curr_batch_size, C, IH, IW, OH,
OW, stream);
if (N <= max_batch_size) { if (N <= max_batch_size) {
break; break;
......
...@@ -17,9 +17,9 @@ namespace megdnn { ...@@ -17,9 +17,9 @@ namespace megdnn {
namespace cuda { namespace cuda {
namespace resize { namespace resize {
__global__ void resize_bwd_kernel(const float* hidden, float* dst, int N, int C, __global__ void resize_bwd_linear_kernel(const float* hidden, float* dst, int N,
int IH, int IW, int OH, int OW, float scale_h, int C, int IH, int IW, int OH, int OW,
float scale_w) { float scale_h, float scale_w) {
int n = blockIdx.z; int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x; int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y; int oh = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -51,8 +51,30 @@ __global__ void resize_bwd_kernel(const float* hidden, float* dst, int N, int C, ...@@ -51,8 +51,30 @@ __global__ void resize_bwd_kernel(const float* hidden, float* dst, int N, int C,
} }
} }
void backward_data_proxy(const float* diff, float* grad, int N, int C, int IH, __global__ void resize_bwd_nearest_kernel(const float* hidden, float* dst,
int IW, int OH, int OW, cudaStream_t stream) { int N, int C, int IH, int IW, int OH,
int OW, float scale_h,
float scale_w) {
int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
hidden += n * C * OH * OW;
dst += n * C * IH * IW;
if (ow < OW && oh < OH) {
int ih = get_nearest_src(scale_h, IH, oh);
int iw = get_nearest_src(scale_w, IW, ow);
for (int c = 0; c < C; ++c) {
atomicAdd(dst + ih * IW + iw,
hidden[oh * OW + ow]);
hidden += OH * OW;
dst += IH * IW;
}
}
}
void backward_data_proxy(InterpolationMode imode, const float* diff,
float* grad, int N, int C, int IH, int IW, int OH,
int OW, cudaStream_t stream) {
const int BY = 16, BX = 32; const int BY = 16, BX = 32;
{ {
dim3 threads(BX, BY); dim3 threads(BX, BY);
...@@ -61,9 +83,15 @@ void backward_data_proxy(const float* diff, float* grad, int N, int C, int IH, ...@@ -61,9 +83,15 @@ void backward_data_proxy(const float* diff, float* grad, int N, int C, int IH,
stream)); stream));
float scale_h = static_cast<float>(OH) / IH; float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW; float scale_w = static_cast<float>(OW) / IW;
resize_bwd_kernel<<<blocks, threads, 0, stream>>>( if(imode == InterpolationMode::INTER_LINEAR) {
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
}
else if (imode == InterpolationMode::INTER_NEAREST) {
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
} }
}
after_kernel_launch(); after_kernel_launch();
} }
......
...@@ -28,6 +28,10 @@ __device__ inline void get_origin_coord(float scale, int size, int idx, ...@@ -28,6 +28,10 @@ __device__ inline void get_origin_coord(float scale, int size, int idx,
} }
} }
__device__ inline int get_nearest_src(float scale, int size, int idx) {
return min(static_cast<int>(idx / scale), size - 1);
}
} // namespace resize } // namespace resize
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
...@@ -20,16 +20,17 @@ namespace resize { ...@@ -20,16 +20,17 @@ namespace resize {
// all these kernels use bilinear interpolation // all these kernels use bilinear interpolation
template <typename ctype> template <typename ctype>
void forward_proxy(bool is_nhwc, const ctype* src, ctype* dst, int N, int C, void forward_proxy(bool is_nhwc, InterpolationMode imode, const ctype* src,
int IH, int IW, int OH, int OW, int S_IN, int S_IC, int S_IH, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
int S_IW, cudaStream_t stream); int S_IN, int S_IC, int S_IH, int S_IW, cudaStream_t stream);
template <typename ctype> template <typename ctype>
void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH, void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream); int IW, int OH, int OW, cudaStream_t stream);
void backward_data_proxy(const float* diff, float* grad, int N, int C, int IH, void backward_data_proxy(InterpolationMode imode, const float* diff,
int IW, int OH, int OW, cudaStream_t stream); float* grad, int N, int C, int IH, int IW, int OH,
int OW, cudaStream_t stream);
} // namespace resize } // namespace resize
} // namespace cuda } // namespace cuda
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "src/common/cv/common.h" #include "src/common/cv/common.h"
#include "src/common/cv/enums.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/resize/common.h" #include "src/cuda/resize/common.h"
#include "src/cuda/resize/helper.h" #include "src/cuda/resize/helper.h"
...@@ -146,19 +147,23 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, ...@@ -146,19 +147,23 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
C, IH, IW, OH, OW, stream); C, IH, IW, OH, OW, stream);
return; return;
} }
megdnn_assert(param().imode == Param::InterpolationMode::LINEAR, megdnn_assert(param().imode == Param::InterpolationMode::LINEAR ||
param().imode == Param::InterpolationMode::NEAREST,
"unsupported interpolation mode for NCHW format"); "unsupported interpolation mode for NCHW format");
if (src.layout.dtype == dtype::Float32{}) { if (src.layout.dtype == dtype::Float32{}) {
resize::forward_proxy(is_nhwc, src.ptr<dt_float32>(), resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
dst.ptr<dt_float32>(), src.layout[0], C, IH, IW, src.ptr<dt_float32>(), dst.ptr<dt_float32>(),
OH, OW, S_IN, S_IC, S_IH, S_IW, stream); src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, stream);
} else if (src.layout.dtype == dtype::Uint8()) { } else if (src.layout.dtype == dtype::Uint8()) {
resize::forward_proxy(is_nhwc, src.ptr<dt_uint8>(), dst.ptr<dt_uint8>(), resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
src.ptr<dt_uint8>(), dst.ptr<dt_uint8>(),
src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, stream); S_IH, S_IW, stream);
} else if (src.layout.dtype == dtype::Int8()) { } else if (src.layout.dtype == dtype::Int8()) {
resize::forward_proxy(is_nhwc, src.ptr<dt_int8>(), dst.ptr<dt_int8>(), resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
src.ptr<dt_int8>(), dst.ptr<dt_int8>(),
src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, stream); S_IH, S_IW, stream);
} else { } else {
......
...@@ -32,9 +32,10 @@ struct DirectSrcVisitor { ...@@ -32,9 +32,10 @@ struct DirectSrcVisitor {
}; };
template <typename ctype, typename SrcVisitor, typename OutputConverter> template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general(SrcVisitor src, ctype* __restrict dst, int C, __global__ void kern_general_linear(SrcVisitor src, ctype* __restrict dst,
int IH, int IW, int OH, int OW, int S_IN, int S_IC, int C, int IH, int IW, int OH, int OW,
int S_IH, int S_IW, float scale_h, float scale_w) { int S_IN, int S_IC, int S_IH, int S_IW,
float scale_h, float scale_w) {
OutputConverter output_converter; OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x; int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y; int oh = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -64,6 +65,31 @@ __global__ void kern_general(SrcVisitor src, ctype* __restrict dst, int C, ...@@ -64,6 +65,31 @@ __global__ void kern_general(SrcVisitor src, ctype* __restrict dst, int C,
} }
} }
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst,
int C, int IH, int IW, int OH, int OW,
int S_IN, int S_IC, int S_IH, int S_IW,
float scale_h, float scale_w) {
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, S_IN);
dst += blockIdx.z * C * OH * OW;
if (ow < OW && oh < OH) {
int ih = get_nearest_src(scale_h, IH, oh);
int iw = get_nearest_src(scale_w, IW, ow);
for (int c = 0; c < C; ++c) {
dst[oh * OW + ow] = output_converter(
sptr[ih * S_IH + iw * S_IW]);
sptr += S_IC;
dst += OH * OW;
}
}
}
template <typename ctype, typename SrcVisitor, typename OutputConverter> template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C, __global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, float scale_h, int IH, int IW, int OH, int OW, float scale_h,
...@@ -94,9 +120,10 @@ __global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C, ...@@ -94,9 +120,10 @@ __global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C,
} }
template <typename ctype, typename SrcVisitor> template <typename ctype, typename SrcVisitor>
void dispatch_with_visitor(bool is_nhwc, SrcVisitor src, ctype* dst, int N, void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode,
int C, int IH, int IW, int OH, int OW, int S_IN, SrcVisitor src, ctype* dst, int N, int C, int IH,
int S_IC, int S_IH, int S_IW, cudaStream_t stream) { int IW, int OH, int OW, int S_IN, int S_IC, int S_IH,
int S_IW, cudaStream_t stream) {
const int BY = 16, BX = 32; const int BY = 16, BX = 32;
const int max_batch_size = 65535; const int max_batch_size = 65535;
...@@ -113,10 +140,19 @@ void dispatch_with_visitor(bool is_nhwc, SrcVisitor src, ctype* dst, int N, ...@@ -113,10 +140,19 @@ void dispatch_with_visitor(bool is_nhwc, SrcVisitor src, ctype* dst, int N,
<<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH, <<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH,
OW, scale_h, scale_w); OW, scale_h, scale_w);
} else { } else {
kern_general<ctype, SrcVisitor, rounding::RoundingConverter<ctype>> if (imode == InterpolationMode::INTER_LINEAR) {
<<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH, kern_general_linear<ctype, SrcVisitor,
OW, S_IN, S_IC, S_IH, S_IW, rounding::RoundingConverter<ctype>>
scale_h, scale_w); <<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH,
S_IW, scale_h, scale_w);
} else if (imode == InterpolationMode::INTER_NEAREST) {
kern_general_nearest<ctype, SrcVisitor,
rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH,
S_IW, scale_h, scale_w);
}
} }
N -= curr_batch_size; N -= curr_batch_size;
src.move_batch(curr_batch_size, C * IH * IW); src.move_batch(curr_batch_size, C * IH * IW);
...@@ -194,13 +230,14 @@ namespace cuda { ...@@ -194,13 +230,14 @@ namespace cuda {
namespace resize { namespace resize {
template <typename ctype> template <typename ctype>
void forward_proxy(bool is_nhwc, const ctype* src, ctype* dst, int N, int C, void forward_proxy(bool is_nhwc, InterpolationMode imode, const ctype* src,
int IH, int IW, int OH, int OW, int S_IN, int S_IC, int S_IH, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
int S_IW, cudaStream_t stream) { int S_IN, int S_IC, int S_IH, int S_IW,
cudaStream_t stream) {
DirectSrcVisitor<ctype> visitor; DirectSrcVisitor<ctype> visitor;
visitor.ptr = src; visitor.ptr = src;
dispatch_with_visitor(is_nhwc, visitor, dst, N, C, IH, IW, OH, OW, S_IN, dispatch_with_visitor(is_nhwc, imode, visitor, dst, N, C, IH, IW, OH, OW,
S_IC, S_IH, S_IW, stream); S_IN, S_IC, S_IH, S_IW, stream);
after_kernel_launch(); after_kernel_launch();
} }
...@@ -214,7 +251,7 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH, ...@@ -214,7 +251,7 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH,
} }
#define INST(ctype) \ #define INST(ctype) \
template void forward_proxy(bool, const ctype*, ctype*, int, int, int, \ template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, int, int, int, \
int, int, int, int, int, int, int, \ int, int, int, int, int, int, int, \
cudaStream_t); cudaStream_t);
INST(float) INST(float)
......
...@@ -116,7 +116,9 @@ void ResizeImpl::kern_fallback_nhwc(const KernParam<ctype>& kern_param) { ...@@ -116,7 +116,9 @@ void ResizeImpl::kern_fallback_nhwc(const KernParam<ctype>& kern_param) {
void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size); check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW4) { if (param().format == param::Resize::Format::NCHW4 ||
(param().format == param::Resize::Format::NCHW &&
param().imode == param::Resize::InterpolationMode::NEAREST)) {
naive::ResizeImpl::exec(src, dst, workspace); naive::ResizeImpl::exec(src, dst, workspace);
return; return;
} }
......
...@@ -10,12 +10,14 @@ ...@@ -10,12 +10,14 @@
*/ */
#include "src/common/rounding_converter.cuh" #include "src/common/rounding_converter.cuh"
#include "src/common/utils.cuh"
#include "src/naive/handle.h" #include "src/naive/handle.h"
#include "src/naive/resize/opr_impl.h" #include "src/naive/resize/opr_impl.h"
#include "src/naive/resize/resize_cv.h" #include "src/naive/resize/resize_cv.h"
#include "midout.h" #include "midout.h"
MIDOUT_DECL(megdnn_naive_resize_layout) MIDOUT_DECL(megdnn_naive_resize_layout)
MIDOUT_DECL(megdnn_naive_resize_layout_nearest)
using namespace megdnn; using namespace megdnn;
using namespace naive; using namespace naive;
...@@ -86,6 +88,28 @@ INST(dt_qint8); ...@@ -86,6 +88,28 @@ INST(dt_qint8);
INST(dt_quint8); INST(dt_quint8);
#undef INST #undef INST
template <typename ctype>
void ResizeImpl::kern_nchw_nearest (const KernParam<ctype>& kern_param) {
megdnn_assert(kern_param.format == Format::NCHW);
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param);
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
rep(n, N) {
rep(oh, OH) rep(ow, OW) {
auto ih = get_nearest_src(scale_h, IH, oh);
auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) {
dptr[c * OH * OW + oh * OW + ow] = sptr[c * S_IC + ih * S_IH + iw * S_IW];
}
}
sptr += S_IN;
dptr += C * OH * OW;
}
}
template <typename ctype> template <typename ctype>
void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) { void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) {
if (kern_param.format == Format::NHWC) { if (kern_param.format == Format::NHWC) {
...@@ -266,6 +290,39 @@ void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) { ...@@ -266,6 +290,39 @@ void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) {
void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size); check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW &&
param().imode == param::Resize::InterpolationMode::NEAREST) {
#define cb(dt, ct, _midout_iv) \
case DTypeTrait<dt>::enumv: { \
MIDOUT_BEGIN(megdnn_naive_resize_layout_nearest, \
midout_iv(_midout_iv)) { \
auto kparam = KernParam<ct>::from_tensors(param().format, src, \
dst, workspace); \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw_nearest(kparam)); \
} \
MIDOUT_END(); \
return; \
}
switch (src.layout.dtype.enumv()) {
cb(dtype::Float32, float, 0);
DNN_INC_FLOAT16(cb(dtype::Float16, dt_float16, 1));
cb(dtype::Int8, int8_t, 2);
cb(dtype::QuantizedS8, int8_t, 3);
cb(dtype::Uint8, uint8_t, 4);
cb(dtype::Quantized8Asymm, uint8_t, 5);
default:
megdnn_throw(ssprintf("Unsupported input DType in Resize "
"NEAREST mode: %s",
src.layout.dtype.name())
.c_str());
return;
}
#undef cb
#undef cb
}
if ((param().format == param::Resize::Format::NCHW || if ((param().format == param::Resize::Format::NCHW ||
(src.layout[3] != 1 && src.layout[3] != 3) || (src.layout[3] != 1 && src.layout[3] != 3) ||
!is_nhwc_contig_wc(src.layout)) || !is_nhwc_contig_wc(src.layout)) ||
...@@ -306,8 +363,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, ...@@ -306,8 +363,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(diff.layout, grad.layout, workspace.size); check_exec(diff.layout, grad.layout, workspace.size);
megdnn_assert(param().format == param::WarpPerspective::Format::NCHW, megdnn_assert(param().format == param::Resize::Format::NCHW,
"invalid warp_perspective format"); "invalid resize format");
const int N = grad.layout.shape[0], C = grad.layout.shape[1], const int N = grad.layout.shape[0], C = grad.layout.shape[1],
IH = grad.layout.shape[2], IW = grad.layout.shape[3]; IH = grad.layout.shape[2], IW = grad.layout.shape[3];
const int OH = diff.layout.shape[2], OW = diff.layout.shape[3]; const int OH = diff.layout.shape[2], OW = diff.layout.shape[3];
...@@ -321,6 +378,7 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, ...@@ -321,6 +378,7 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); std::memset(sptr, 0, sizeof(float) * N * C * IH * IW);
rep(n, N) { rep(n, N) {
rep(oh, OH) rep(ow, OW) { rep(oh, OH) rep(ow, OW) {
if(param().imode == InterpolationMode::INTER_LINEAR) {
auto coord_h = get_origin_coord(scale_h, IH, oh); auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow); auto coord_w = get_origin_coord(scale_w, IW, ow);
...@@ -343,6 +401,14 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, ...@@ -343,6 +401,14 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
sptr[c * IH * IW + ih1 * IW + iw1] += sptr[c * IH * IW + ih1 * IW + iw1] +=
alphaw * alphah * hidden; alphaw * alphah * hidden;
} }
} else if (param().imode == InterpolationMode::NEAREST) {
auto ih = get_nearest_src(scale_h, IH, oh);
auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) {
sptr[c * IH * IW + ih * IW + iw] += hptr[c * OH * OW + oh * OW + ow];
}
}
else megdnn_throw("unsupported mode in ResizeBackwardImpl");
} }
sptr += C * IH * IW; sptr += C * IH * IW;
hptr += C * OH * OW; hptr += C * OH * OW;
......
...@@ -46,6 +46,9 @@ private: ...@@ -46,6 +46,9 @@ private:
template <typename ctype> template <typename ctype>
void kern_naive(const KernParam<ctype>& kern_param); void kern_naive(const KernParam<ctype>& kern_param);
template <typename ctype>
void kern_nchw_nearest(const KernParam<ctype>& kern_param);
template <typename ctype> template <typename ctype>
void kern_naive_nhwc(const KernParam<ctype>& kern_param); void kern_naive_nhwc(const KernParam<ctype>& kern_param);
......
...@@ -18,6 +18,8 @@ namespace megdnn { ...@@ -18,6 +18,8 @@ namespace megdnn {
namespace test { namespace test {
namespace resize { namespace resize {
using IMode = param::Resize::InterpolationMode;
struct TestArg { struct TestArg {
param::Resize param; param::Resize param;
TensorShape src; TensorShape src;
...@@ -62,17 +64,18 @@ static void set_nchw_args(std::vector<TestArg>& args) { ...@@ -62,17 +64,18 @@ static void set_nchw_args(std::vector<TestArg>& args) {
args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
} }
static inline std::vector<TestArg> get_args() { static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) {
std::vector<TestArg> args; std::vector<TestArg> args;
set_nchw_args(args); set_nchw_args(args);
if(imode == IMode::INTER_LINEAR) {
//! test NHWC with ch != 1 or ch != 3 //! test NHWC with ch != 1 or ch != 3
param::Resize param; param::Resize param;
param.format = param::Resize::Format::NHWC; param.format = param::Resize::Format::NHWC;
param.imode = param::Resize::InterpolationMode::LINEAR; param.imode = imode;
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 4, 6, 4}); args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 4, 6, 4});
args.emplace_back(param, TensorShape{2, 4, 6, 4}, TensorShape{2, 2, 3, 4}); args.emplace_back(param, TensorShape{2, 4, 6, 4}, TensorShape{2, 2, 3, 4});
}
return args; return args;
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "test/common/resize.h" #include "test/common/resize.h"
#include "src/common/cv/enums.h"
#include "test/common/benchmarker.h" #include "test/common/benchmarker.h"
#include "test/common/checker.h" #include "test/common/checker.h"
#include "test/cuda/fixture.h" #include "test/cuda/fixture.h"
...@@ -42,7 +43,9 @@ TEST_F(CUDA, RESIZE_CV) { ...@@ -42,7 +43,9 @@ TEST_F(CUDA, RESIZE_CV) {
TEST_F(CUDA, RESIZE_FORWARD) { TEST_F(CUDA, RESIZE_FORWARD) {
using namespace resize; using namespace resize;
std::vector<TestArg> args = get_args(); IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST};
for (auto imode : modes) {
std::vector<TestArg> args = get_args(imode);
Checker<Resize> checker(handle_cuda()); Checker<Resize> checker(handle_cuda());
for (auto&& arg : args) { for (auto&& arg : args) {
...@@ -67,6 +70,7 @@ TEST_F(CUDA, RESIZE_FORWARD) { ...@@ -67,6 +70,7 @@ TEST_F(CUDA, RESIZE_FORWARD) {
.set_epsilon(1e-3) .set_epsilon(1e-3)
.execs({arg.src, arg.dst}); .execs({arg.src, arg.dst});
} }
}
} }
TEST_F(CUDA, RESIZE_NCHW4) { TEST_F(CUDA, RESIZE_NCHW4) {
...@@ -84,9 +88,11 @@ TEST_F(CUDA, RESIZE_NCHW4) { ...@@ -84,9 +88,11 @@ TEST_F(CUDA, RESIZE_NCHW4) {
} }
TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) {
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST};
for (auto imode : modes) {
param::Resize param; param::Resize param;
param.format = param::Resize::Format::NCHW; param.format = param::Resize::Format::NCHW;
param.imode = param::Resize::InterpolationMode::LINEAR; param.imode = imode;
Checker<Resize> checker(handle_cuda()); Checker<Resize> checker(handle_cuda());
checker.set_epsilon(1 + 1e-3) checker.set_epsilon(1 + 1e-3)
.set_param(param); .set_param(param);
...@@ -107,19 +113,23 @@ TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { ...@@ -107,19 +113,23 @@ TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) {
run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype); run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype);
run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype); run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype);
} }
}
} }
TEST_F(CUDA, RESIZE_BACKWARD) { TEST_F(CUDA, RESIZE_BACKWARD) {
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST};
for (auto imode : modes) {
Checker<ResizeBackward> checker(handle_cuda()); Checker<ResizeBackward> checker(handle_cuda());
param::Resize param; param::Resize param;
param.format = param::Resize::Format::NCHW; param.format = param::Resize::Format::NCHW;
param.imode = param::Resize::InterpolationMode::LINEAR; param.imode = imode;
checker.set_param(param); checker.set_param(param);
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}}); checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}});
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}}); checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}});
checker.execs({{2, 5, 8, 5}, {2, 5, 4, 9}}); checker.execs({{2, 5, 8, 5}, {2, 5, 4, 9}});
checker.execs({{2, 5, 4, 9}, {2, 5, 8, 5}}); checker.execs({{2, 5, 4, 9}, {2, 5, 8, 5}});
}
} }
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
......
...@@ -522,29 +522,13 @@ def interpolate( ...@@ -522,29 +522,13 @@ def interpolate(
if align_corners is None: if align_corners is None:
align_corners = False align_corners = False
if (
size is not None
and scale_factor is None
and not align_corners
and mode == "bilinear"
and inp.ndim in [4, 5]
):
# fastpath for interpolate
op = builtin.Resize(imode="linear", format="NCHW")
shape = astensor1d(size, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result
if mode == "linear": if mode == "linear":
inp = expand_dims(inp, 3) inp = expand_dims(inp, 3)
if inp.ndim != 4: if inp.ndim != 4:
raise ValueError("shape of input tensor must correspond to the operartion mode") raise ValueError("shape of input tensor must correspond to the operartion mode")
if size is None: def get_dsize(scale_factor):
if scale_factor is None:
raise ValueError("scale_factor must not be None when size is None")
if isinstance(scale_factor, (float, int)): if isinstance(scale_factor, (float, int)):
scale_factor = float(scale_factor) scale_factor = float(scale_factor)
if mode == "linear": if mode == "linear":
...@@ -572,6 +556,13 @@ def interpolate( ...@@ -572,6 +556,13 @@ def interpolate(
for i in range(2) for i in range(2)
) )
dsize = concat([dsize[0], dsize[1]], axis=0) dsize = concat([dsize[0], dsize[1]], axis=0)
return dsize
if size is None:
if scale_factor is None:
raise ValueError("scale_factor must not be None when size is None")
dsize = get_dsize(scale_factor)
else: else:
if scale_factor is not None: if scale_factor is not None:
raise ValueError("scale_factor must be None when size is provided") raise ValueError("scale_factor must be None when size is provided")
...@@ -583,6 +574,15 @@ def interpolate( ...@@ -583,6 +574,15 @@ def interpolate(
raise ValueError("under linear mode, size can only be single value") raise ValueError("under linear mode, size can only be single value")
dsize = size dsize = size
if not align_corners and mode in ("bilinear", "nearest") and inp.ndim in [4, 5]:
# fastpath for interpolate
op = builtin.Resize(
imode="linear" if mode == "bilinear" else "nearest", format="NCHW"
)
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result
oh, ow = dsize[0], dsize[1] oh, ow = dsize[0], dsize[1]
ih, iw = inp.shape[2], inp.shape[3] ih, iw = inp.shape[2], inp.shape[3]
...@@ -630,15 +630,10 @@ def interpolate( ...@@ -630,15 +630,10 @@ def interpolate(
if mode == "linear": if mode == "linear":
ret = reshape(ret, ret.shape[0:3]) ret = reshape(ret, ret.shape[0:3])
else: else:
# only NHWC format support "cubic" and "nearest" mode # only NHWC format support "cubic" mode
assert mode == "bicubic"
inp = transpose(inp, (0, 2, 3, 1)) inp = transpose(inp, (0, 2, 3, 1))
ret = warp_perspective( ret = warp_perspective(inp, weight, dsize, format="NHWC", interp_mode="cubic",)
inp,
weight,
dsize,
format="NHWC",
interp_mode="cubic" if mode == "bicubic" else mode,
)
ret = transpose(ret, (0, 3, 1, 2)) ret = transpose(ret, (0, 3, 1, 2))
return ret return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册