提交 04193e3b 编写于 作者: M Megvii Engine Team

feat(dnn): add nearest mode for remap and resize

GitOrigin-RevId: 31e7b72a7850d60e1d2992268c798117d8096173
上级 69b89388
......@@ -18,21 +18,22 @@ namespace megdnn {
void RemapBase::deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst) {
dst.dtype = src.dtype;
dst.ndim = src.ndim;
dst.shape[0] = src.shape[0];
size_t height_index, channel_index;
size_t n = src.shape[0];
size_t c, oh, ow;
oh = map_xy.shape[1];
ow = map_xy.shape[2];
if (param().format == param::Remap::Format::NHWC) {
height_index = 1;
channel_index = 3;
c = src.shape[3];
dst = TensorLayout(TensorShape({n, oh, ow, c}), src.dtype);
} else if (param().format == param::Remap::Format::NCHW) {
c = src.shape[1];
dst = TensorLayout(TensorShape{n, c, oh, ow}, src.dtype, src.format);
} else if (param().format == param::Remap::Format::NHWCD4) {
c = src.shape[2];
dst = TensorLayout{{n, oh, c, ow, 4}, src.dtype, src.format};
} else {
megdnn_assert(param().format == param::Remap::Format::NCHW);
height_index = 2;
channel_index = 1;
megdnn_throw("unsupport format");
}
dst.shape[height_index] = map_xy.shape[1];
dst.shape[height_index + 1] = map_xy.shape[2];
dst.shape[channel_index] = src.shape[channel_index];
}
void RemapBase::check_layout_fwd(
......@@ -42,7 +43,7 @@ void RemapBase::check_layout_fwd(
megdnn_layout_msg(dst);
};
MEGDNN_MARK_USED_VAR(errmsg);
megdnn_assert(src.ndim == map_xy.ndim && src.ndim == dst.ndim && src.ndim == 4);
megdnn_assert(src.ndim == dst.ndim);
megdnn_assert(dst.dtype == src.dtype);
megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str());
megdnn_assert(map_xy.shape[3] == 2);
......@@ -64,10 +65,13 @@ void RemapBase::check_layout_fwd(
megdnn_assert(
dst.shape[2] == map_xy.shape[1] && dst.shape[3] == map_xy.shape[2],
"%s", errmsg().c_str());
} else if (param().format == param::Remap::Format::NHWCD4) {
megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str());
megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str());
megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());
megdnn_assert(param().format == Param::Format::NHWCD4);
} else {
megdnn_throw(
"currently do not support other param.format except NHWC and "
"NCHW");
megdnn_throw("unsupport format");
}
}
......
......@@ -22,8 +22,9 @@ void RemapBackwardDataImpl::exec(
_megdnn_workspace workspace) {
check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size);
megdnn_assert(
param().imode == param::Remap::InterpolationMode::LINEAR,
"only support LINEAR interpolationMode");
(param().imode == param::Remap::InterpolationMode::NEAREST) ||
(param().imode == param::Remap::InterpolationMode::LINEAR),
"only support NEAREST and LINEAR interpolationMode");
megdnn_assert(
param().format == param::Remap::Format::NCHW,
"only support NCHW format for remap backward");
......@@ -36,13 +37,15 @@ void RemapBackwardDataImpl::exec(
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];
#define cb(dt, _format, bmode) \
#define cb(dt, _format, bmode, inter_mode) \
if (param().format == param::Remap::Format::_format && \
param().border_type == param::Remap::BorderMode::bmode) { \
param().border_type == param::Remap::BorderMode::bmode && \
param().imode == param::Remap::InterpolationMode::inter_mode) { \
using ctype = DTypeTrait<dt>::ctype; \
remap::backwarddata_proxy< \
ctype, param_enumv::Remap::Format::_format, \
::BorderMode::BORDER_##bmode>( \
::BorderMode::BORDER_##bmode, \
::InterpolationMode::INTER_##inter_mode>( \
grad.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \
diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, stream); \
break; \
......@@ -50,11 +53,16 @@ void RemapBackwardDataImpl::exec(
#define support_dtype(dt) \
case DTypeTrait<dt>::enumv: { \
cb(dt, NCHW, CONSTANT); \
cb(dt, NCHW, REPLICATE); \
cb(dt, NCHW, REFLECT); \
cb(dt, NCHW, REFLECT_101); \
cb(dt, NCHW, WRAP); \
cb(dt, NCHW, CONSTANT, NEAREST); \
cb(dt, NCHW, REPLICATE, NEAREST); \
cb(dt, NCHW, REFLECT, NEAREST); \
cb(dt, NCHW, REFLECT_101, NEAREST); \
cb(dt, NCHW, WRAP, NEAREST); \
cb(dt, NCHW, CONSTANT, LINEAR); \
cb(dt, NCHW, REPLICATE, LINEAR); \
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
megdnn_throw("unsupported border type in remap cuda"); \
}
......
......@@ -52,8 +52,49 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> {
}
};
__device__ inline float round_half_to_even(float f) {
const float round_away_from_zero = round(f);
const float diff = round_away_from_zero - f;
if ((diff != 0.5f) && (diff != -0.5f)) {
return round_away_from_zero;
}
if (fmod(round_away_from_zero, 2.0f) == 0.0f) {
return round_away_from_zero;
}
return f - diff;
}
template <typename ctype, const uint32_t format, ::BorderMode bmode>
__global__ void kern_general_nearest(
ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH,
int IW, int OH, int OW) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
grad += blockIdx.z * C * IH * IW;
diff += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
int col = static_cast<int>(round_half_to_even(index_col));
int row = static_cast<int>(round_half_to_even(index_row));
for (int c = 0; c < C; ++c) {
ctype hidden = diff[get_offset<format>(oh, ow, c, OH, OW, C)];
int idx =
GetSrcData<ctype, format, bmode>::get_index(row, col, c, IH, IW, C);
if (idx != -1) {
atomic_add(grad + idx, hidden);
}
}
}
}
template <typename ctype, const uint32_t format, ::BorderMode bmode>
__global__ void kern_general(
__global__ void kern_general_linear(
ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH,
int IW, int OH, int OW) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -93,8 +134,8 @@ __global__ void kern_general(
atomic_add(grad + a10, round_converter(u * (one - v) * hidden));
}
int a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::
get_index(row + 1, col + 1, c, IH, IW, C);
int a11 = GetSrcData<ctype, format, bmode>::get_index(
row + 1, col + 1, c, IH, IW, C);
if (a11 != -1) {
atomic_add(grad + a11, round_converter(u * v * hidden));
}
......@@ -102,7 +143,9 @@ __global__ void kern_general(
}
}
template <typename ctype, const uint32_t format, ::BorderMode bmode>
template <
typename ctype, const uint32_t format, ::BorderMode bmode,
::InterpolationMode imode>
void dispatch_backwarddata(
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream) {
......@@ -115,8 +158,13 @@ void dispatch_backwarddata(
cuda_check(cudaMemsetAsync(
grad, 0, sizeof(ctype) * curr_batch_size * C * IH * IW, stream));
kern_general<ctype, format, bmode>
<<<blocks, threads, 0, stream>>>(grad, map_xy, diff, C, IH, IW, OH, OW);
if (imode == ::InterpolationMode::INTER_NEAREST) {
kern_general_nearest<ctype, format, bmode><<<blocks, threads, 0, stream>>>(
grad, map_xy, diff, C, IH, IW, OH, OW);
} else if (imode == ::InterpolationMode::INTER_LINEAR) {
kern_general_linear<ctype, format, bmode><<<blocks, threads, 0, stream>>>(
grad, map_xy, diff, C, IH, IW, OH, OW);
}
N -= curr_batch_size;
grad += curr_batch_size * C * IH * IW;
......@@ -131,27 +179,35 @@ namespace megdnn {
namespace cuda {
namespace remap {
template <typename ctype, const uint32_t format, ::BorderMode bmode>
template <
typename ctype, const uint32_t format, ::BorderMode bmode,
::InterpolationMode imode>
void backwarddata_proxy(
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream) {
dispatch_backwarddata<ctype, format, bmode>(
dispatch_backwarddata<ctype, format, bmode, imode>(
grad, map_xy, diff, N, C, IH, IW, OH, OW, stream);
after_kernel_launch();
}
#define INST(ctype, format, bmode) \
#define INST(ctype, format, bmode, imode) \
template void backwarddata_proxy< \
ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \
ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \
::InterpolationMode::imode>( \
ctype*, const float*, const ctype*, int, int, int, int, int, int, \
cudaStream_t);
#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT) \
INST(ctype, NCHW, BORDER_REPLICATE) \
INST(ctype, NCHW, BORDER_REFLECT) \
INST(ctype, NCHW, BORDER_REFLECT_101) \
INST(ctype, NCHW, BORDER_WRAP)
#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR)
FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
......
......@@ -22,8 +22,9 @@ void RemapBackwardMatImpl::exec(
_megdnn_tensor_out grad, _megdnn_workspace workspace) {
check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, workspace.size);
megdnn_assert(
param().imode == param::Remap::InterpolationMode::LINEAR,
"only support LINEAR interpolationMode");
(param().imode == param::Remap::InterpolationMode::NEAREST) ||
(param().imode == param::Remap::InterpolationMode::LINEAR),
"only support NEAREST and LINEAR interpolationMode");
megdnn_assert(
param().format == param::Remap::Format::NCHW,
"only support NCHW format for remap backward");
......@@ -36,13 +37,15 @@ void RemapBackwardMatImpl::exec(
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];
#define cb(dt, _format, bmode) \
#define cb(dt, _format, bmode, inter_mode) \
if (param().format == param::Remap::Format::_format && \
param().border_type == param::Remap::BorderMode::bmode) { \
param().border_type == param::Remap::BorderMode::bmode && \
param().imode == param::Remap::InterpolationMode::inter_mode) { \
using ctype = DTypeTrait<dt>::ctype; \
remap::backwardmat_proxy< \
ctype, param_enumv::Remap::Format::_format, \
::BorderMode::BORDER_##bmode>( \
::BorderMode::BORDER_##bmode, \
::InterpolationMode::INTER_##inter_mode>( \
src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \
diff.compatible_ptr<ctype>(), grad.compatible_ptr<dt_float32>(), N, C, \
IH, IW, OH, OW, param().scalar, stream); \
......@@ -51,11 +54,16 @@ void RemapBackwardMatImpl::exec(
#define support_dtype(dt) \
case DTypeTrait<dt>::enumv: { \
cb(dt, NCHW, CONSTANT); \
cb(dt, NCHW, REPLICATE); \
cb(dt, NCHW, REFLECT); \
cb(dt, NCHW, REFLECT_101); \
cb(dt, NCHW, WRAP); \
cb(dt, NCHW, CONSTANT, NEAREST); \
cb(dt, NCHW, REPLICATE, NEAREST); \
cb(dt, NCHW, REFLECT, NEAREST); \
cb(dt, NCHW, REFLECT_101, NEAREST); \
cb(dt, NCHW, WRAP, NEAREST); \
cb(dt, NCHW, CONSTANT, LINEAR); \
cb(dt, NCHW, REPLICATE, LINEAR); \
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
megdnn_throw("unsupported border type in remap cuda"); \
}
......
......@@ -53,7 +53,7 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> {
};
template <typename ctype, const uint32_t format, ::BorderMode bmode>
__global__ void kern_general(
__global__ void kern_general_linear(
const ctype* src, const float* map_xy, const ctype* diff,
float* __restrict grad, int C, int IH, int IW, int OH, int OW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -62,7 +62,6 @@ __global__ void kern_general(
diff += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
grad += blockIdx.z * 2 * OH * OW;
RoundingConverter<ctype> round_converter;
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
......@@ -86,23 +85,25 @@ __global__ void kern_general(
int a11 = GetSrcData<ctype, format, bmode>::get_index(
row + 1, col + 1, c, IH, IW, C);
dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u);
dv += ((a01 != -1) ? src[a01] : scalar) * (one - u);
dv -= ((a10 != -1) ? src[a10] : scalar) * u;
dv += ((a11 != -1) ? src[a11] : scalar) * u;
dv -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) * (one - u);
dv += ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * (one - u);
dv -= ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * u;
dv += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * u;
du -= ((a00 != -1) ? src[a00] : scalar) * (one - v);
du -= ((a01 != -1) ? src[a01] : scalar) * v;
du += ((a10 != -1) ? src[a10] : scalar) * (one - v);
du += ((a11 != -1) ? src[a11] : scalar) * v;
du -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) * (one - v);
du -= ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * v;
du += ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * (one - v);
du += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * v;
grad[oh * OW * 2 + ow * 2 + 0] += round_converter(hidden * dv);
grad[oh * OW * 2 + ow * 2 + 1] += round_converter(hidden * du);
grad[oh * OW * 2 + ow * 2 + 0] += hidden * dv;
grad[oh * OW * 2 + ow * 2 + 1] += hidden * du;
}
}
}
template <typename ctype, const uint32_t format, ::BorderMode bmode>
template <
typename ctype, const uint32_t format, ::BorderMode bmode,
::InterpolationMode imode>
void dispatch_backwardmat(
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) {
......@@ -115,8 +116,11 @@ void dispatch_backwardmat(
cuda_check(cudaMemsetAsync(
grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2, stream));
kern_general<ctype, format, bmode><<<blocks, threads, 0, stream>>>(
src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar);
if (imode == ::InterpolationMode::INTER_LINEAR) {
kern_general_linear<ctype, format, bmode><<<blocks, threads, 0, stream>>>(
src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar);
}
N -= curr_batch_size;
src += curr_batch_size * C * IH * IW;
......@@ -132,27 +136,35 @@ namespace megdnn {
namespace cuda {
namespace remap {
template <typename ctype, const uint32_t format, ::BorderMode bmode>
template <
typename ctype, const uint32_t format, ::BorderMode bmode,
::InterpolationMode imode>
void backwardmat_proxy(
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) {
dispatch_backwardmat<ctype, format, bmode>(
dispatch_backwardmat<ctype, format, bmode, imode>(
src, map_xy, diff, grad, N, C, IH, IW, OH, OW, scalar, stream);
after_kernel_launch();
}
#define INST(ctype, format, bmode) \
template void \
backwardmat_proxy<ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \
#define INST(ctype, format, bmode, imode) \
template void backwardmat_proxy< \
ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \
::InterpolationMode::imode>( \
const ctype*, const float*, const ctype*, float*, int, int, int, int, int, \
int, float, cudaStream_t);
#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT) \
INST(ctype, NCHW, BORDER_REPLICATE) \
INST(ctype, NCHW, BORDER_REFLECT) \
INST(ctype, NCHW, BORDER_REFLECT_101) \
INST(ctype, NCHW, BORDER_WRAP)
#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR)
FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
......
......@@ -21,17 +21,23 @@ namespace remap {
// all these kernels use LINEAR interpolation
template <typename ctype, const uint32_t format, ::BorderMode bmode>
template <
typename ctype, const uint32_t format, ::BorderMode bmode,
::InterpolationMode imode>
void forward_proxy(
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
int OH, int OW, float scalar, cudaStream_t stream);
template <typename ctype, const uint32_t format, ::BorderMode bmode>
template <
typename ctype, const uint32_t format, ::BorderMode bmode,
::InterpolationMode imode>
void backwarddata_proxy(
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream);
template <typename ctype, const uint32_t format, ::BorderMode bmode>
template <
typename ctype, const uint32_t format, ::BorderMode bmode,
::InterpolationMode imode>
void backwardmat_proxy(
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream);
......
......@@ -30,8 +30,9 @@ void RemapImpl::exec(
OW = map_xy.layout.shape[2];
megdnn_assert(
param().imode == param::Remap::InterpolationMode::LINEAR,
"only support LINEAR interpolationMode");
(param().imode == param::Remap::InterpolationMode::NEAREST) ||
(param().imode == param::Remap::InterpolationMode::LINEAR),
"only support NEAREST and LINEAR interpolationMode");
if (param().format == param::Remap::Format::NCHW) {
N = src.layout.shape[0];
......@@ -47,13 +48,15 @@ void RemapImpl::exec(
megdnn_throw("unsupported format, cuda remap");
}
#define cb(dt, _format, bmode) \
#define cb(dt, _format, bmode, inter_mode) \
if (param().format == param::Remap::Format::_format && \
param().border_type == param::Remap::BorderMode::bmode) { \
param().border_type == param::Remap::BorderMode::bmode && \
param().imode == param::Remap::InterpolationMode::inter_mode) { \
using ctype = DTypeTrait<dt>::ctype; \
remap::forward_proxy< \
ctype, param_enumv::Remap::Format::_format, \
::BorderMode::BORDER_##bmode>( \
::BorderMode::BORDER_##bmode, \
::InterpolationMode::INTER_##inter_mode>( \
src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \
dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, param().scalar, \
stream); \
......@@ -62,16 +65,26 @@ void RemapImpl::exec(
#define support_dtype(dt) \
case DTypeTrait<dt>::enumv: { \
cb(dt, NCHW, CONSTANT); \
cb(dt, NCHW, REPLICATE); \
cb(dt, NCHW, REFLECT); \
cb(dt, NCHW, REFLECT_101); \
cb(dt, NCHW, WRAP); \
cb(dt, NHWC, CONSTANT); \
cb(dt, NHWC, REPLICATE); \
cb(dt, NHWC, REFLECT); \
cb(dt, NHWC, REFLECT_101); \
cb(dt, NHWC, WRAP); \
cb(dt, NCHW, CONSTANT, NEAREST); \
cb(dt, NCHW, REPLICATE, NEAREST); \
cb(dt, NCHW, REFLECT, NEAREST); \
cb(dt, NCHW, REFLECT_101, NEAREST); \
cb(dt, NCHW, WRAP, NEAREST); \
cb(dt, NHWC, CONSTANT, NEAREST); \
cb(dt, NHWC, REPLICATE, NEAREST); \
cb(dt, NHWC, REFLECT, NEAREST); \
cb(dt, NHWC, REFLECT_101, NEAREST); \
cb(dt, NHWC, WRAP, NEAREST); \
cb(dt, NCHW, CONSTANT, LINEAR); \
cb(dt, NCHW, REPLICATE, LINEAR); \
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
cb(dt, NHWC, CONSTANT, LINEAR); \
cb(dt, NHWC, REPLICATE, LINEAR); \
cb(dt, NHWC, REFLECT, LINEAR); \
cb(dt, NHWC, REFLECT_101, LINEAR); \
cb(dt, NHWC, WRAP, LINEAR); \
megdnn_throw("unsupported border type in remap cuda"); \
}
......
......@@ -62,8 +62,23 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> {
}
};
template <typename ctype, ::BorderMode bmode>
__global__ void kern_general(
__device__ inline float round_half_to_even(float f) {
const float round_away_from_zero = round(f);
const float diff = round_away_from_zero - f;
if ((diff != 0.5f) && (diff != -0.5f)) {
return round_away_from_zero;
}
if (fmod(round_away_from_zero, 2.0f) == 0.0f) {
return round_away_from_zero;
}
return f - diff;
}
template <typename ctype, const uint32_t format, ::BorderMode bmode>
__global__ void kern_general_nearest(
const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -71,37 +86,22 @@ __global__ void kern_general(
sptr += blockIdx.z * C * IH * IW;
dst += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
RoundingConverter<ctype> round_converter;
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
int col = static_cast<int>(floor(index_col));
int row = static_cast<int>(floor(index_row));
float v = index_col - col;
float u = index_row - row;
int col = static_cast<int>(round_half_to_even(index_col));
int row = static_cast<int>(round_half_to_even(index_row));
for (int c = 0; c < C; ++c) {
ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get(
sptr, row + 0, col + 0, c, IH, IW, C, scalar);
ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get(
sptr, row + 0, col + 1, c, IH, IW, C, scalar);
ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get(
sptr, row + 1, col + 0, c, IH, IW, C, scalar);
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get(
sptr, row + 1, col + 1, c, IH, IW, C, scalar);
/* in remap, we use float as the type of intermediate result */
float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) +
static_cast<float>(a01) * (1.f - u) * v +
static_cast<float>(a10) * (1.f - v) * u +
static_cast<float>(a11) * u * v;
dst[get_offset<param_enumv::Remap::Format::NCHW>(oh, ow, c, OH, OW, C)] =
round_converter(result);
dst[get_offset<format>(oh, ow, c, OH, OW, C)] =
GetSrcData<ctype, format, bmode>::get(
sptr, row, col, c, IH, IW, C, scalar);
}
}
}
template <typename ctype, ::BorderMode bmode>
__global__ void kern_general_nhwc(
template <typename ctype, const uint32_t format, ::BorderMode bmode>
__global__ void kern_general_linear(
const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C,
int IH, int IW, int OH, int OW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -119,26 +119,27 @@ __global__ void kern_general_nhwc(
float v = index_col - col;
float u = index_row - row;
for (int c = 0; c < C; ++c) {
ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get(
ctype a00 = GetSrcData<ctype, format, bmode>::get(
sptr, row + 0, col + 0, c, IH, IW, C, scalar);
ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get(
ctype a01 = GetSrcData<ctype, format, bmode>::get(
sptr, row + 0, col + 1, c, IH, IW, C, scalar);
ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get(
ctype a10 = GetSrcData<ctype, format, bmode>::get(
sptr, row + 1, col + 0, c, IH, IW, C, scalar);
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get(
ctype a11 = GetSrcData<ctype, format, bmode>::get(
sptr, row + 1, col + 1, c, IH, IW, C, scalar);
/* in remap, we use float as the type of intermediate result */
float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) +
static_cast<float>(a01) * (1.f - u) * v +
static_cast<float>(a10) * (1.f - v) * u +
static_cast<float>(a11) * u * v;
dst[get_offset<param_enumv::Remap::Format::NHWC>(oh, ow, c, OH, OW, C)] =
round_converter(result);
dst[get_offset<format>(oh, ow, c, OH, OW, C)] = round_converter(result);
}
}
}
template <typename ctype, const uint32_t format, ::BorderMode bmode>
template <
typename ctype, const uint32_t format, ::BorderMode bmode,
::InterpolationMode imode>
void dispatch_forward(
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
int OH, int OW, float scalar, cudaStream_t stream) {
......@@ -150,11 +151,11 @@ void dispatch_forward(
dim3 threads(BX, BY);
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);
if (format == param_enumv::Remap::Format::NCHW) {
kern_general<ctype, bmode><<<blocks, threads, 0, stream>>>(
if (imode == ::InterpolationMode::INTER_NEAREST) {
kern_general_nearest<ctype, format, bmode><<<blocks, threads, 0, stream>>>(
src, map_xy, dst, C, IH, IW, OH, OW, scalar);
} else if (format == param_enumv::Remap::Format::NHWC) {
kern_general_nhwc<ctype, bmode><<<blocks, threads, 0, stream>>>(
} else if (imode == ::InterpolationMode::INTER_LINEAR) {
kern_general_linear<ctype, format, bmode><<<blocks, threads, 0, stream>>>(
src, map_xy, dst, C, IH, IW, OH, OW, scalar);
}
......@@ -171,32 +172,45 @@ namespace megdnn {
namespace cuda {
namespace remap {
template <typename ctype, const uint32_t format, ::BorderMode bmode>
template <
typename ctype, const uint32_t format, ::BorderMode bmode,
::InterpolationMode imode>
void forward_proxy(
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
int OH, int OW, float scalar, cudaStream_t stream) {
dispatch_forward<ctype, format, bmode>(
dispatch_forward<ctype, format, bmode, imode>(
src, map_xy, dst, N, C, IH, IW, OH, OW, scalar, stream);
after_kernel_launch();
}
#define INST(ctype, format, bmode) \
template void \
forward_proxy<ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \
#define INST(ctype, format, bmode, imode) \
template void forward_proxy< \
ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \
::InterpolationMode::imode>( \
const ctype*, const float*, ctype*, int, int, int, int, int, int, float, \
cudaStream_t);
#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT) \
INST(ctype, NCHW, BORDER_REPLICATE) \
INST(ctype, NCHW, BORDER_REFLECT) \
INST(ctype, NCHW, BORDER_REFLECT_101) \
INST(ctype, NCHW, BORDER_WRAP) \
INST(ctype, NHWC, BORDER_CONSTANT) \
INST(ctype, NHWC, BORDER_REPLICATE) \
INST(ctype, NHWC, BORDER_REFLECT) \
INST(ctype, NHWC, BORDER_REFLECT_101) \
INST(ctype, NHWC, BORDER_WRAP)
#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \
INST(ctype, NHWC, BORDER_CONSTANT, INTER_NEAREST) \
INST(ctype, NHWC, BORDER_REPLICATE, INTER_NEAREST) \
INST(ctype, NHWC, BORDER_REFLECT, INTER_NEAREST) \
INST(ctype, NHWC, BORDER_REFLECT_101, INTER_NEAREST) \
INST(ctype, NHWC, BORDER_WRAP, INTER_NEAREST) \
INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \
INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) \
INST(ctype, NHWC, BORDER_CONSTANT, INTER_LINEAR) \
INST(ctype, NHWC, BORDER_REPLICATE, INTER_LINEAR) \
INST(ctype, NHWC, BORDER_REFLECT, INTER_LINEAR) \
INST(ctype, NHWC, BORDER_REFLECT_101, INTER_LINEAR) \
INST(ctype, NHWC, BORDER_WRAP, INTER_LINEAR)
FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))
......
......@@ -36,6 +36,12 @@ inline int get_offset<param::Remap::Format::NHWC>(
return height * w * c + width * c + channel;
}
template <>
inline int get_offset<param::Remap::Format::NHWCD4>(
int height, int width, int channel, int, int w, int c) {
return ((height * c + channel) * w + width) * 4;
}
template <
typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
......@@ -80,8 +86,12 @@ void remap_LINEAR(
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
int OH, int OW, float scalar) {
RoundingConverter<ctype> round_converter;
for (int n = 0; n < N;
++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) {
size_t c_scale = 1;
if (format == param::Remap::Format::NHWCD4) {
c_scale = 4;
}
for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW,
dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) {
for (int h = 0; h < OH; ++h) {
for (int w = 0; w < OW; ++w) {
float index_col = map_xy[h * OW * 2 + w * 2 + 0];
......@@ -92,18 +102,102 @@ void remap_LINEAR(
float u = index_row - row; // alphah
const float one = 1.f;
for (int c = 0; c < C; ++c) {
ctype a00 = GetSrcData<ctype, format, bordertype>::get(
src, row + 0, col + 0, c, IH, IW, C, scalar);
ctype a01 = GetSrcData<ctype, format, bordertype>::get(
src, row + 0, col + 1, c, IH, IW, C, scalar);
ctype a10 = GetSrcData<ctype, format, bordertype>::get(
src, row + 1, col + 0, c, IH, IW, C, scalar);
ctype a11 = GetSrcData<ctype, format, bordertype>::get(
src, row + 1, col + 1, c, IH, IW, C, scalar);
dst[get_offset<format>(h, w, c, OH, OW, C)] = round_converter(
a00 * (one - v) * (one - u) + a01 * (one - u) * v +
a10 * (one - v) * u + a11 * u * v);
if (format == param::Remap::Format::NHWCD4) {
int idx00 = GetSrcData<ctype, format, bordertype>::get_index(
row + 0, col + 0, c, IH, IW, C);
int idx01 = GetSrcData<ctype, format, bordertype>::get_index(
row + 0, col + 1, c, IH, IW, C);
int idx10 = GetSrcData<ctype, format, bordertype>::get_index(
row + 1, col + 0, c, IH, IW, C);
int idx11 = GetSrcData<ctype, format, bordertype>::get_index(
row + 1, col + 1, c, IH, IW, C);
for (int c_inner = 0; c_inner < 4; ++c_inner) {
ctype a00 = (idx00 != -1) ? src[idx00 + c_inner]
: round_converter(scalar);
ctype a01 = (idx01 != -1) ? src[idx01 + c_inner]
: round_converter(scalar);
ctype a10 = (idx10 != -1) ? src[idx10 + c_inner]
: round_converter(scalar);
ctype a11 = (idx11 != -1) ? src[idx11 + c_inner]
: round_converter(scalar);
dst[get_offset<format>(h, w, c, OH, OW, C) + c_inner] =
round_converter(
a00 * (one - v) * (one - u) +
a01 * (one - u) * v + a10 * (one - v) * u +
a11 * u * v);
}
} else {
ctype a00 = GetSrcData<ctype, format, bordertype>::get(
src, row + 0, col + 0, c, IH, IW, C, scalar);
ctype a01 = GetSrcData<ctype, format, bordertype>::get(
src, row + 0, col + 1, c, IH, IW, C, scalar);
ctype a10 = GetSrcData<ctype, format, bordertype>::get(
src, row + 1, col + 0, c, IH, IW, C, scalar);
ctype a11 = GetSrcData<ctype, format, bordertype>::get(
src, row + 1, col + 1, c, IH, IW, C, scalar);
dst[get_offset<format>(h, w, c, OH, OW, C)] = round_converter(
a00 * (one - v) * (one - u) + a01 * (one - u) * v +
a10 * (one - v) * u + a11 * u * v);
}
}
}
}
}
}
namespace {
inline float round_half_to_even(float f) {
const float round_away_from_zero = std::round(f);
const float diff = round_away_from_zero - f;
if ((diff != 0.5f) && (diff != -0.5f)) {
return round_away_from_zero;
}
if (std::fmod(round_away_from_zero, 2.0f) == 0.0f) {
return round_away_from_zero;
}
return f - diff;
}
} // anonymous namespace
template <
typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
void remap_NEAREST(
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW,
int OH, int OW, float scalar) {
RoundingConverter<ctype> round_converter;
size_t c_scale = 1;
if (format == param::Remap::Format::NHWCD4) {
c_scale = 4;
}
for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW,
dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) {
for (int h = 0; h < OH; ++h) {
for (int w = 0; w < OW; ++w) {
float index_col = map_xy[h * OW * 2 + w * 2 + 0];
float index_row = map_xy[h * OW * 2 + w * 2 + 1];
int col = static_cast<int>(round_half_to_even(index_col));
int row = static_cast<int>(round_half_to_even(index_row));
for (int c = 0; c < C; ++c) {
if (format == param::Remap::Format::NHWCD4) {
int idx = GetSrcData<ctype, format, bordertype>::get_index(
row, col, c, IH, IW, C);
for (int c_inner = 0; c_inner < 4; ++c_inner) {
dst[get_offset<format>(h, w, c, OH, OW, C) + c_inner] =
(idx != -1) ? (src[idx + c_inner])
: round_converter(scalar);
}
} else {
dst[get_offset<format>(h, w, c, OH, OW, C)] =
GetSrcData<ctype, format, bordertype>::get(
src, row, col, c, IH, IW, C, scalar);
}
}
}
}
......@@ -161,13 +255,40 @@ void remap_LINEAR_backwarddata(
}
}
template <
typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
void remap_NEAREST_backwarddata(
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH,
int IW, int OH, int OW) {
std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW);
for (int n = 0; n < N;
++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) {
for (int h = 0; h < OH; ++h) {
for (int w = 0; w < OW; ++w) {
float index_col = map_xy[h * OW * 2 + w * 2 + 0];
float index_row = map_xy[h * OW * 2 + w * 2 + 1];
int col = static_cast<int>(round_half_to_even(index_col));
int row = static_cast<int>(round_half_to_even(index_row));
for (int c = 0; c < C; ++c) {
ctype hidden = diff[get_offset<format>(h, w, c, OH, OW, C)];
int idx = GetSrcData<ctype, format, bordertype>::get_index(
row, col, c, IH, IW, C);
if (idx != -1) {
grad[idx] += hidden;
}
}
}
}
}
}
template <
typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
void remap_LINEAR_backwardmat(
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N,
int C, int IH, int IW, int OH, int OW, float scalar) {
RoundingConverter<ctype> round_converter;
std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW);
for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW,
map_xy += OH * OW * 2, grad += OH * OW * 2) {
......@@ -194,24 +315,38 @@ void remap_LINEAR_backwardmat(
int a11 = GetSrcData<ctype, format, bordertype>::get_index(
row + 1, col + 1, c, IH, IW, C);
dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u);
dv += ((a01 != -1) ? src[a01] : scalar) * (one - u);
dv -= ((a10 != -1) ? src[a10] : scalar) * u;
dv += ((a11 != -1) ? src[a11] : scalar) * u;
du -= ((a00 != -1) ? src[a00] : scalar) * (one - v);
du -= ((a01 != -1) ? src[a01] : scalar) * v;
du += ((a10 != -1) ? src[a10] : scalar) * (one - v);
du += ((a11 != -1) ? src[a11] : scalar) * v;
grad[h * OW * 2 + w * 2 + 0] += round_converter(hidden * dv);
grad[h * OW * 2 + w * 2 + 1] += round_converter(hidden * du);
dv -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) *
(one - u);
dv += ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) *
(one - u);
dv -= ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * u;
dv += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * u;
du -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) *
(one - v);
du -= ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * v;
du += ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) *
(one - v);
du += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * v;
grad[h * OW * 2 + w * 2 + 0] += hidden * dv;
grad[h * OW * 2 + w * 2 + 1] += hidden * du;
}
}
}
}
}
template <
typename ctype, param::Remap::Format format,
param::Remap::BorderMode bordertype>
void remap_NEAREST_backwardmat(
const ctype*, const float*, const ctype*, float* grad, int N, int, int, int,
int OH, int OW, float) {
std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW);
return;
}
} // namespace
void RemapImpl::exec(
......@@ -229,6 +364,11 @@ void RemapImpl::exec(
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
} else if (param().format == param::Remap::Format::NHWCD4) {
N = src.layout.shape[0];
C = src.layout.shape[2];
IH = src.layout.shape[1];
IW = src.layout.shape[3];
} else {
megdnn_throw("unsupported format");
}
......@@ -255,11 +395,31 @@ void RemapImpl::exec(
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
cb(dt, NHWCD4, CONSTANT, LINEAR); \
cb(dt, NHWCD4, REPLICATE, LINEAR); \
cb(dt, NHWCD4, REFLECT, LINEAR); \
cb(dt, NHWCD4, REFLECT_101, LINEAR); \
cb(dt, NHWCD4, WRAP, LINEAR); \
cb(dt, NHWC, CONSTANT, LINEAR); \
cb(dt, NHWC, REPLICATE, LINEAR); \
cb(dt, NHWC, REFLECT, LINEAR); \
cb(dt, NHWC, REFLECT_101, LINEAR); \
cb(dt, NHWC, WRAP, LINEAR); \
cb(dt, NCHW, CONSTANT, NEAREST); \
cb(dt, NCHW, REPLICATE, NEAREST); \
cb(dt, NCHW, REFLECT, NEAREST); \
cb(dt, NCHW, REFLECT_101, NEAREST); \
cb(dt, NCHW, WRAP, NEAREST); \
cb(dt, NHWCD4, CONSTANT, NEAREST); \
cb(dt, NHWCD4, REPLICATE, NEAREST); \
cb(dt, NHWCD4, REFLECT, NEAREST); \
cb(dt, NHWCD4, REFLECT_101, NEAREST); \
cb(dt, NHWCD4, WRAP, NEAREST); \
cb(dt, NHWC, CONSTANT, NEAREST); \
cb(dt, NHWC, REPLICATE, NEAREST); \
cb(dt, NHWC, REFLECT, NEAREST); \
cb(dt, NHWC, REFLECT_101, NEAREST); \
cb(dt, NHWC, WRAP, NEAREST); \
megdnn_throw( \
"format, border type or imode is incorrect in remap navie " \
"with dtype = " #dt); \
......@@ -313,6 +473,11 @@ void RemapBackwardDataImpl::exec(
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
cb(dt, NCHW, CONSTANT, NEAREST); \
cb(dt, NCHW, REPLICATE, NEAREST); \
cb(dt, NCHW, REFLECT, NEAREST); \
cb(dt, NCHW, REFLECT_101, NEAREST); \
cb(dt, NCHW, WRAP, NEAREST); \
megdnn_throw( \
"format, border type or imode is incorrect in remap navie " \
"with dtype = " #dt); \
......@@ -365,6 +530,11 @@ void RemapBackwardMatImpl::exec(
cb(dt, NCHW, REFLECT, LINEAR); \
cb(dt, NCHW, REFLECT_101, LINEAR); \
cb(dt, NCHW, WRAP, LINEAR); \
cb(dt, NCHW, CONSTANT, NEAREST); \
cb(dt, NCHW, REPLICATE, NEAREST); \
cb(dt, NCHW, REFLECT, NEAREST); \
cb(dt, NCHW, REFLECT_101, NEAREST); \
cb(dt, NCHW, WRAP, NEAREST); \
megdnn_throw( \
"format, border type or imode is incorrect in remap navie " \
"with dtype = " #dt); \
......
......@@ -34,53 +34,91 @@ static inline std::vector<TestArg> get_nchw_args() {
param::Remap param;
std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NCHW};
std::vector<param::Remap::InterpolationMode> interp_mode_vec = {
param::Remap::InterpolationMode::NEAREST,
param::Remap::InterpolationMode::LINEAR};
std::vector<param::Remap::BorderMode> border_mode_vec = {
param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101,
param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP,
param::Remap::BorderMode::REPLICATE};
// current do not test this.
std::vector<float> scalar;
for (auto fmt : format_vec) {
for (auto border_type : border_mode_vec) {
param.format = fmt;
param.border_type = border_type;
args.emplace_back(
param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2},
TensorShape{70000, 1, 2, 2});
args.emplace_back(
param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2},
TensorShape{1, 1, 2, 2});
args.emplace_back(
param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2},
TensorShape{1, 3, 2, 2});
args.emplace_back(
param, TensorShape{1, 10, 100, 100}, TensorShape{1, 100, 100, 2},
TensorShape{1, 10, 100, 100});
args.emplace_back(
param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2},
TensorShape{2, 4, 100, 200});
args.emplace_back(
param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2},
TensorShape{2, 4, 20, 30});
args.emplace_back(
param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2},
TensorShape{2, 4, 20, 30});
for (auto interp_mode : interp_mode_vec) {
for (auto border_type : border_mode_vec) {
param.format = fmt;
param.imode = interp_mode;
param.border_type = border_type;
args.emplace_back(
param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2},
TensorShape{70000, 1, 2, 2});
args.emplace_back(
param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2},
TensorShape{1, 1, 2, 2});
args.emplace_back(
param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2},
TensorShape{1, 3, 2, 2});
args.emplace_back(
param, TensorShape{1, 10, 100, 100},
TensorShape{1, 100, 100, 2}, TensorShape{1, 10, 100, 100});
args.emplace_back(
param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2},
TensorShape{2, 4, 100, 200});
args.emplace_back(
param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2},
TensorShape{2, 4, 20, 30});
args.emplace_back(
param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2},
TensorShape{2, 4, 20, 30});
}
}
}
return args;
}
static inline std::vector<TestArg> get_nhwcd4_args() {
std::vector<TestArg> args;
param::Remap param;
param.format = param::Remap::Format::NHWCD4;
param.imode = param::Remap::InterpolationMode::LINEAR;
param.border_type = param::Remap::BorderMode::CONSTANT;
// FIXME: when fractional part of bval is not zero, naive and opencl bankend may
// have different rounding result
param.scalar = 77;
args.emplace_back(
param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2},
TensorShape{2, 4, 1, 6, 4});
args.emplace_back(
param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2},
TensorShape{2, 2, 1, 3, 4});
param.imode = param::Remap::InterpolationMode::NEAREST;
args.emplace_back(
param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2},
TensorShape{2, 4, 1, 6, 4});
args.emplace_back(
param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2},
TensorShape{2, 2, 1, 3, 4});
return args;
}
static inline std::vector<TestArg> get_nhwc_args() {
std::vector<TestArg> args;
param::Remap param;
std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NHWC};
std::vector<param::Remap::InterpolationMode> interp_mode_vec = {
param::Remap::InterpolationMode::NEAREST,
param::Remap::InterpolationMode::LINEAR};
std::vector<param::Remap::BorderMode> border_mode_vec = {
param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101,
param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP,
......@@ -88,41 +126,44 @@ static inline std::vector<TestArg> get_nhwc_args() {
// current do not test this.
std::vector<float> scalar;
for (auto fmt : format_vec) {
for (auto border_type : border_mode_vec) {
param.format = fmt;
param.border_type = border_type;
param.scalar = 12.f;
args.emplace_back(
param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2},
TensorShape{70000, 2, 2, 1});
args.emplace_back(
param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2},
TensorShape{1, 2, 2, 1});
args.emplace_back(
param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2},
TensorShape{1, 2, 2, 3});
args.emplace_back(
param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2},
TensorShape{1, 2, 2, 66});
args.emplace_back(
param, TensorShape{1, 100, 100, 10}, TensorShape{1, 100, 100, 2},
TensorShape{1, 100, 100, 10});
args.emplace_back(
param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2},
TensorShape{2, 100, 200, 4});
args.emplace_back(
param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2},
TensorShape{2, 20, 30, 4});
args.emplace_back(
param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2},
TensorShape{2, 20, 30, 4});
for (auto interp_mode : interp_mode_vec) {
for (auto border_type : border_mode_vec) {
param.format = fmt;
param.imode = interp_mode;
param.border_type = border_type;
param.scalar = 12.f;
args.emplace_back(
param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2},
TensorShape{70000, 2, 2, 1});
args.emplace_back(
param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2},
TensorShape{1, 2, 2, 1});
args.emplace_back(
param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2},
TensorShape{1, 2, 2, 3});
args.emplace_back(
param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2},
TensorShape{1, 2, 2, 66});
args.emplace_back(
param, TensorShape{1, 100, 100, 10},
TensorShape{1, 100, 100, 2}, TensorShape{1, 100, 100, 10});
args.emplace_back(
param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2},
TensorShape{2, 100, 200, 4});
args.emplace_back(
param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2},
TensorShape{2, 20, 30, 4});
args.emplace_back(
param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2},
TensorShape{2, 20, 30, 4});
}
}
}
return args;
......
......@@ -58,6 +58,11 @@ static void set_nchw_args(std::vector<TestArg>& args) {
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8});
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3});
args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
param.imode = param::Resize::InterpolationMode::NEAREST;
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8});
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3});
args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4});
}
static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) {
......@@ -75,6 +80,25 @@ static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) {
return args;
}
static inline std::vector<TestArg> get_nhwc_args() {
std::vector<TestArg> args;
param::Resize param;
param.format = param::Resize::Format::NHWC;
param.imode = param::Resize::InterpolationMode::LINEAR;
args.emplace_back(param, TensorShape{2, 3, 4, 2}, TensorShape{2, 6, 8, 2});
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2});
args.emplace_back(param, TensorShape{1, 6, 8, 2}, TensorShape{1, 3, 4, 2});
param.imode = param::Resize::InterpolationMode::NEAREST;
args.emplace_back(param, TensorShape{2, 3, 4, 2}, TensorShape{2, 6, 8, 2});
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2});
args.emplace_back(param, TensorShape{1, 6, 8, 2}, TensorShape{1, 3, 4, 2});
return args;
}
static inline std::vector<TestArg> get_nhwcd4_args() {
std::vector<TestArg> args;
......@@ -83,6 +107,9 @@ static inline std::vector<TestArg> get_nhwcd4_args() {
param.imode = param::Resize::InterpolationMode::LINEAR;
args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4});
args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4});
param.imode = param::Resize::InterpolationMode::NEAREST;
args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4});
args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4});
return args;
}
......
......@@ -351,7 +351,7 @@ def remap(
"reflect_101", "wrap".
scalar: value used in case of a constant border. Default: 0
interp_mode: interpolation methods.
Default: "linear". Currently only support "linear" mode.
Default: "linear". Currently also support "nearest" mode.
Returns:
output tensor.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册