diff --git a/dnn/src/common/remap.cpp b/dnn/src/common/remap.cpp
index 6c2e5a473ed945f8dcad39ff4bb7fdd5cec32da9..9d543bb397736ed2d600500ec42c9bbd723f790b 100644
--- a/dnn/src/common/remap.cpp
+++ b/dnn/src/common/remap.cpp
@@ -18,21 +18,22 @@ namespace megdnn {
void RemapBase::deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst) {
- dst.dtype = src.dtype;
- dst.ndim = src.ndim;
- dst.shape[0] = src.shape[0];
- size_t height_index, channel_index;
+ size_t n = src.shape[0];
+ size_t c, oh, ow;
+ oh = map_xy.shape[1];
+ ow = map_xy.shape[2];
if (param().format == param::Remap::Format::NHWC) {
- height_index = 1;
- channel_index = 3;
+ c = src.shape[3];
+ dst = TensorLayout(TensorShape({n, oh, ow, c}), src.dtype);
+ } else if (param().format == param::Remap::Format::NCHW) {
+ c = src.shape[1];
+ dst = TensorLayout(TensorShape{n, c, oh, ow}, src.dtype, src.format);
+ } else if (param().format == param::Remap::Format::NHWCD4) {
+ c = src.shape[2];
+ dst = TensorLayout{{n, oh, c, ow, 4}, src.dtype, src.format};
} else {
- megdnn_assert(param().format == param::Remap::Format::NCHW);
- height_index = 2;
- channel_index = 1;
+ megdnn_throw("unsupport format");
}
- dst.shape[height_index] = map_xy.shape[1];
- dst.shape[height_index + 1] = map_xy.shape[2];
- dst.shape[channel_index] = src.shape[channel_index];
}
void RemapBase::check_layout_fwd(
@@ -42,7 +43,7 @@ void RemapBase::check_layout_fwd(
megdnn_layout_msg(dst);
};
MEGDNN_MARK_USED_VAR(errmsg);
- megdnn_assert(src.ndim == map_xy.ndim && src.ndim == dst.ndim && src.ndim == 4);
+ megdnn_assert(src.ndim == dst.ndim);
megdnn_assert(dst.dtype == src.dtype);
megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str());
megdnn_assert(map_xy.shape[3] == 2);
@@ -64,10 +65,13 @@ void RemapBase::check_layout_fwd(
megdnn_assert(
dst.shape[2] == map_xy.shape[1] && dst.shape[3] == map_xy.shape[2],
"%s", errmsg().c_str());
+ } else if (param().format == param::Remap::Format::NHWCD4) {
+ megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str());
+ megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str());
+ megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());
+ megdnn_assert(param().format == Param::Format::NHWCD4);
} else {
- megdnn_throw(
- "currently do not support other param.format except NHWC and "
- "NCHW");
+ megdnn_throw("unsupport format");
}
}
diff --git a/dnn/src/cuda/remap/backward_data.cpp b/dnn/src/cuda/remap/backward_data.cpp
index a2ce89505fec850499cc5501974bb947ae1109d1..26b9933aab56453dd30924b6ccd614749c0cd609 100644
--- a/dnn/src/cuda/remap/backward_data.cpp
+++ b/dnn/src/cuda/remap/backward_data.cpp
@@ -22,8 +22,9 @@ void RemapBackwardDataImpl::exec(
_megdnn_workspace workspace) {
check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size);
megdnn_assert(
- param().imode == param::Remap::InterpolationMode::LINEAR,
- "only support LINEAR interpolationMode");
+ (param().imode == param::Remap::InterpolationMode::NEAREST) ||
+ (param().imode == param::Remap::InterpolationMode::LINEAR),
+ "only support NEAREST and LINEAR interpolationMode");
megdnn_assert(
param().format == param::Remap::Format::NCHW,
"only support NCHW format for remap backward");
@@ -36,13 +37,15 @@ void RemapBackwardDataImpl::exec(
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];
-#define cb(dt, _format, bmode) \
+#define cb(dt, _format, bmode, inter_mode) \
if (param().format == param::Remap::Format::_format && \
- param().border_type == param::Remap::BorderMode::bmode) { \
+ param().border_type == param::Remap::BorderMode::bmode && \
+ param().imode == param::Remap::InterpolationMode::inter_mode) { \
using ctype = DTypeTrait
::ctype; \
remap::backwarddata_proxy< \
ctype, param_enumv::Remap::Format::_format, \
- ::BorderMode::BORDER_##bmode>( \
+ ::BorderMode::BORDER_##bmode, \
+ ::InterpolationMode::INTER_##inter_mode>( \
grad.compatible_ptr(), map_xy.compatible_ptr(), \
diff.compatible_ptr(), N, C, IH, IW, OH, OW, stream); \
break; \
@@ -50,11 +53,16 @@ void RemapBackwardDataImpl::exec(
#define support_dtype(dt) \
case DTypeTrait::enumv: { \
- cb(dt, NCHW, CONSTANT); \
- cb(dt, NCHW, REPLICATE); \
- cb(dt, NCHW, REFLECT); \
- cb(dt, NCHW, REFLECT_101); \
- cb(dt, NCHW, WRAP); \
+ cb(dt, NCHW, CONSTANT, NEAREST); \
+ cb(dt, NCHW, REPLICATE, NEAREST); \
+ cb(dt, NCHW, REFLECT, NEAREST); \
+ cb(dt, NCHW, REFLECT_101, NEAREST); \
+ cb(dt, NCHW, WRAP, NEAREST); \
+ cb(dt, NCHW, CONSTANT, LINEAR); \
+ cb(dt, NCHW, REPLICATE, LINEAR); \
+ cb(dt, NCHW, REFLECT, LINEAR); \
+ cb(dt, NCHW, REFLECT_101, LINEAR); \
+ cb(dt, NCHW, WRAP, LINEAR); \
megdnn_throw("unsupported border type in remap cuda"); \
}
diff --git a/dnn/src/cuda/remap/backward_data.cu b/dnn/src/cuda/remap/backward_data.cu
index 662f9ac62ad909b15d1250207af72517fba69cf5..0074f8bb8da7134dbe0f5b14efc9e60795721922 100644
--- a/dnn/src/cuda/remap/backward_data.cu
+++ b/dnn/src/cuda/remap/backward_data.cu
@@ -52,8 +52,49 @@ struct GetSrcData {
}
};
+__device__ inline float round_half_to_even(float f) {
+ const float round_away_from_zero = round(f);
+ const float diff = round_away_from_zero - f;
+
+ if ((diff != 0.5f) && (diff != -0.5f)) {
+ return round_away_from_zero;
+ }
+
+ if (fmod(round_away_from_zero, 2.0f) == 0.0f) {
+ return round_away_from_zero;
+ }
+
+ return f - diff;
+}
+
+template
+__global__ void kern_general_nearest(
+ ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH,
+ int IW, int OH, int OW) {
+ int ow = blockIdx.x * blockDim.x + threadIdx.x;
+ int oh = blockIdx.y * blockDim.y + threadIdx.y;
+ grad += blockIdx.z * C * IH * IW;
+ diff += blockIdx.z * C * OH * OW;
+ map_xy += blockIdx.z * 2 * OH * OW;
+
+ if (ow < OW && oh < OH) {
+ float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
+ float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
+ int col = static_cast(round_half_to_even(index_col));
+ int row = static_cast(round_half_to_even(index_row));
+ for (int c = 0; c < C; ++c) {
+ ctype hidden = diff[get_offset(oh, ow, c, OH, OW, C)];
+ int idx =
+ GetSrcData::get_index(row, col, c, IH, IW, C);
+ if (idx != -1) {
+ atomic_add(grad + idx, hidden);
+ }
+ }
+ }
+}
+
template
-__global__ void kern_general(
+__global__ void kern_general_linear(
ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH,
int IW, int OH, int OW) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
@@ -93,8 +134,8 @@ __global__ void kern_general(
atomic_add(grad + a10, round_converter(u * (one - v) * hidden));
}
- int a11 = GetSrcData::
- get_index(row + 1, col + 1, c, IH, IW, C);
+ int a11 = GetSrcData::get_index(
+ row + 1, col + 1, c, IH, IW, C);
if (a11 != -1) {
atomic_add(grad + a11, round_converter(u * v * hidden));
}
@@ -102,7 +143,9 @@ __global__ void kern_general(
}
}
-template
+template <
+ typename ctype, const uint32_t format, ::BorderMode bmode,
+ ::InterpolationMode imode>
void dispatch_backwarddata(
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream) {
@@ -115,8 +158,13 @@ void dispatch_backwarddata(
cuda_check(cudaMemsetAsync(
grad, 0, sizeof(ctype) * curr_batch_size * C * IH * IW, stream));
- kern_general
- <<>>(grad, map_xy, diff, C, IH, IW, OH, OW);
+ if (imode == ::InterpolationMode::INTER_NEAREST) {
+ kern_general_nearest<<>>(
+ grad, map_xy, diff, C, IH, IW, OH, OW);
+ } else if (imode == ::InterpolationMode::INTER_LINEAR) {
+ kern_general_linear<<>>(
+ grad, map_xy, diff, C, IH, IW, OH, OW);
+ }
N -= curr_batch_size;
grad += curr_batch_size * C * IH * IW;
@@ -131,27 +179,35 @@ namespace megdnn {
namespace cuda {
namespace remap {
-template
+template <
+ typename ctype, const uint32_t format, ::BorderMode bmode,
+ ::InterpolationMode imode>
void backwarddata_proxy(
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream) {
- dispatch_backwarddata(
+ dispatch_backwarddata(
grad, map_xy, diff, N, C, IH, IW, OH, OW, stream);
after_kernel_launch();
}
-#define INST(ctype, format, bmode) \
+#define INST(ctype, format, bmode, imode) \
template void backwarddata_proxy< \
- ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \
+ ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \
+ ::InterpolationMode::imode>( \
ctype*, const float*, const ctype*, int, int, int, int, int, int, \
cudaStream_t);
-#define FOR_FORMAT_BMODE(ctype) \
- INST(ctype, NCHW, BORDER_CONSTANT) \
- INST(ctype, NCHW, BORDER_REPLICATE) \
- INST(ctype, NCHW, BORDER_REFLECT) \
- INST(ctype, NCHW, BORDER_REFLECT_101) \
- INST(ctype, NCHW, BORDER_WRAP)
+#define FOR_FORMAT_BMODE(ctype) \
+ INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR)
FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
diff --git a/dnn/src/cuda/remap/backward_mat.cpp b/dnn/src/cuda/remap/backward_mat.cpp
index e14d9b09f8a140a6724dfee9c11c3ef90869bc22..c1dcc53cbb5eb1751de906feea86b13501478ccc 100644
--- a/dnn/src/cuda/remap/backward_mat.cpp
+++ b/dnn/src/cuda/remap/backward_mat.cpp
@@ -22,8 +22,9 @@ void RemapBackwardMatImpl::exec(
_megdnn_tensor_out grad, _megdnn_workspace workspace) {
check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, workspace.size);
megdnn_assert(
- param().imode == param::Remap::InterpolationMode::LINEAR,
- "only support LINEAR interpolationMode");
+ (param().imode == param::Remap::InterpolationMode::NEAREST) ||
+ (param().imode == param::Remap::InterpolationMode::LINEAR),
+ "only support NEAREST and LINEAR interpolationMode");
megdnn_assert(
param().format == param::Remap::Format::NCHW,
"only support NCHW format for remap backward");
@@ -36,13 +37,15 @@ void RemapBackwardMatImpl::exec(
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];
-#define cb(dt, _format, bmode) \
+#define cb(dt, _format, bmode, inter_mode) \
if (param().format == param::Remap::Format::_format && \
- param().border_type == param::Remap::BorderMode::bmode) { \
+ param().border_type == param::Remap::BorderMode::bmode && \
+ param().imode == param::Remap::InterpolationMode::inter_mode) { \
using ctype = DTypeTrait::ctype; \
remap::backwardmat_proxy< \
ctype, param_enumv::Remap::Format::_format, \
- ::BorderMode::BORDER_##bmode>( \
+ ::BorderMode::BORDER_##bmode, \
+ ::InterpolationMode::INTER_##inter_mode>( \
src.compatible_ptr(), map_xy.compatible_ptr(), \
diff.compatible_ptr(), grad.compatible_ptr(), N, C, \
IH, IW, OH, OW, param().scalar, stream); \
@@ -51,11 +54,16 @@ void RemapBackwardMatImpl::exec(
#define support_dtype(dt) \
case DTypeTrait::enumv: { \
- cb(dt, NCHW, CONSTANT); \
- cb(dt, NCHW, REPLICATE); \
- cb(dt, NCHW, REFLECT); \
- cb(dt, NCHW, REFLECT_101); \
- cb(dt, NCHW, WRAP); \
+ cb(dt, NCHW, CONSTANT, NEAREST); \
+ cb(dt, NCHW, REPLICATE, NEAREST); \
+ cb(dt, NCHW, REFLECT, NEAREST); \
+ cb(dt, NCHW, REFLECT_101, NEAREST); \
+ cb(dt, NCHW, WRAP, NEAREST); \
+ cb(dt, NCHW, CONSTANT, LINEAR); \
+ cb(dt, NCHW, REPLICATE, LINEAR); \
+ cb(dt, NCHW, REFLECT, LINEAR); \
+ cb(dt, NCHW, REFLECT_101, LINEAR); \
+ cb(dt, NCHW, WRAP, LINEAR); \
megdnn_throw("unsupported border type in remap cuda"); \
}
diff --git a/dnn/src/cuda/remap/backward_mat.cu b/dnn/src/cuda/remap/backward_mat.cu
index f0f6498a0fb07647772c430b83f9f4ee662dc404..882bd8e2153932e019646a88278cf3f97d949372 100644
--- a/dnn/src/cuda/remap/backward_mat.cu
+++ b/dnn/src/cuda/remap/backward_mat.cu
@@ -53,7 +53,7 @@ struct GetSrcData {
};
template
-__global__ void kern_general(
+__global__ void kern_general_linear(
const ctype* src, const float* map_xy, const ctype* diff,
float* __restrict grad, int C, int IH, int IW, int OH, int OW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
@@ -62,7 +62,6 @@ __global__ void kern_general(
diff += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
grad += blockIdx.z * 2 * OH * OW;
- RoundingConverter round_converter;
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
@@ -86,23 +85,25 @@ __global__ void kern_general(
int a11 = GetSrcData::get_index(
row + 1, col + 1, c, IH, IW, C);
- dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u);
- dv += ((a01 != -1) ? src[a01] : scalar) * (one - u);
- dv -= ((a10 != -1) ? src[a10] : scalar) * u;
- dv += ((a11 != -1) ? src[a11] : scalar) * u;
+ dv -= ((a00 != -1) ? static_cast(src[a00]) : scalar) * (one - u);
+ dv += ((a01 != -1) ? static_cast(src[a01]) : scalar) * (one - u);
+ dv -= ((a10 != -1) ? static_cast(src[a10]) : scalar) * u;
+ dv += ((a11 != -1) ? static_cast(src[a11]) : scalar) * u;
- du -= ((a00 != -1) ? src[a00] : scalar) * (one - v);
- du -= ((a01 != -1) ? src[a01] : scalar) * v;
- du += ((a10 != -1) ? src[a10] : scalar) * (one - v);
- du += ((a11 != -1) ? src[a11] : scalar) * v;
+ du -= ((a00 != -1) ? static_cast(src[a00]) : scalar) * (one - v);
+ du -= ((a01 != -1) ? static_cast(src[a01]) : scalar) * v;
+ du += ((a10 != -1) ? static_cast(src[a10]) : scalar) * (one - v);
+ du += ((a11 != -1) ? static_cast(src[a11]) : scalar) * v;
- grad[oh * OW * 2 + ow * 2 + 0] += round_converter(hidden * dv);
- grad[oh * OW * 2 + ow * 2 + 1] += round_converter(hidden * du);
+ grad[oh * OW * 2 + ow * 2 + 0] += hidden * dv;
+ grad[oh * OW * 2 + ow * 2 + 1] += hidden * du;
}
}
}
-template
+template <
+ typename ctype, const uint32_t format, ::BorderMode bmode,
+ ::InterpolationMode imode>
void dispatch_backwardmat(
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) {
@@ -115,8 +116,11 @@ void dispatch_backwardmat(
cuda_check(cudaMemsetAsync(
grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2, stream));
- kern_general<<>>(
- src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar);
+
+ if (imode == ::InterpolationMode::INTER_LINEAR) {
+ kern_general_linear<<>>(
+ src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar);
+ }
N -= curr_batch_size;
src += curr_batch_size * C * IH * IW;
@@ -132,27 +136,35 @@ namespace megdnn {
namespace cuda {
namespace remap {
-template
+template <
+ typename ctype, const uint32_t format, ::BorderMode bmode,
+ ::InterpolationMode imode>
void backwardmat_proxy(
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) {
- dispatch_backwardmat(
+ dispatch_backwardmat(
src, map_xy, diff, grad, N, C, IH, IW, OH, OW, scalar, stream);
after_kernel_launch();
}
-#define INST(ctype, format, bmode) \
- template void \
- backwardmat_proxy( \
+#define INST(ctype, format, bmode, imode) \
+ template void backwardmat_proxy< \
+ ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \
+ ::InterpolationMode::imode>( \
const ctype*, const float*, const ctype*, float*, int, int, int, int, int, \
int, float, cudaStream_t);
-#define FOR_FORMAT_BMODE(ctype) \
- INST(ctype, NCHW, BORDER_CONSTANT) \
- INST(ctype, NCHW, BORDER_REPLICATE) \
- INST(ctype, NCHW, BORDER_REFLECT) \
- INST(ctype, NCHW, BORDER_REFLECT_101) \
- INST(ctype, NCHW, BORDER_WRAP)
+#define FOR_FORMAT_BMODE(ctype) \
+ INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR)
FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
diff --git a/dnn/src/cuda/remap/common.h b/dnn/src/cuda/remap/common.h
index 0d85f6a82c7e6c1ab9a92f36f18699edbcb6e81c..779c205606de28f6294a9b8d53a918d5dcd183a3 100644
--- a/dnn/src/cuda/remap/common.h
+++ b/dnn/src/cuda/remap/common.h
@@ -21,17 +21,23 @@ namespace remap {
// all these kernels use LINEAR interpolation
-template
+template <
+ typename ctype, const uint32_t format, ::BorderMode bmode,
+ ::InterpolationMode imode>
void forward_proxy(
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
int OH, int OW, float scalar, cudaStream_t stream);
-template
+template <
+ typename ctype, const uint32_t format, ::BorderMode bmode,
+ ::InterpolationMode imode>
void backwarddata_proxy(
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream);
-template
+template <
+ typename ctype, const uint32_t format, ::BorderMode bmode,
+ ::InterpolationMode imode>
void backwardmat_proxy(
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream);
diff --git a/dnn/src/cuda/remap/forward.cpp b/dnn/src/cuda/remap/forward.cpp
index 4a40b551c9a630e4217c358a634d3fab8b4d14a2..2e1876cddf96771696cfecd6c748a5107e231c5a 100644
--- a/dnn/src/cuda/remap/forward.cpp
+++ b/dnn/src/cuda/remap/forward.cpp
@@ -30,8 +30,9 @@ void RemapImpl::exec(
OW = map_xy.layout.shape[2];
megdnn_assert(
- param().imode == param::Remap::InterpolationMode::LINEAR,
- "only support LINEAR interpolationMode");
+ (param().imode == param::Remap::InterpolationMode::NEAREST) ||
+ (param().imode == param::Remap::InterpolationMode::LINEAR),
+ "only support NEAREST and LINEAR interpolationMode");
if (param().format == param::Remap::Format::NCHW) {
N = src.layout.shape[0];
@@ -47,13 +48,15 @@ void RemapImpl::exec(
megdnn_throw("unsupported format, cuda remap");
}
-#define cb(dt, _format, bmode) \
+#define cb(dt, _format, bmode, inter_mode) \
if (param().format == param::Remap::Format::_format && \
- param().border_type == param::Remap::BorderMode::bmode) { \
+ param().border_type == param::Remap::BorderMode::bmode && \
+ param().imode == param::Remap::InterpolationMode::inter_mode) { \
using ctype = DTypeTrait::ctype; \
remap::forward_proxy< \
ctype, param_enumv::Remap::Format::_format, \
- ::BorderMode::BORDER_##bmode>( \
+ ::BorderMode::BORDER_##bmode, \
+ ::InterpolationMode::INTER_##inter_mode>( \
src.compatible_ptr(), map_xy.compatible_ptr(), \
dst.compatible_ptr(), N, C, IH, IW, OH, OW, param().scalar, \
stream); \
@@ -62,16 +65,26 @@ void RemapImpl::exec(
#define support_dtype(dt) \
case DTypeTrait::enumv: { \
- cb(dt, NCHW, CONSTANT); \
- cb(dt, NCHW, REPLICATE); \
- cb(dt, NCHW, REFLECT); \
- cb(dt, NCHW, REFLECT_101); \
- cb(dt, NCHW, WRAP); \
- cb(dt, NHWC, CONSTANT); \
- cb(dt, NHWC, REPLICATE); \
- cb(dt, NHWC, REFLECT); \
- cb(dt, NHWC, REFLECT_101); \
- cb(dt, NHWC, WRAP); \
+ cb(dt, NCHW, CONSTANT, NEAREST); \
+ cb(dt, NCHW, REPLICATE, NEAREST); \
+ cb(dt, NCHW, REFLECT, NEAREST); \
+ cb(dt, NCHW, REFLECT_101, NEAREST); \
+ cb(dt, NCHW, WRAP, NEAREST); \
+ cb(dt, NHWC, CONSTANT, NEAREST); \
+ cb(dt, NHWC, REPLICATE, NEAREST); \
+ cb(dt, NHWC, REFLECT, NEAREST); \
+ cb(dt, NHWC, REFLECT_101, NEAREST); \
+ cb(dt, NHWC, WRAP, NEAREST); \
+ cb(dt, NCHW, CONSTANT, LINEAR); \
+ cb(dt, NCHW, REPLICATE, LINEAR); \
+ cb(dt, NCHW, REFLECT, LINEAR); \
+ cb(dt, NCHW, REFLECT_101, LINEAR); \
+ cb(dt, NCHW, WRAP, LINEAR); \
+ cb(dt, NHWC, CONSTANT, LINEAR); \
+ cb(dt, NHWC, REPLICATE, LINEAR); \
+ cb(dt, NHWC, REFLECT, LINEAR); \
+ cb(dt, NHWC, REFLECT_101, LINEAR); \
+ cb(dt, NHWC, WRAP, LINEAR); \
megdnn_throw("unsupported border type in remap cuda"); \
}
diff --git a/dnn/src/cuda/remap/forward.cu b/dnn/src/cuda/remap/forward.cu
index 648c4188e81b86fbe5fec8c964cfd30f6d7fbd31..700008421d242f79a515f2f06096b60150564e6e 100644
--- a/dnn/src/cuda/remap/forward.cu
+++ b/dnn/src/cuda/remap/forward.cu
@@ -62,8 +62,23 @@ struct GetSrcData {
}
};
-template
-__global__ void kern_general(
+__device__ inline float round_half_to_even(float f) {
+ const float round_away_from_zero = round(f);
+ const float diff = round_away_from_zero - f;
+
+ if ((diff != 0.5f) && (diff != -0.5f)) {
+ return round_away_from_zero;
+ }
+
+ if (fmod(round_away_from_zero, 2.0f) == 0.0f) {
+ return round_away_from_zero;
+ }
+
+ return f - diff;
+}
+
+template
+__global__ void kern_general_nearest(
const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
@@ -71,37 +86,22 @@ __global__ void kern_general(
sptr += blockIdx.z * C * IH * IW;
dst += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
- RoundingConverter round_converter;
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
- int col = static_cast(floor(index_col));
- int row = static_cast(floor(index_row));
- float v = index_col - col;
- float u = index_row - row;
+ int col = static_cast(round_half_to_even(index_col));
+ int row = static_cast(round_half_to_even(index_row));
for (int c = 0; c < C; ++c) {
- ctype a00 = GetSrcData::get(
- sptr, row + 0, col + 0, c, IH, IW, C, scalar);
- ctype a01 = GetSrcData::get(
- sptr, row + 0, col + 1, c, IH, IW, C, scalar);
- ctype a10 = GetSrcData::get(
- sptr, row + 1, col + 0, c, IH, IW, C, scalar);
- ctype a11 = GetSrcData::get(
- sptr, row + 1, col + 1, c, IH, IW, C, scalar);
- /* in remap, we use float as the type of intermediate result */
- float result = static_cast(a00) * (1.f - u) * (1.f - v) +
- static_cast(a01) * (1.f - u) * v +
- static_cast(a10) * (1.f - v) * u +
- static_cast(a11) * u * v;
- dst[get_offset(oh, ow, c, OH, OW, C)] =
- round_converter(result);
+ dst[get_offset(oh, ow, c, OH, OW, C)] =
+ GetSrcData::get(
+ sptr, row, col, c, IH, IW, C, scalar);
}
}
}
-template
-__global__ void kern_general_nhwc(
+template
+__global__ void kern_general_linear(
const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
@@ -119,26 +119,27 @@ __global__ void kern_general_nhwc(
float v = index_col - col;
float u = index_row - row;
for (int c = 0; c < C; ++c) {
- ctype a00 = GetSrcData::get(
+ ctype a00 = GetSrcData::get(
sptr, row + 0, col + 0, c, IH, IW, C, scalar);
- ctype a01 = GetSrcData::get(
+ ctype a01 = GetSrcData::get(
sptr, row + 0, col + 1, c, IH, IW, C, scalar);
- ctype a10 = GetSrcData::get(
+ ctype a10 = GetSrcData::get(
sptr, row + 1, col + 0, c, IH, IW, C, scalar);
- ctype a11 = GetSrcData::get(
+ ctype a11 = GetSrcData::get(
sptr, row + 1, col + 1, c, IH, IW, C, scalar);
/* in remap, we use float as the type of intermediate result */
float result = static_cast(a00) * (1.f - u) * (1.f - v) +
static_cast(a01) * (1.f - u) * v +
static_cast(a10) * (1.f - v) * u +
static_cast(a11) * u * v;
- dst[get_offset(oh, ow, c, OH, OW, C)] =
- round_converter(result);
+ dst[get_offset(oh, ow, c, OH, OW, C)] = round_converter(result);
}
}
}
-template
+template <
+ typename ctype, const uint32_t format, ::BorderMode bmode,
+ ::InterpolationMode imode>
void dispatch_forward(
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
int OH, int OW, float scalar, cudaStream_t stream) {
@@ -150,11 +151,11 @@ void dispatch_forward(
dim3 threads(BX, BY);
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);
- if (format == param_enumv::Remap::Format::NCHW) {
- kern_general<<>>(
+ if (imode == ::InterpolationMode::INTER_NEAREST) {
+ kern_general_nearest<<>>(
src, map_xy, dst, C, IH, IW, OH, OW, scalar);
- } else if (format == param_enumv::Remap::Format::NHWC) {
- kern_general_nhwc<<>>(
+ } else if (imode == ::InterpolationMode::INTER_LINEAR) {
+ kern_general_linear<<>>(
src, map_xy, dst, C, IH, IW, OH, OW, scalar);
}
@@ -171,32 +172,45 @@ namespace megdnn {
namespace cuda {
namespace remap {
-template
+template <
+ typename ctype, const uint32_t format, ::BorderMode bmode,
+ ::InterpolationMode imode>
void forward_proxy(
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
int OH, int OW, float scalar, cudaStream_t stream) {
- dispatch_forward(
+ dispatch_forward(
src, map_xy, dst, N, C, IH, IW, OH, OW, scalar, stream);
after_kernel_launch();
}
-#define INST(ctype, format, bmode) \
- template void \
- forward_proxy( \
+#define INST(ctype, format, bmode, imode) \
+ template void forward_proxy< \
+ ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \
+ ::InterpolationMode::imode>( \
const ctype*, const float*, ctype*, int, int, int, int, int, int, float, \
cudaStream_t);
-#define FOR_FORMAT_BMODE(ctype) \
- INST(ctype, NCHW, BORDER_CONSTANT) \
- INST(ctype, NCHW, BORDER_REPLICATE) \
- INST(ctype, NCHW, BORDER_REFLECT) \
- INST(ctype, NCHW, BORDER_REFLECT_101) \
- INST(ctype, NCHW, BORDER_WRAP) \
- INST(ctype, NHWC, BORDER_CONSTANT) \
- INST(ctype, NHWC, BORDER_REPLICATE) \
- INST(ctype, NHWC, BORDER_REFLECT) \
- INST(ctype, NHWC, BORDER_REFLECT_101) \
- INST(ctype, NHWC, BORDER_WRAP)
+#define FOR_FORMAT_BMODE(ctype) \
+ INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \
+ INST(ctype, NHWC, BORDER_CONSTANT, INTER_NEAREST) \
+ INST(ctype, NHWC, BORDER_REPLICATE, INTER_NEAREST) \
+ INST(ctype, NHWC, BORDER_REFLECT, INTER_NEAREST) \
+ INST(ctype, NHWC, BORDER_REFLECT_101, INTER_NEAREST) \
+ INST(ctype, NHWC, BORDER_WRAP, INTER_NEAREST) \
+ INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \
+ INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) \
+ INST(ctype, NHWC, BORDER_CONSTANT, INTER_LINEAR) \
+ INST(ctype, NHWC, BORDER_REPLICATE, INTER_LINEAR) \
+ INST(ctype, NHWC, BORDER_REFLECT, INTER_LINEAR) \
+ INST(ctype, NHWC, BORDER_REFLECT_101, INTER_LINEAR) \
+ INST(ctype, NHWC, BORDER_WRAP, INTER_LINEAR)
FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))
diff --git a/dnn/src/naive/remap/opr_impl.cpp b/dnn/src/naive/remap/opr_impl.cpp
index 56e8b632e6b51560e16212862c3d19f91c7557b8..440d73d2b995180f8a635c4acd3c719eb3658832 100644
--- a/dnn/src/naive/remap/opr_impl.cpp
+++ b/dnn/src/naive/remap/opr_impl.cpp
@@ -36,6 +36,12 @@ inline int get_offset(
return height * w * c + width * c + channel;
}
+template <>
+inline int get_offset(
+ int height, int width, int channel, int, int w, int c) {
+ return ((height * c + channel) * w + width) * 4;
+}
+
template <
typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
@@ -80,8 +86,12 @@ void remap_LINEAR(
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
int OH, int OW, float scalar) {
RoundingConverter round_converter;
- for (int n = 0; n < N;
- ++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) {
+ size_t c_scale = 1;
+ if (format == param::Remap::Format::NHWCD4) {
+ c_scale = 4;
+ }
+ for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW,
+ dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) {
for (int h = 0; h < OH; ++h) {
for (int w = 0; w < OW; ++w) {
float index_col = map_xy[h * OW * 2 + w * 2 + 0];
@@ -92,18 +102,102 @@ void remap_LINEAR(
float u = index_row - row; // alphah
const float one = 1.f;
for (int c = 0; c < C; ++c) {
- ctype a00 = GetSrcData::get(
- src, row + 0, col + 0, c, IH, IW, C, scalar);
- ctype a01 = GetSrcData::get(
- src, row + 0, col + 1, c, IH, IW, C, scalar);
- ctype a10 = GetSrcData::get(
- src, row + 1, col + 0, c, IH, IW, C, scalar);
- ctype a11 = GetSrcData::get(
- src, row + 1, col + 1, c, IH, IW, C, scalar);
-
- dst[get_offset(h, w, c, OH, OW, C)] = round_converter(
- a00 * (one - v) * (one - u) + a01 * (one - u) * v +
- a10 * (one - v) * u + a11 * u * v);
+ if (format == param::Remap::Format::NHWCD4) {
+ int idx00 = GetSrcData::get_index(
+ row + 0, col + 0, c, IH, IW, C);
+ int idx01 = GetSrcData::get_index(
+ row + 0, col + 1, c, IH, IW, C);
+ int idx10 = GetSrcData::get_index(
+ row + 1, col + 0, c, IH, IW, C);
+ int idx11 = GetSrcData::get_index(
+ row + 1, col + 1, c, IH, IW, C);
+ for (int c_inner = 0; c_inner < 4; ++c_inner) {
+ ctype a00 = (idx00 != -1) ? src[idx00 + c_inner]
+ : round_converter(scalar);
+ ctype a01 = (idx01 != -1) ? src[idx01 + c_inner]
+ : round_converter(scalar);
+ ctype a10 = (idx10 != -1) ? src[idx10 + c_inner]
+ : round_converter(scalar);
+ ctype a11 = (idx11 != -1) ? src[idx11 + c_inner]
+ : round_converter(scalar);
+ dst[get_offset(h, w, c, OH, OW, C) + c_inner] =
+ round_converter(
+ a00 * (one - v) * (one - u) +
+ a01 * (one - u) * v + a10 * (one - v) * u +
+ a11 * u * v);
+ }
+ } else {
+ ctype a00 = GetSrcData::get(
+ src, row + 0, col + 0, c, IH, IW, C, scalar);
+ ctype a01 = GetSrcData::get(
+ src, row + 0, col + 1, c, IH, IW, C, scalar);
+ ctype a10 = GetSrcData::get(
+ src, row + 1, col + 0, c, IH, IW, C, scalar);
+ ctype a11 = GetSrcData::get(
+ src, row + 1, col + 1, c, IH, IW, C, scalar);
+
+ dst[get_offset(h, w, c, OH, OW, C)] = round_converter(
+ a00 * (one - v) * (one - u) + a01 * (one - u) * v +
+ a10 * (one - v) * u + a11 * u * v);
+ }
+ }
+ }
+ }
+ }
+}
+
+namespace {
+
+inline float round_half_to_even(float f) {
+ const float round_away_from_zero = std::round(f);
+ const float diff = round_away_from_zero - f;
+
+ if ((diff != 0.5f) && (diff != -0.5f)) {
+ return round_away_from_zero;
+ }
+
+ if (std::fmod(round_away_from_zero, 2.0f) == 0.0f) {
+ return round_away_from_zero;
+ }
+
+ return f - diff;
+}
+
+} // anonymous namespace
+
+template <
+ typename ctype, param::Remap::Format format,
+ param::Remap::BorderMode bordertype>
+void remap_NEAREST(
+ const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
+ int OH, int OW, float scalar) {
+ RoundingConverter round_converter;
+ size_t c_scale = 1;
+ if (format == param::Remap::Format::NHWCD4) {
+ c_scale = 4;
+ }
+ for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW,
+ dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) {
+ for (int h = 0; h < OH; ++h) {
+ for (int w = 0; w < OW; ++w) {
+ float index_col = map_xy[h * OW * 2 + w * 2 + 0];
+ float index_row = map_xy[h * OW * 2 + w * 2 + 1];
+ int col = static_cast(round_half_to_even(index_col));
+ int row = static_cast(round_half_to_even(index_row));
+ for (int c = 0; c < C; ++c) {
+ if (format == param::Remap::Format::NHWCD4) {
+ int idx = GetSrcData::get_index(
+ row, col, c, IH, IW, C);
+ for (int c_inner = 0; c_inner < 4; ++c_inner) {
+ dst[get_offset(h, w, c, OH, OW, C) + c_inner] =
+ (idx != -1) ? (src[idx + c_inner])
+ : round_converter(scalar);
+ }
+ } else {
+ dst[get_offset(h, w, c, OH, OW, C)] =
+ GetSrcData::get(
+ src, row, col, c, IH, IW, C, scalar);
+ }
}
}
}
@@ -161,13 +255,40 @@ void remap_LINEAR_backwarddata(
}
}
+template <
+ typename ctype, param::Remap::Format format,
+ param::Remap::BorderMode bordertype>
+void remap_NEAREST_backwarddata(
+ ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH,
+ int IW, int OH, int OW) {
+ std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW);
+ for (int n = 0; n < N;
+ ++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) {
+ for (int h = 0; h < OH; ++h) {
+ for (int w = 0; w < OW; ++w) {
+ float index_col = map_xy[h * OW * 2 + w * 2 + 0];
+ float index_row = map_xy[h * OW * 2 + w * 2 + 1];
+ int col = static_cast(round_half_to_even(index_col));
+ int row = static_cast(round_half_to_even(index_row));
+ for (int c = 0; c < C; ++c) {
+ ctype hidden = diff[get_offset(h, w, c, OH, OW, C)];
+ int idx = GetSrcData::get_index(
+ row, col, c, IH, IW, C);
+ if (idx != -1) {
+ grad[idx] += hidden;
+ }
+ }
+ }
+ }
+ }
+}
+
template <
typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
void remap_LINEAR_backwardmat(
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
int C, int IH, int IW, int OH, int OW, float scalar) {
- RoundingConverter round_converter;
std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW);
for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW,
map_xy += OH * OW * 2, grad += OH * OW * 2) {
@@ -194,24 +315,38 @@ void remap_LINEAR_backwardmat(
int a11 = GetSrcData::get_index(
row + 1, col + 1, c, IH, IW, C);
- dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u);
- dv += ((a01 != -1) ? src[a01] : scalar) * (one - u);
- dv -= ((a10 != -1) ? src[a10] : scalar) * u;
- dv += ((a11 != -1) ? src[a11] : scalar) * u;
-
- du -= ((a00 != -1) ? src[a00] : scalar) * (one - v);
- du -= ((a01 != -1) ? src[a01] : scalar) * v;
- du += ((a10 != -1) ? src[a10] : scalar) * (one - v);
- du += ((a11 != -1) ? src[a11] : scalar) * v;
-
- grad[h * OW * 2 + w * 2 + 0] += round_converter(hidden * dv);
- grad[h * OW * 2 + w * 2 + 1] += round_converter(hidden * du);
+ dv -= ((a00 != -1) ? static_cast(src[a00]) : scalar) *
+ (one - u);
+ dv += ((a01 != -1) ? static_cast(src[a01]) : scalar) *
+ (one - u);
+ dv -= ((a10 != -1) ? static_cast(src[a10]) : scalar) * u;
+ dv += ((a11 != -1) ? static_cast(src[a11]) : scalar) * u;
+
+ du -= ((a00 != -1) ? static_cast(src[a00]) : scalar) *
+ (one - v);
+ du -= ((a01 != -1) ? static_cast(src[a01]) : scalar) * v;
+ du += ((a10 != -1) ? static_cast(src[a10]) : scalar) *
+ (one - v);
+ du += ((a11 != -1) ? static_cast(src[a11]) : scalar) * v;
+
+ grad[h * OW * 2 + w * 2 + 0] += hidden * dv;
+ grad[h * OW * 2 + w * 2 + 1] += hidden * du;
}
}
}
}
}
+template <
+ typename ctype, param::Remap::Format format,
+ param::Remap::BorderMode bordertype>
+void remap_NEAREST_backwardmat(
+ const ctype*, const float*, const ctype*, float* grad, int N, int, int, int,
+ int OH, int OW, float) {
+ std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW);
+ return;
+}
+
} // namespace
void RemapImpl::exec(
@@ -229,6 +364,11 @@ void RemapImpl::exec(
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
+ } else if (param().format == param::Remap::Format::NHWCD4) {
+ N = src.layout.shape[0];
+ C = src.layout.shape[2];
+ IH = src.layout.shape[1];
+ IW = src.layout.shape[3];
} else {
megdnn_throw("unsupported format");
}
@@ -255,11 +395,31 @@ void RemapImpl::exec(
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
+ cb(dt, NHWCD4, CONSTANT, LINEAR); \
+ cb(dt, NHWCD4, REPLICATE, LINEAR); \
+ cb(dt, NHWCD4, REFLECT, LINEAR); \
+ cb(dt, NHWCD4, REFLECT_101, LINEAR); \
+ cb(dt, NHWCD4, WRAP, LINEAR); \
cb(dt, NHWC, CONSTANT, LINEAR); \
cb(dt, NHWC, REPLICATE, LINEAR); \
cb(dt, NHWC, REFLECT, LINEAR); \
cb(dt, NHWC, REFLECT_101, LINEAR); \
cb(dt, NHWC, WRAP, LINEAR); \
+ cb(dt, NCHW, CONSTANT, NEAREST); \
+ cb(dt, NCHW, REPLICATE, NEAREST); \
+ cb(dt, NCHW, REFLECT, NEAREST); \
+ cb(dt, NCHW, REFLECT_101, NEAREST); \
+ cb(dt, NCHW, WRAP, NEAREST); \
+ cb(dt, NHWCD4, CONSTANT, NEAREST); \
+ cb(dt, NHWCD4, REPLICATE, NEAREST); \
+ cb(dt, NHWCD4, REFLECT, NEAREST); \
+ cb(dt, NHWCD4, REFLECT_101, NEAREST); \
+ cb(dt, NHWCD4, WRAP, NEAREST); \
+ cb(dt, NHWC, CONSTANT, NEAREST); \
+ cb(dt, NHWC, REPLICATE, NEAREST); \
+ cb(dt, NHWC, REFLECT, NEAREST); \
+ cb(dt, NHWC, REFLECT_101, NEAREST); \
+ cb(dt, NHWC, WRAP, NEAREST); \
megdnn_throw( \
"format, border type or imode is incorrect in remap navie " \
"with dtype = " #dt); \
@@ -313,6 +473,11 @@ void RemapBackwardDataImpl::exec(
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
+ cb(dt, NCHW, CONSTANT, NEAREST); \
+ cb(dt, NCHW, REPLICATE, NEAREST); \
+ cb(dt, NCHW, REFLECT, NEAREST); \
+ cb(dt, NCHW, REFLECT_101, NEAREST); \
+ cb(dt, NCHW, WRAP, NEAREST); \
megdnn_throw( \
"format, border type or imode is incorrect in remap navie " \
"with dtype = " #dt); \
@@ -365,6 +530,11 @@ void RemapBackwardMatImpl::exec(
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
+ cb(dt, NCHW, CONSTANT, NEAREST); \
+ cb(dt, NCHW, REPLICATE, NEAREST); \
+ cb(dt, NCHW, REFLECT, NEAREST); \
+ cb(dt, NCHW, REFLECT_101, NEAREST); \
+ cb(dt, NCHW, WRAP, NEAREST); \
megdnn_throw( \
"format, border type or imode is incorrect in remap navie " \
"with dtype = " #dt); \
diff --git a/dnn/test/common/remap.h b/dnn/test/common/remap.h
index 8c2198843e2a248357d8d056831c864fd9940308..9ae64d204c6fd69d87c3b8bc0d1fa646541feb98 100644
--- a/dnn/test/common/remap.h
+++ b/dnn/test/common/remap.h
@@ -34,53 +34,91 @@ static inline std::vector get_nchw_args() {
param::Remap param;
std::vector format_vec = {param::Remap::Format::NCHW};
+ std::vector interp_mode_vec = {
+ param::Remap::InterpolationMode::NEAREST,
+ param::Remap::InterpolationMode::LINEAR};
std::vector border_mode_vec = {
param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101,
param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP,
param::Remap::BorderMode::REPLICATE};
+
// current do not test this.
std::vector scalar;
for (auto fmt : format_vec) {
- for (auto border_type : border_mode_vec) {
- param.format = fmt;
- param.border_type = border_type;
- args.emplace_back(
- param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2},
- TensorShape{70000, 1, 2, 2});
-
- args.emplace_back(
- param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2},
- TensorShape{1, 1, 2, 2});
-
- args.emplace_back(
- param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2},
- TensorShape{1, 3, 2, 2});
-
- args.emplace_back(
- param, TensorShape{1, 10, 100, 100}, TensorShape{1, 100, 100, 2},
- TensorShape{1, 10, 100, 100});
-
- args.emplace_back(
- param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2},
- TensorShape{2, 4, 100, 200});
-
- args.emplace_back(
- param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2},
- TensorShape{2, 4, 20, 30});
-
- args.emplace_back(
- param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2},
- TensorShape{2, 4, 20, 30});
+ for (auto interp_mode : interp_mode_vec) {
+ for (auto border_type : border_mode_vec) {
+ param.format = fmt;
+ param.imode = interp_mode;
+ param.border_type = border_type;
+ args.emplace_back(
+ param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2},
+ TensorShape{70000, 1, 2, 2});
+
+ args.emplace_back(
+ param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2},
+ TensorShape{1, 1, 2, 2});
+
+ args.emplace_back(
+ param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2},
+ TensorShape{1, 3, 2, 2});
+
+ args.emplace_back(
+ param, TensorShape{1, 10, 100, 100},
+ TensorShape{1, 100, 100, 2}, TensorShape{1, 10, 100, 100});
+
+ args.emplace_back(
+ param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2},
+ TensorShape{2, 4, 100, 200});
+
+ args.emplace_back(
+ param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2},
+ TensorShape{2, 4, 20, 30});
+
+ args.emplace_back(
+ param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2},
+ TensorShape{2, 4, 20, 30});
+ }
}
}
return args;
}
+static inline std::vector get_nhwcd4_args() {
+ std::vector args;
+
+ param::Remap param;
+ param.format = param::Remap::Format::NHWCD4;
+ param.imode = param::Remap::InterpolationMode::LINEAR;
+ param.border_type = param::Remap::BorderMode::CONSTANT;
+ // FIXME: when fractional part of bval is not zero, naive and opencl bankend may
+ // have different rounding result
+ param.scalar = 77;
+ args.emplace_back(
+ param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2},
+ TensorShape{2, 4, 1, 6, 4});
+ args.emplace_back(
+ param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2},
+ TensorShape{2, 2, 1, 3, 4});
+
+ param.imode = param::Remap::InterpolationMode::NEAREST;
+ args.emplace_back(
+ param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2},
+ TensorShape{2, 4, 1, 6, 4});
+ args.emplace_back(
+ param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2},
+ TensorShape{2, 2, 1, 3, 4});
+
+ return args;
+}
+
static inline std::vector get_nhwc_args() {
std::vector args;
param::Remap param;
std::vector format_vec = {param::Remap::Format::NHWC};
+ std::vector interp_mode_vec = {
+ param::Remap::InterpolationMode::NEAREST,
+ param::Remap::InterpolationMode::LINEAR};
std::vector border_mode_vec = {
param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101,
param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP,
@@ -88,41 +126,44 @@ static inline std::vector get_nhwc_args() {
// current do not test this.
std::vector scalar;
for (auto fmt : format_vec) {
- for (auto border_type : border_mode_vec) {
- param.format = fmt;
- param.border_type = border_type;
- param.scalar = 12.f;
- args.emplace_back(
- param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2},
- TensorShape{70000, 2, 2, 1});
-
- args.emplace_back(
- param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2},
- TensorShape{1, 2, 2, 1});
-
- args.emplace_back(
- param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2},
- TensorShape{1, 2, 2, 3});
-
- args.emplace_back(
- param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2},
- TensorShape{1, 2, 2, 66});
-
- args.emplace_back(
- param, TensorShape{1, 100, 100, 10}, TensorShape{1, 100, 100, 2},
- TensorShape{1, 100, 100, 10});
-
- args.emplace_back(
- param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2},
- TensorShape{2, 100, 200, 4});
-
- args.emplace_back(
- param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2},
- TensorShape{2, 20, 30, 4});
-
- args.emplace_back(
- param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2},
- TensorShape{2, 20, 30, 4});
+ for (auto interp_mode : interp_mode_vec) {
+ for (auto border_type : border_mode_vec) {
+ param.format = fmt;
+ param.imode = interp_mode;
+ param.border_type = border_type;
+ param.scalar = 12.f;
+ args.emplace_back(
+ param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2},
+ TensorShape{70000, 2, 2, 1});
+
+ args.emplace_back(
+ param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2},
+ TensorShape{1, 2, 2, 1});
+
+ args.emplace_back(
+ param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2},
+ TensorShape{1, 2, 2, 3});
+
+ args.emplace_back(
+ param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2},
+ TensorShape{1, 2, 2, 66});
+
+ args.emplace_back(
+ param, TensorShape{1, 100, 100, 10},
+ TensorShape{1, 100, 100, 2}, TensorShape{1, 100, 100, 10});
+
+ args.emplace_back(
+ param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2},
+ TensorShape{2, 100, 200, 4});
+
+ args.emplace_back(
+ param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2},
+ TensorShape{2, 20, 30, 4});
+
+ args.emplace_back(
+ param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2},
+ TensorShape{2, 20, 30, 4});
+ }
}
}
return args;
diff --git a/dnn/test/common/resize.h b/dnn/test/common/resize.h
index caddf8208c425a5254e51a5e98a3ab293b0208f6..6e78d74e8f1a718934c8f3ffee3791fb2c0055d9 100644
--- a/dnn/test/common/resize.h
+++ b/dnn/test/common/resize.h
@@ -58,6 +58,11 @@ static void set_nchw_args(std::vector& args) {
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8});
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3});
args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
+
+ param.imode = param::Resize::InterpolationMode::NEAREST;
+ args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8});
+ args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3});
+ args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
}
static inline std::vector get_args(IMode imode = IMode::INTER_LINEAR) {
@@ -75,6 +80,25 @@ static inline std::vector get_args(IMode imode = IMode::INTER_LINEAR) {
return args;
}
+static inline std::vector get_nhwc_args() {
+ std::vector args;
+
+ param::Resize param;
+ param.format = param::Resize::Format::NHWC;
+ param.imode = param::Resize::InterpolationMode::LINEAR;
+
+ args.emplace_back(param, TensorShape{2, 3, 4, 2}, TensorShape{2, 6, 8, 2});
+ args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2});
+ args.emplace_back(param, TensorShape{1, 6, 8, 2}, TensorShape{1, 3, 4, 2});
+
+ param.imode = param::Resize::InterpolationMode::NEAREST;
+ args.emplace_back(param, TensorShape{2, 3, 4, 2}, TensorShape{2, 6, 8, 2});
+ args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2});
+ args.emplace_back(param, TensorShape{1, 6, 8, 2}, TensorShape{1, 3, 4, 2});
+
+ return args;
+}
+
static inline std::vector get_nhwcd4_args() {
std::vector args;
@@ -83,6 +107,9 @@ static inline std::vector get_nhwcd4_args() {
param.imode = param::Resize::InterpolationMode::LINEAR;
args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4});
args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4});
+ param.imode = param::Resize::InterpolationMode::NEAREST;
+ args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4});
+ args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4});
return args;
}
diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py
index 0fd905ee4c9639590482464b58870a6811369caa..bf98e3e15f082a8071b090857f93aaddfe81996b 100644
--- a/imperative/python/megengine/functional/vision.py
+++ b/imperative/python/megengine/functional/vision.py
@@ -351,7 +351,7 @@ def remap(
"reflect_101", "wrap".
scalar: value used in case of a constant border. Default: 0
interp_mode: interpolation methods.
- Default: "linear". Currently only support "linear" mode.
+ Default: "linear". Currently also support "nearest" mode.
Returns:
output tensor.