提交 0558b212 编写于 作者: M Megvii Engine Team

feat(mge/opr): add interpolate nearest mode

GitOrigin-RevId: d384b87f504c7dd2731bb3c618f35f8b70d00ed2
上级 171d6915
......@@ -198,6 +198,9 @@ public:
protected:
//! get origin coord
std::pair<float, int> get_origin_coord(float scale, int size, int idx);
//! get nearest index in src
int get_nearest_src(float scale, int size, int idx);
void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
};
......
......@@ -6,9 +6,11 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/handle.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"
......@@ -26,8 +28,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src,
errmsg().c_str());
if (param().format == Param::Format::NCHW) {
megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str());
megdnn_assert(param().imode ==
param::Resize::InterpolationMode::INTER_LINEAR);
auto imode = param().imode;
megdnn_assert(imode == param::Resize::InterpolationMode::INTER_LINEAR ||
imode == param::Resize::InterpolationMode::NEAREST);
} else if (param().format == Param::Format::NHWC) {
megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str());
} else if (param().format == Param::Format::NCHW4) {
......@@ -79,6 +82,9 @@ std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size,
return {alpha, origin_idx};
}
int ResizeBase::get_nearest_src(float scale, int size, int idx) {
return std::min(static_cast<int>(idx / scale), size - 1);
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -30,8 +30,9 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
size_t max_batch_size = max_batch_x_channel / C;
while (N > 0) {
size_t curr_batch_size = N > max_batch_size ? max_batch_size : N;
resize::backward_data_proxy(diff_ptr, grad_ptr, curr_batch_size, C, IH,
IW, OH, OW, stream);
resize::backward_data_proxy(resize::get_imode(param().imode), diff_ptr,
grad_ptr, curr_batch_size, C, IH, IW, OH,
OW, stream);
if (N <= max_batch_size) {
break;
......
......@@ -17,9 +17,9 @@ namespace megdnn {
namespace cuda {
namespace resize {
__global__ void resize_bwd_kernel(const float* hidden, float* dst, int N, int C,
int IH, int IW, int OH, int OW, float scale_h,
float scale_w) {
__global__ void resize_bwd_linear_kernel(const float* hidden, float* dst, int N,
int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) {
int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
......@@ -51,8 +51,30 @@ __global__ void resize_bwd_kernel(const float* hidden, float* dst, int N, int C,
}
}
void backward_data_proxy(const float* diff, float* grad, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream) {
__global__ void resize_bwd_nearest_kernel(const float* hidden, float* dst,
int N, int C, int IH, int IW, int OH,
int OW, float scale_h,
float scale_w) {
int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
hidden += n * C * OH * OW;
dst += n * C * IH * IW;
if (ow < OW && oh < OH) {
int ih = get_nearest_src(scale_h, IH, oh);
int iw = get_nearest_src(scale_w, IW, ow);
for (int c = 0; c < C; ++c) {
atomicAdd(dst + ih * IW + iw,
hidden[oh * OW + ow]);
hidden += OH * OW;
dst += IH * IW;
}
}
}
void backward_data_proxy(InterpolationMode imode, const float* diff,
float* grad, int N, int C, int IH, int IW, int OH,
int OW, cudaStream_t stream) {
const int BY = 16, BX = 32;
{
dim3 threads(BX, BY);
......@@ -61,8 +83,14 @@ void backward_data_proxy(const float* diff, float* grad, int N, int C, int IH,
stream));
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
resize_bwd_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
if(imode == InterpolationMode::INTER_LINEAR) {
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
}
else if (imode == InterpolationMode::INTER_NEAREST) {
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
}
}
after_kernel_launch();
}
......
......@@ -28,6 +28,10 @@ __device__ inline void get_origin_coord(float scale, int size, int idx,
}
}
__device__ inline int get_nearest_src(float scale, int size, int idx) {
return min(static_cast<int>(idx / scale), size - 1);
}
} // namespace resize
} // namespace cuda
} // namespace megdnn
......
......@@ -20,16 +20,17 @@ namespace resize {
// all these kernels use bilinear interpolation
template <typename ctype>
void forward_proxy(bool is_nhwc, const ctype* src, ctype* dst, int N, int C,
int IH, int IW, int OH, int OW, int S_IN, int S_IC, int S_IH,
int S_IW, cudaStream_t stream);
void forward_proxy(bool is_nhwc, InterpolationMode imode, const ctype* src,
ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
int S_IN, int S_IC, int S_IH, int S_IW, cudaStream_t stream);
template <typename ctype>
void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream);
void backward_data_proxy(const float* diff, float* grad, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream);
void backward_data_proxy(InterpolationMode imode, const float* diff,
float* grad, int N, int C, int IH, int IW, int OH,
int OW, cudaStream_t stream);
} // namespace resize
} // namespace cuda
......
......@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/cv/common.h"
#include "src/common/cv/enums.h"
#include "src/cuda/handle.h"
#include "src/cuda/resize/common.h"
#include "src/cuda/resize/helper.h"
......@@ -146,19 +147,23 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
C, IH, IW, OH, OW, stream);
return;
}
megdnn_assert(param().imode == Param::InterpolationMode::LINEAR,
megdnn_assert(param().imode == Param::InterpolationMode::LINEAR ||
param().imode == Param::InterpolationMode::NEAREST,
"unsupported interpolation mode for NCHW format");
if (src.layout.dtype == dtype::Float32{}) {
resize::forward_proxy(is_nhwc, src.ptr<dt_float32>(),
dst.ptr<dt_float32>(), src.layout[0], C, IH, IW,
OH, OW, S_IN, S_IC, S_IH, S_IW, stream);
resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
src.ptr<dt_float32>(), dst.ptr<dt_float32>(),
src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, stream);
} else if (src.layout.dtype == dtype::Uint8()) {
resize::forward_proxy(is_nhwc, src.ptr<dt_uint8>(), dst.ptr<dt_uint8>(),
resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
src.ptr<dt_uint8>(), dst.ptr<dt_uint8>(),
src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, stream);
} else if (src.layout.dtype == dtype::Int8()) {
resize::forward_proxy(is_nhwc, src.ptr<dt_int8>(), dst.ptr<dt_int8>(),
resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)),
src.ptr<dt_int8>(), dst.ptr<dt_int8>(),
src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, stream);
} else {
......
......@@ -32,9 +32,10 @@ struct DirectSrcVisitor {
};
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general(SrcVisitor src, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, int S_IN, int S_IC,
int S_IH, int S_IW, float scale_h, float scale_w) {
__global__ void kern_general_linear(SrcVisitor src, ctype* __restrict dst,
int C, int IH, int IW, int OH, int OW,
int S_IN, int S_IC, int S_IH, int S_IW,
float scale_h, float scale_w) {
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
......@@ -64,6 +65,31 @@ __global__ void kern_general(SrcVisitor src, ctype* __restrict dst, int C,
}
}
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst,
int C, int IH, int IW, int OH, int OW,
int S_IN, int S_IC, int S_IH, int S_IW,
float scale_h, float scale_w) {
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, S_IN);
dst += blockIdx.z * C * OH * OW;
if (ow < OW && oh < OH) {
int ih = get_nearest_src(scale_h, IH, oh);
int iw = get_nearest_src(scale_w, IW, ow);
for (int c = 0; c < C; ++c) {
dst[oh * OW + ow] = output_converter(
sptr[ih * S_IH + iw * S_IW]);
sptr += S_IC;
dst += OH * OW;
}
}
}
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, float scale_h,
......@@ -94,9 +120,10 @@ __global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C,
}
template <typename ctype, typename SrcVisitor>
void dispatch_with_visitor(bool is_nhwc, SrcVisitor src, ctype* dst, int N,
int C, int IH, int IW, int OH, int OW, int S_IN,
int S_IC, int S_IH, int S_IW, cudaStream_t stream) {
void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode,
SrcVisitor src, ctype* dst, int N, int C, int IH,
int IW, int OH, int OW, int S_IN, int S_IC, int S_IH,
int S_IW, cudaStream_t stream) {
const int BY = 16, BX = 32;
const int max_batch_size = 65535;
......@@ -113,10 +140,19 @@ void dispatch_with_visitor(bool is_nhwc, SrcVisitor src, ctype* dst, int N,
<<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH,
OW, scale_h, scale_w);
} else {
kern_general<ctype, SrcVisitor, rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH,
OW, S_IN, S_IC, S_IH, S_IW,
scale_h, scale_w);
if (imode == InterpolationMode::INTER_LINEAR) {
kern_general_linear<ctype, SrcVisitor,
rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH,
S_IW, scale_h, scale_w);
} else if (imode == InterpolationMode::INTER_NEAREST) {
kern_general_nearest<ctype, SrcVisitor,
rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH,
S_IW, scale_h, scale_w);
}
}
N -= curr_batch_size;
src.move_batch(curr_batch_size, C * IH * IW);
......@@ -194,13 +230,14 @@ namespace cuda {
namespace resize {
template <typename ctype>
void forward_proxy(bool is_nhwc, const ctype* src, ctype* dst, int N, int C,
int IH, int IW, int OH, int OW, int S_IN, int S_IC, int S_IH,
int S_IW, cudaStream_t stream) {
void forward_proxy(bool is_nhwc, InterpolationMode imode, const ctype* src,
ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
int S_IN, int S_IC, int S_IH, int S_IW,
cudaStream_t stream) {
DirectSrcVisitor<ctype> visitor;
visitor.ptr = src;
dispatch_with_visitor(is_nhwc, visitor, dst, N, C, IH, IW, OH, OW, S_IN,
S_IC, S_IH, S_IW, stream);
dispatch_with_visitor(is_nhwc, imode, visitor, dst, N, C, IH, IW, OH, OW,
S_IN, S_IC, S_IH, S_IW, stream);
after_kernel_launch();
}
......@@ -214,7 +251,7 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH,
}
#define INST(ctype) \
template void forward_proxy(bool, const ctype*, ctype*, int, int, int, \
template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, int, int, int, \
int, int, int, int, int, int, int, \
cudaStream_t);
INST(float)
......
......@@ -116,7 +116,9 @@ void ResizeImpl::kern_fallback_nhwc(const KernParam<ctype>& kern_param) {
void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW4) {
if (param().format == param::Resize::Format::NCHW4 ||
(param().format == param::Resize::Format::NCHW &&
param().imode == param::Resize::InterpolationMode::NEAREST)) {
naive::ResizeImpl::exec(src, dst, workspace);
return;
}
......
......@@ -10,12 +10,14 @@
*/
#include "src/common/rounding_converter.cuh"
#include "src/common/utils.cuh"
#include "src/naive/handle.h"
#include "src/naive/resize/opr_impl.h"
#include "src/naive/resize/resize_cv.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_resize_layout)
MIDOUT_DECL(megdnn_naive_resize_layout_nearest)
using namespace megdnn;
using namespace naive;
......@@ -86,6 +88,28 @@ INST(dt_qint8);
INST(dt_quint8);
#undef INST
template <typename ctype>
void ResizeImpl::kern_nchw_nearest (const KernParam<ctype>& kern_param) {
megdnn_assert(kern_param.format == Format::NCHW);
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param);
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
rep(n, N) {
rep(oh, OH) rep(ow, OW) {
auto ih = get_nearest_src(scale_h, IH, oh);
auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) {
dptr[c * OH * OW + oh * OW + ow] = sptr[c * S_IC + ih * S_IH + iw * S_IW];
}
}
sptr += S_IN;
dptr += C * OH * OW;
}
}
template <typename ctype>
void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) {
if (kern_param.format == Format::NHWC) {
......@@ -266,6 +290,39 @@ void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) {
void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW &&
param().imode == param::Resize::InterpolationMode::NEAREST) {
#define cb(dt, ct, _midout_iv) \
case DTypeTrait<dt>::enumv: { \
MIDOUT_BEGIN(megdnn_naive_resize_layout_nearest, \
midout_iv(_midout_iv)) { \
auto kparam = KernParam<ct>::from_tensors(param().format, src, \
dst, workspace); \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw_nearest(kparam)); \
} \
MIDOUT_END(); \
return; \
}
switch (src.layout.dtype.enumv()) {
cb(dtype::Float32, float, 0);
DNN_INC_FLOAT16(cb(dtype::Float16, dt_float16, 1));
cb(dtype::Int8, int8_t, 2);
cb(dtype::QuantizedS8, int8_t, 3);
cb(dtype::Uint8, uint8_t, 4);
cb(dtype::Quantized8Asymm, uint8_t, 5);
default:
megdnn_throw(ssprintf("Unsupported input DType in Resize "
"NEAREST mode: %s",
src.layout.dtype.name())
.c_str());
return;
}
#undef cb
#undef cb
}
if ((param().format == param::Resize::Format::NCHW ||
(src.layout[3] != 1 && src.layout[3] != 3) ||
!is_nhwc_contig_wc(src.layout)) ||
......@@ -306,8 +363,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(diff.layout, grad.layout, workspace.size);
megdnn_assert(param().format == param::WarpPerspective::Format::NCHW,
"invalid warp_perspective format");
megdnn_assert(param().format == param::Resize::Format::NCHW,
"invalid resize format");
const int N = grad.layout.shape[0], C = grad.layout.shape[1],
IH = grad.layout.shape[2], IW = grad.layout.shape[3];
const int OH = diff.layout.shape[2], OW = diff.layout.shape[3];
......@@ -321,28 +378,37 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
std::memset(sptr, 0, sizeof(float) * N * C * IH * IW);
rep(n, N) {
rep(oh, OH) rep(ow, OW) {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, C) {
float hidden = hptr[c * OH * OW + oh * OW + ow];
sptr[c * IH * IW + ih0 * IW + iw0] +=
(1.0f - alphaw) * (1.0f - alphah) * hidden;
sptr[c * IH * IW + ih1 * IW + iw0] +=
(1.0f - alphaw) * alphah * hidden;
sptr[c * IH * IW + ih0 * IW + iw1] +=
alphaw * (1.0f - alphah) * hidden;
sptr[c * IH * IW + ih1 * IW + iw1] +=
alphaw * alphah * hidden;
if(param().imode == InterpolationMode::INTER_LINEAR) {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, C) {
float hidden = hptr[c * OH * OW + oh * OW + ow];
sptr[c * IH * IW + ih0 * IW + iw0] +=
(1.0f - alphaw) * (1.0f - alphah) * hidden;
sptr[c * IH * IW + ih1 * IW + iw0] +=
(1.0f - alphaw) * alphah * hidden;
sptr[c * IH * IW + ih0 * IW + iw1] +=
alphaw * (1.0f - alphah) * hidden;
sptr[c * IH * IW + ih1 * IW + iw1] +=
alphaw * alphah * hidden;
}
} else if (param().imode == InterpolationMode::NEAREST) {
auto ih = get_nearest_src(scale_h, IH, oh);
auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) {
sptr[c * IH * IW + ih * IW + iw] += hptr[c * OH * OW + oh * OW + ow];
}
}
else megdnn_throw("unsupported mode in ResizeBackwardImpl");
}
sptr += C * IH * IW;
hptr += C * OH * OW;
......
......@@ -46,6 +46,9 @@ private:
template <typename ctype>
void kern_naive(const KernParam<ctype>& kern_param);
template <typename ctype>
void kern_nchw_nearest(const KernParam<ctype>& kern_param);
template <typename ctype>
void kern_naive_nhwc(const KernParam<ctype>& kern_param);
......
......@@ -18,6 +18,8 @@ namespace megdnn {
namespace test {
namespace resize {
using IMode = param::Resize::InterpolationMode;
struct TestArg {
param::Resize param;
TensorShape src;
......@@ -62,17 +64,18 @@ static void set_nchw_args(std::vector<TestArg>& args) {
args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
}
static inline std::vector<TestArg> get_args() {
static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) {
std::vector<TestArg> args;
set_nchw_args(args);
if(imode == IMode::INTER_LINEAR) {
//! test NHWC with ch != 1 or ch != 3
param::Resize param;
param.format = param::Resize::Format::NHWC;
param.imode = param::Resize::InterpolationMode::LINEAR;
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 4, 6, 4});
args.emplace_back(param, TensorShape{2, 4, 6, 4}, TensorShape{2, 2, 3, 4});
param::Resize param;
param.format = param::Resize::Format::NHWC;
param.imode = imode;
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 4, 6, 4});
args.emplace_back(param, TensorShape{2, 4, 6, 4}, TensorShape{2, 2, 3, 4});
}
return args;
}
......
......@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/common/resize.h"
#include "src/common/cv/enums.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/cuda/fixture.h"
......@@ -42,30 +43,33 @@ TEST_F(CUDA, RESIZE_CV) {
TEST_F(CUDA, RESIZE_FORWARD) {
using namespace resize;
std::vector<TestArg> args = get_args();
Checker<Resize> checker(handle_cuda());
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Uint8())
.execs({arg.src, arg.dst});
}
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST};
for (auto imode : modes) {
std::vector<TestArg> args = get_args(imode);
Checker<Resize> checker(handle_cuda());
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Uint8())
.execs({arg.src, arg.dst});
}
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
}
}
......@@ -84,42 +88,48 @@ TEST_F(CUDA, RESIZE_NCHW4) {
}
TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) {
param::Resize param;
param.format = param::Resize::Format::NCHW;
param.imode = param::Resize::InterpolationMode::LINEAR;
Checker<Resize> checker(handle_cuda());
checker.set_epsilon(1 + 1e-3)
.set_param(param);
auto run = [&](TensorShape src_shape, std::vector<ptrdiff_t> src_layout,
TensorShape dst_shape, DType dtype) {
checker.set_dtype(0, dtype)
.set_dtype(1, dtype)
.execl({{src_shape, src_layout, dtype}, {dst_shape, dtype}});
};
for (DType& dtype : std::vector<DType>{dtype::Float32(), dtype::Uint8(),
dtype::Int8()}) {
run({2, 3, 4, 4}, {256, 32, 8, 1}, {2, 3, 3, 3}, dtype);
run({1, 3, 4, 3}, {105, 35, 7, 2}, {1, 3, 5, 5}, dtype);
run({1, 3, 40, 40}, {25600, 3200, 80, 1}, {1, 3, 30, 30}, dtype);
run({2, 3, 4, 4}, {-256, 32, -8, 1}, {2, 3, 3, 3}, dtype);
run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype);
run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype);
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST};
for (auto imode : modes) {
param::Resize param;
param.format = param::Resize::Format::NCHW;
param.imode = imode;
Checker<Resize> checker(handle_cuda());
checker.set_epsilon(1 + 1e-3)
.set_param(param);
auto run = [&](TensorShape src_shape, std::vector<ptrdiff_t> src_layout,
TensorShape dst_shape, DType dtype) {
checker.set_dtype(0, dtype)
.set_dtype(1, dtype)
.execl({{src_shape, src_layout, dtype}, {dst_shape, dtype}});
};
for (DType& dtype : std::vector<DType>{dtype::Float32(), dtype::Uint8(),
dtype::Int8()}) {
run({2, 3, 4, 4}, {256, 32, 8, 1}, {2, 3, 3, 3}, dtype);
run({1, 3, 4, 3}, {105, 35, 7, 2}, {1, 3, 5, 5}, dtype);
run({1, 3, 40, 40}, {25600, 3200, 80, 1}, {1, 3, 30, 30}, dtype);
run({2, 3, 4, 4}, {-256, 32, -8, 1}, {2, 3, 3, 3}, dtype);
run({2, 3, 4, 4}, {256, -32, 8, -1}, {2, 3, 3, 3}, dtype);
run({2, 3, 4, 4}, {-256, -32, -8, -1}, {2, 3, 3, 3}, dtype);
}
}
}
TEST_F(CUDA, RESIZE_BACKWARD) {
Checker<ResizeBackward> checker(handle_cuda());
param::Resize param;
param.format = param::Resize::Format::NCHW;
param.imode = param::Resize::InterpolationMode::LINEAR;
checker.set_param(param);
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}});
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}});
checker.execs({{2, 5, 8, 5}, {2, 5, 4, 9}});
checker.execs({{2, 5, 4, 9}, {2, 5, 8, 5}});
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST};
for (auto imode : modes) {
Checker<ResizeBackward> checker(handle_cuda());
param::Resize param;
param.format = param::Resize::Format::NCHW;
param.imode = imode;
checker.set_param(param);
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}});
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}});
checker.execs({{2, 5, 8, 5}, {2, 5, 4, 9}});
checker.execs({{2, 5, 4, 9}, {2, 5, 8, 5}});
}
}
#if MEGDNN_WITH_BENCHMARK
......
......@@ -522,29 +522,13 @@ def interpolate(
if align_corners is None:
align_corners = False
if (
size is not None
and scale_factor is None
and not align_corners
and mode == "bilinear"
and inp.ndim in [4, 5]
):
# fastpath for interpolate
op = builtin.Resize(imode="linear", format="NCHW")
shape = astensor1d(size, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result
if mode == "linear":
inp = expand_dims(inp, 3)
if inp.ndim != 4:
raise ValueError("shape of input tensor must correspond to the operartion mode")
if size is None:
if scale_factor is None:
raise ValueError("scale_factor must not be None when size is None")
def get_dsize(scale_factor):
if isinstance(scale_factor, (float, int)):
scale_factor = float(scale_factor)
if mode == "linear":
......@@ -572,6 +556,13 @@ def interpolate(
for i in range(2)
)
dsize = concat([dsize[0], dsize[1]], axis=0)
return dsize
if size is None:
if scale_factor is None:
raise ValueError("scale_factor must not be None when size is None")
dsize = get_dsize(scale_factor)
else:
if scale_factor is not None:
raise ValueError("scale_factor must be None when size is provided")
......@@ -583,6 +574,15 @@ def interpolate(
raise ValueError("under linear mode, size can only be single value")
dsize = size
if not align_corners and mode in ("bilinear", "nearest") and inp.ndim in [4, 5]:
# fastpath for interpolate
op = builtin.Resize(
imode="linear" if mode == "bilinear" else "nearest", format="NCHW"
)
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result
oh, ow = dsize[0], dsize[1]
ih, iw = inp.shape[2], inp.shape[3]
......@@ -630,15 +630,10 @@ def interpolate(
if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
else:
# only NHWC format support "cubic" and "nearest" mode
# only NHWC format support "cubic" mode
assert mode == "bicubic"
inp = transpose(inp, (0, 2, 3, 1))
ret = warp_perspective(
inp,
weight,
dsize,
format="NHWC",
interp_mode="cubic" if mode == "bicubic" else mode,
)
ret = warp_perspective(inp, weight, dsize, format="NHWC", interp_mode="cubic",)
ret = transpose(ret, (0, 3, 1, 2))
return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册