diff --git a/dnn/src/common/remap.cpp b/dnn/src/common/remap.cpp index 6c2e5a473ed945f8dcad39ff4bb7fdd5cec32da9..9d543bb397736ed2d600500ec42c9bbd723f790b 100644 --- a/dnn/src/common/remap.cpp +++ b/dnn/src/common/remap.cpp @@ -18,21 +18,22 @@ namespace megdnn { void RemapBase::deduce_layout_fwd( const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst) { - dst.dtype = src.dtype; - dst.ndim = src.ndim; - dst.shape[0] = src.shape[0]; - size_t height_index, channel_index; + size_t n = src.shape[0]; + size_t c, oh, ow; + oh = map_xy.shape[1]; + ow = map_xy.shape[2]; if (param().format == param::Remap::Format::NHWC) { - height_index = 1; - channel_index = 3; + c = src.shape[3]; + dst = TensorLayout(TensorShape({n, oh, ow, c}), src.dtype); + } else if (param().format == param::Remap::Format::NCHW) { + c = src.shape[1]; + dst = TensorLayout(TensorShape{n, c, oh, ow}, src.dtype, src.format); + } else if (param().format == param::Remap::Format::NHWCD4) { + c = src.shape[2]; + dst = TensorLayout{{n, oh, c, ow, 4}, src.dtype, src.format}; } else { - megdnn_assert(param().format == param::Remap::Format::NCHW); - height_index = 2; - channel_index = 1; + megdnn_throw("unsupport format"); } - dst.shape[height_index] = map_xy.shape[1]; - dst.shape[height_index + 1] = map_xy.shape[2]; - dst.shape[channel_index] = src.shape[channel_index]; } void RemapBase::check_layout_fwd( @@ -42,7 +43,7 @@ void RemapBase::check_layout_fwd( megdnn_layout_msg(dst); }; MEGDNN_MARK_USED_VAR(errmsg); - megdnn_assert(src.ndim == map_xy.ndim && src.ndim == dst.ndim && src.ndim == 4); + megdnn_assert(src.ndim == dst.ndim); megdnn_assert(dst.dtype == src.dtype); megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str()); megdnn_assert(map_xy.shape[3] == 2); @@ -64,10 +65,13 @@ void RemapBase::check_layout_fwd( megdnn_assert( dst.shape[2] == map_xy.shape[1] && dst.shape[3] == map_xy.shape[2], "%s", errmsg().c_str()); + } else if (param().format == param::Remap::Format::NHWCD4) { + megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str()); + megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str()); + megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); + megdnn_assert(param().format == Param::Format::NHWCD4); } else { - megdnn_throw( - "currently do not support other param.format except NHWC and " - "NCHW"); + megdnn_throw("unsupport format"); } } diff --git a/dnn/src/cuda/remap/backward_data.cpp b/dnn/src/cuda/remap/backward_data.cpp index a2ce89505fec850499cc5501974bb947ae1109d1..26b9933aab56453dd30924b6ccd614749c0cd609 100644 --- a/dnn/src/cuda/remap/backward_data.cpp +++ b/dnn/src/cuda/remap/backward_data.cpp @@ -22,8 +22,9 @@ void RemapBackwardDataImpl::exec( _megdnn_workspace workspace) { check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size); megdnn_assert( - param().imode == param::Remap::InterpolationMode::LINEAR, - "only support LINEAR interpolationMode"); + (param().imode == param::Remap::InterpolationMode::NEAREST) || + (param().imode == param::Remap::InterpolationMode::LINEAR), + "only support NEAREST and LINEAR interpolationMode"); megdnn_assert( param().format == param::Remap::Format::NCHW, "only support NCHW format for remap backward"); @@ -36,13 +37,15 @@ void RemapBackwardDataImpl::exec( OH = map_xy.layout.shape[1]; OW = map_xy.layout.shape[2]; -#define cb(dt, _format, bmode) \ +#define cb(dt, _format, bmode, inter_mode) \ if (param().format == param::Remap::Format::_format && \ - param().border_type == param::Remap::BorderMode::bmode) { \ + param().border_type == param::Remap::BorderMode::bmode && \ + param().imode == param::Remap::InterpolationMode::inter_mode) { \ using ctype = DTypeTrait
::ctype; \ remap::backwarddata_proxy< \ ctype, param_enumv::Remap::Format::_format, \ - ::BorderMode::BORDER_##bmode>( \ + ::BorderMode::BORDER_##bmode, \ + ::InterpolationMode::INTER_##inter_mode>( \ grad.compatible_ptr(), map_xy.compatible_ptr(), \ diff.compatible_ptr(), N, C, IH, IW, OH, OW, stream); \ break; \ @@ -50,11 +53,16 @@ void RemapBackwardDataImpl::exec( #define support_dtype(dt) \ case DTypeTrait
::enumv: { \ - cb(dt, NCHW, CONSTANT); \ - cb(dt, NCHW, REPLICATE); \ - cb(dt, NCHW, REFLECT); \ - cb(dt, NCHW, REFLECT_101); \ - cb(dt, NCHW, WRAP); \ + cb(dt, NCHW, CONSTANT, NEAREST); \ + cb(dt, NCHW, REPLICATE, NEAREST); \ + cb(dt, NCHW, REFLECT, NEAREST); \ + cb(dt, NCHW, REFLECT_101, NEAREST); \ + cb(dt, NCHW, WRAP, NEAREST); \ + cb(dt, NCHW, CONSTANT, LINEAR); \ + cb(dt, NCHW, REPLICATE, LINEAR); \ + cb(dt, NCHW, REFLECT, LINEAR); \ + cb(dt, NCHW, REFLECT_101, LINEAR); \ + cb(dt, NCHW, WRAP, LINEAR); \ megdnn_throw("unsupported border type in remap cuda"); \ } diff --git a/dnn/src/cuda/remap/backward_data.cu b/dnn/src/cuda/remap/backward_data.cu index 662f9ac62ad909b15d1250207af72517fba69cf5..0074f8bb8da7134dbe0f5b14efc9e60795721922 100644 --- a/dnn/src/cuda/remap/backward_data.cu +++ b/dnn/src/cuda/remap/backward_data.cu @@ -52,8 +52,49 @@ struct GetSrcData { } }; +__device__ inline float round_half_to_even(float f) { + const float round_away_from_zero = round(f); + const float diff = round_away_from_zero - f; + + if ((diff != 0.5f) && (diff != -0.5f)) { + return round_away_from_zero; + } + + if (fmod(round_away_from_zero, 2.0f) == 0.0f) { + return round_away_from_zero; + } + + return f - diff; +} + +template +__global__ void kern_general_nearest( + ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH, + int IW, int OH, int OW) { + int ow = blockIdx.x * blockDim.x + threadIdx.x; + int oh = blockIdx.y * blockDim.y + threadIdx.y; + grad += blockIdx.z * C * IH * IW; + diff += blockIdx.z * C * OH * OW; + map_xy += blockIdx.z * 2 * OH * OW; + + if (ow < OW && oh < OH) { + float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; + float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; + int col = static_cast(round_half_to_even(index_col)); + int row = static_cast(round_half_to_even(index_row)); + for (int c = 0; c < C; ++c) { + ctype hidden = diff[get_offset(oh, ow, c, OH, OW, C)]; + int idx = + GetSrcData::get_index(row, col, c, IH, IW, C); + if (idx != -1) { + atomic_add(grad + idx, hidden); + } + } + } +} + template -__global__ void kern_general( +__global__ void kern_general_linear( ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH, int IW, int OH, int OW) { int ow = blockIdx.x * blockDim.x + threadIdx.x; @@ -93,8 +134,8 @@ __global__ void kern_general( atomic_add(grad + a10, round_converter(u * (one - v) * hidden)); } - int a11 = GetSrcData:: - get_index(row + 1, col + 1, c, IH, IW, C); + int a11 = GetSrcData::get_index( + row + 1, col + 1, c, IH, IW, C); if (a11 != -1) { atomic_add(grad + a11, round_converter(u * v * hidden)); } @@ -102,7 +143,9 @@ __global__ void kern_general( } } -template +template < + typename ctype, const uint32_t format, ::BorderMode bmode, + ::InterpolationMode imode> void dispatch_backwarddata( ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, int IW, int OH, int OW, cudaStream_t stream) { @@ -115,8 +158,13 @@ void dispatch_backwarddata( cuda_check(cudaMemsetAsync( grad, 0, sizeof(ctype) * curr_batch_size * C * IH * IW, stream)); - kern_general - <<>>(grad, map_xy, diff, C, IH, IW, OH, OW); + if (imode == ::InterpolationMode::INTER_NEAREST) { + kern_general_nearest<<>>( + grad, map_xy, diff, C, IH, IW, OH, OW); + } else if (imode == ::InterpolationMode::INTER_LINEAR) { + kern_general_linear<<>>( + grad, map_xy, diff, C, IH, IW, OH, OW); + } N -= curr_batch_size; grad += curr_batch_size * C * IH * IW; @@ -131,27 +179,35 @@ namespace megdnn { namespace cuda { namespace remap { -template +template < + typename ctype, const uint32_t format, ::BorderMode bmode, + ::InterpolationMode imode> void backwarddata_proxy( ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, int IW, int OH, int OW, cudaStream_t stream) { - dispatch_backwarddata( + dispatch_backwarddata( grad, map_xy, diff, N, C, IH, IW, OH, OW, stream); after_kernel_launch(); } -#define INST(ctype, format, bmode) \ +#define INST(ctype, format, bmode, imode) \ template void backwarddata_proxy< \ - ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \ + ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \ + ::InterpolationMode::imode>( \ ctype*, const float*, const ctype*, int, int, int, int, int, int, \ cudaStream_t); -#define FOR_FORMAT_BMODE(ctype) \ - INST(ctype, NCHW, BORDER_CONSTANT) \ - INST(ctype, NCHW, BORDER_REPLICATE) \ - INST(ctype, NCHW, BORDER_REFLECT) \ - INST(ctype, NCHW, BORDER_REFLECT_101) \ - INST(ctype, NCHW, BORDER_WRAP) +#define FOR_FORMAT_BMODE(ctype) \ + INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) FOR_FORMAT_BMODE(float) DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) diff --git a/dnn/src/cuda/remap/backward_mat.cpp b/dnn/src/cuda/remap/backward_mat.cpp index e14d9b09f8a140a6724dfee9c11c3ef90869bc22..c1dcc53cbb5eb1751de906feea86b13501478ccc 100644 --- a/dnn/src/cuda/remap/backward_mat.cpp +++ b/dnn/src/cuda/remap/backward_mat.cpp @@ -22,8 +22,9 @@ void RemapBackwardMatImpl::exec( _megdnn_tensor_out grad, _megdnn_workspace workspace) { check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, workspace.size); megdnn_assert( - param().imode == param::Remap::InterpolationMode::LINEAR, - "only support LINEAR interpolationMode"); + (param().imode == param::Remap::InterpolationMode::NEAREST) || + (param().imode == param::Remap::InterpolationMode::LINEAR), + "only support NEAREST and LINEAR interpolationMode"); megdnn_assert( param().format == param::Remap::Format::NCHW, "only support NCHW format for remap backward"); @@ -36,13 +37,15 @@ void RemapBackwardMatImpl::exec( OH = map_xy.layout.shape[1]; OW = map_xy.layout.shape[2]; -#define cb(dt, _format, bmode) \ +#define cb(dt, _format, bmode, inter_mode) \ if (param().format == param::Remap::Format::_format && \ - param().border_type == param::Remap::BorderMode::bmode) { \ + param().border_type == param::Remap::BorderMode::bmode && \ + param().imode == param::Remap::InterpolationMode::inter_mode) { \ using ctype = DTypeTrait
::ctype; \ remap::backwardmat_proxy< \ ctype, param_enumv::Remap::Format::_format, \ - ::BorderMode::BORDER_##bmode>( \ + ::BorderMode::BORDER_##bmode, \ + ::InterpolationMode::INTER_##inter_mode>( \ src.compatible_ptr(), map_xy.compatible_ptr(), \ diff.compatible_ptr(), grad.compatible_ptr(), N, C, \ IH, IW, OH, OW, param().scalar, stream); \ @@ -51,11 +54,16 @@ void RemapBackwardMatImpl::exec( #define support_dtype(dt) \ case DTypeTrait
::enumv: { \ - cb(dt, NCHW, CONSTANT); \ - cb(dt, NCHW, REPLICATE); \ - cb(dt, NCHW, REFLECT); \ - cb(dt, NCHW, REFLECT_101); \ - cb(dt, NCHW, WRAP); \ + cb(dt, NCHW, CONSTANT, NEAREST); \ + cb(dt, NCHW, REPLICATE, NEAREST); \ + cb(dt, NCHW, REFLECT, NEAREST); \ + cb(dt, NCHW, REFLECT_101, NEAREST); \ + cb(dt, NCHW, WRAP, NEAREST); \ + cb(dt, NCHW, CONSTANT, LINEAR); \ + cb(dt, NCHW, REPLICATE, LINEAR); \ + cb(dt, NCHW, REFLECT, LINEAR); \ + cb(dt, NCHW, REFLECT_101, LINEAR); \ + cb(dt, NCHW, WRAP, LINEAR); \ megdnn_throw("unsupported border type in remap cuda"); \ } diff --git a/dnn/src/cuda/remap/backward_mat.cu b/dnn/src/cuda/remap/backward_mat.cu index f0f6498a0fb07647772c430b83f9f4ee662dc404..882bd8e2153932e019646a88278cf3f97d949372 100644 --- a/dnn/src/cuda/remap/backward_mat.cu +++ b/dnn/src/cuda/remap/backward_mat.cu @@ -53,7 +53,7 @@ struct GetSrcData { }; template -__global__ void kern_general( +__global__ void kern_general_linear( const ctype* src, const float* map_xy, const ctype* diff, float* __restrict grad, int C, int IH, int IW, int OH, int OW, float scalar) { int ow = blockIdx.x * blockDim.x + threadIdx.x; @@ -62,7 +62,6 @@ __global__ void kern_general( diff += blockIdx.z * C * OH * OW; map_xy += blockIdx.z * 2 * OH * OW; grad += blockIdx.z * 2 * OH * OW; - RoundingConverter round_converter; if (ow < OW && oh < OH) { float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; @@ -86,23 +85,25 @@ __global__ void kern_general( int a11 = GetSrcData::get_index( row + 1, col + 1, c, IH, IW, C); - dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u); - dv += ((a01 != -1) ? src[a01] : scalar) * (one - u); - dv -= ((a10 != -1) ? src[a10] : scalar) * u; - dv += ((a11 != -1) ? src[a11] : scalar) * u; + dv -= ((a00 != -1) ? static_cast(src[a00]) : scalar) * (one - u); + dv += ((a01 != -1) ? static_cast(src[a01]) : scalar) * (one - u); + dv -= ((a10 != -1) ? static_cast(src[a10]) : scalar) * u; + dv += ((a11 != -1) ? static_cast(src[a11]) : scalar) * u; - du -= ((a00 != -1) ? src[a00] : scalar) * (one - v); - du -= ((a01 != -1) ? src[a01] : scalar) * v; - du += ((a10 != -1) ? src[a10] : scalar) * (one - v); - du += ((a11 != -1) ? src[a11] : scalar) * v; + du -= ((a00 != -1) ? static_cast(src[a00]) : scalar) * (one - v); + du -= ((a01 != -1) ? static_cast(src[a01]) : scalar) * v; + du += ((a10 != -1) ? static_cast(src[a10]) : scalar) * (one - v); + du += ((a11 != -1) ? static_cast(src[a11]) : scalar) * v; - grad[oh * OW * 2 + ow * 2 + 0] += round_converter(hidden * dv); - grad[oh * OW * 2 + ow * 2 + 1] += round_converter(hidden * du); + grad[oh * OW * 2 + ow * 2 + 0] += hidden * dv; + grad[oh * OW * 2 + ow * 2 + 1] += hidden * du; } } } -template +template < + typename ctype, const uint32_t format, ::BorderMode bmode, + ::InterpolationMode imode> void dispatch_backwardmat( const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) { @@ -115,8 +116,11 @@ void dispatch_backwardmat( cuda_check(cudaMemsetAsync( grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2, stream)); - kern_general<<>>( - src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar); + + if (imode == ::InterpolationMode::INTER_LINEAR) { + kern_general_linear<<>>( + src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar); + } N -= curr_batch_size; src += curr_batch_size * C * IH * IW; @@ -132,27 +136,35 @@ namespace megdnn { namespace cuda { namespace remap { -template +template < + typename ctype, const uint32_t format, ::BorderMode bmode, + ::InterpolationMode imode> void backwardmat_proxy( const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) { - dispatch_backwardmat( + dispatch_backwardmat( src, map_xy, diff, grad, N, C, IH, IW, OH, OW, scalar, stream); after_kernel_launch(); } -#define INST(ctype, format, bmode) \ - template void \ - backwardmat_proxy( \ +#define INST(ctype, format, bmode, imode) \ + template void backwardmat_proxy< \ + ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \ + ::InterpolationMode::imode>( \ const ctype*, const float*, const ctype*, float*, int, int, int, int, int, \ int, float, cudaStream_t); -#define FOR_FORMAT_BMODE(ctype) \ - INST(ctype, NCHW, BORDER_CONSTANT) \ - INST(ctype, NCHW, BORDER_REPLICATE) \ - INST(ctype, NCHW, BORDER_REFLECT) \ - INST(ctype, NCHW, BORDER_REFLECT_101) \ - INST(ctype, NCHW, BORDER_WRAP) +#define FOR_FORMAT_BMODE(ctype) \ + INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) FOR_FORMAT_BMODE(float) DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) diff --git a/dnn/src/cuda/remap/common.h b/dnn/src/cuda/remap/common.h index 0d85f6a82c7e6c1ab9a92f36f18699edbcb6e81c..779c205606de28f6294a9b8d53a918d5dcd183a3 100644 --- a/dnn/src/cuda/remap/common.h +++ b/dnn/src/cuda/remap/common.h @@ -21,17 +21,23 @@ namespace remap { // all these kernels use LINEAR interpolation -template +template < + typename ctype, const uint32_t format, ::BorderMode bmode, + ::InterpolationMode imode> void forward_proxy( const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream); -template +template < + typename ctype, const uint32_t format, ::BorderMode bmode, + ::InterpolationMode imode> void backwarddata_proxy( ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, int IW, int OH, int OW, cudaStream_t stream); -template +template < + typename ctype, const uint32_t format, ::BorderMode bmode, + ::InterpolationMode imode> void backwardmat_proxy( const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream); diff --git a/dnn/src/cuda/remap/forward.cpp b/dnn/src/cuda/remap/forward.cpp index 4a40b551c9a630e4217c358a634d3fab8b4d14a2..2e1876cddf96771696cfecd6c748a5107e231c5a 100644 --- a/dnn/src/cuda/remap/forward.cpp +++ b/dnn/src/cuda/remap/forward.cpp @@ -30,8 +30,9 @@ void RemapImpl::exec( OW = map_xy.layout.shape[2]; megdnn_assert( - param().imode == param::Remap::InterpolationMode::LINEAR, - "only support LINEAR interpolationMode"); + (param().imode == param::Remap::InterpolationMode::NEAREST) || + (param().imode == param::Remap::InterpolationMode::LINEAR), + "only support NEAREST and LINEAR interpolationMode"); if (param().format == param::Remap::Format::NCHW) { N = src.layout.shape[0]; @@ -47,13 +48,15 @@ void RemapImpl::exec( megdnn_throw("unsupported format, cuda remap"); } -#define cb(dt, _format, bmode) \ +#define cb(dt, _format, bmode, inter_mode) \ if (param().format == param::Remap::Format::_format && \ - param().border_type == param::Remap::BorderMode::bmode) { \ + param().border_type == param::Remap::BorderMode::bmode && \ + param().imode == param::Remap::InterpolationMode::inter_mode) { \ using ctype = DTypeTrait
::ctype; \ remap::forward_proxy< \ ctype, param_enumv::Remap::Format::_format, \ - ::BorderMode::BORDER_##bmode>( \ + ::BorderMode::BORDER_##bmode, \ + ::InterpolationMode::INTER_##inter_mode>( \ src.compatible_ptr(), map_xy.compatible_ptr(), \ dst.compatible_ptr(), N, C, IH, IW, OH, OW, param().scalar, \ stream); \ @@ -62,16 +65,26 @@ void RemapImpl::exec( #define support_dtype(dt) \ case DTypeTrait
::enumv: { \ - cb(dt, NCHW, CONSTANT); \ - cb(dt, NCHW, REPLICATE); \ - cb(dt, NCHW, REFLECT); \ - cb(dt, NCHW, REFLECT_101); \ - cb(dt, NCHW, WRAP); \ - cb(dt, NHWC, CONSTANT); \ - cb(dt, NHWC, REPLICATE); \ - cb(dt, NHWC, REFLECT); \ - cb(dt, NHWC, REFLECT_101); \ - cb(dt, NHWC, WRAP); \ + cb(dt, NCHW, CONSTANT, NEAREST); \ + cb(dt, NCHW, REPLICATE, NEAREST); \ + cb(dt, NCHW, REFLECT, NEAREST); \ + cb(dt, NCHW, REFLECT_101, NEAREST); \ + cb(dt, NCHW, WRAP, NEAREST); \ + cb(dt, NHWC, CONSTANT, NEAREST); \ + cb(dt, NHWC, REPLICATE, NEAREST); \ + cb(dt, NHWC, REFLECT, NEAREST); \ + cb(dt, NHWC, REFLECT_101, NEAREST); \ + cb(dt, NHWC, WRAP, NEAREST); \ + cb(dt, NCHW, CONSTANT, LINEAR); \ + cb(dt, NCHW, REPLICATE, LINEAR); \ + cb(dt, NCHW, REFLECT, LINEAR); \ + cb(dt, NCHW, REFLECT_101, LINEAR); \ + cb(dt, NCHW, WRAP, LINEAR); \ + cb(dt, NHWC, CONSTANT, LINEAR); \ + cb(dt, NHWC, REPLICATE, LINEAR); \ + cb(dt, NHWC, REFLECT, LINEAR); \ + cb(dt, NHWC, REFLECT_101, LINEAR); \ + cb(dt, NHWC, WRAP, LINEAR); \ megdnn_throw("unsupported border type in remap cuda"); \ } diff --git a/dnn/src/cuda/remap/forward.cu b/dnn/src/cuda/remap/forward.cu index 648c4188e81b86fbe5fec8c964cfd30f6d7fbd31..700008421d242f79a515f2f06096b60150564e6e 100644 --- a/dnn/src/cuda/remap/forward.cu +++ b/dnn/src/cuda/remap/forward.cu @@ -62,8 +62,23 @@ struct GetSrcData { } }; -template -__global__ void kern_general( +__device__ inline float round_half_to_even(float f) { + const float round_away_from_zero = round(f); + const float diff = round_away_from_zero - f; + + if ((diff != 0.5f) && (diff != -0.5f)) { + return round_away_from_zero; + } + + if (fmod(round_away_from_zero, 2.0f) == 0.0f) { + return round_away_from_zero; + } + + return f - diff; +} + +template +__global__ void kern_general_nearest( const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C, int IH, int IW, int OH, int OW, float scalar) { int ow = blockIdx.x * blockDim.x + threadIdx.x; @@ -71,37 +86,22 @@ __global__ void kern_general( sptr += blockIdx.z * C * IH * IW; dst += blockIdx.z * C * OH * OW; map_xy += blockIdx.z * 2 * OH * OW; - RoundingConverter round_converter; if (ow < OW && oh < OH) { float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; - int col = static_cast(floor(index_col)); - int row = static_cast(floor(index_row)); - float v = index_col - col; - float u = index_row - row; + int col = static_cast(round_half_to_even(index_col)); + int row = static_cast(round_half_to_even(index_row)); for (int c = 0; c < C; ++c) { - ctype a00 = GetSrcData::get( - sptr, row + 0, col + 0, c, IH, IW, C, scalar); - ctype a01 = GetSrcData::get( - sptr, row + 0, col + 1, c, IH, IW, C, scalar); - ctype a10 = GetSrcData::get( - sptr, row + 1, col + 0, c, IH, IW, C, scalar); - ctype a11 = GetSrcData::get( - sptr, row + 1, col + 1, c, IH, IW, C, scalar); - /* in remap, we use float as the type of intermediate result */ - float result = static_cast(a00) * (1.f - u) * (1.f - v) + - static_cast(a01) * (1.f - u) * v + - static_cast(a10) * (1.f - v) * u + - static_cast(a11) * u * v; - dst[get_offset(oh, ow, c, OH, OW, C)] = - round_converter(result); + dst[get_offset(oh, ow, c, OH, OW, C)] = + GetSrcData::get( + sptr, row, col, c, IH, IW, C, scalar); } } } -template -__global__ void kern_general_nhwc( +template +__global__ void kern_general_linear( const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C, int IH, int IW, int OH, int OW, float scalar) { int ow = blockIdx.x * blockDim.x + threadIdx.x; @@ -119,26 +119,27 @@ __global__ void kern_general_nhwc( float v = index_col - col; float u = index_row - row; for (int c = 0; c < C; ++c) { - ctype a00 = GetSrcData::get( + ctype a00 = GetSrcData::get( sptr, row + 0, col + 0, c, IH, IW, C, scalar); - ctype a01 = GetSrcData::get( + ctype a01 = GetSrcData::get( sptr, row + 0, col + 1, c, IH, IW, C, scalar); - ctype a10 = GetSrcData::get( + ctype a10 = GetSrcData::get( sptr, row + 1, col + 0, c, IH, IW, C, scalar); - ctype a11 = GetSrcData::get( + ctype a11 = GetSrcData::get( sptr, row + 1, col + 1, c, IH, IW, C, scalar); /* in remap, we use float as the type of intermediate result */ float result = static_cast(a00) * (1.f - u) * (1.f - v) + static_cast(a01) * (1.f - u) * v + static_cast(a10) * (1.f - v) * u + static_cast(a11) * u * v; - dst[get_offset(oh, ow, c, OH, OW, C)] = - round_converter(result); + dst[get_offset(oh, ow, c, OH, OW, C)] = round_converter(result); } } } -template +template < + typename ctype, const uint32_t format, ::BorderMode bmode, + ::InterpolationMode imode> void dispatch_forward( const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) { @@ -150,11 +151,11 @@ void dispatch_forward( dim3 threads(BX, BY); dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); - if (format == param_enumv::Remap::Format::NCHW) { - kern_general<<>>( + if (imode == ::InterpolationMode::INTER_NEAREST) { + kern_general_nearest<<>>( src, map_xy, dst, C, IH, IW, OH, OW, scalar); - } else if (format == param_enumv::Remap::Format::NHWC) { - kern_general_nhwc<<>>( + } else if (imode == ::InterpolationMode::INTER_LINEAR) { + kern_general_linear<<>>( src, map_xy, dst, C, IH, IW, OH, OW, scalar); } @@ -171,32 +172,45 @@ namespace megdnn { namespace cuda { namespace remap { -template +template < + typename ctype, const uint32_t format, ::BorderMode bmode, + ::InterpolationMode imode> void forward_proxy( const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) { - dispatch_forward( + dispatch_forward( src, map_xy, dst, N, C, IH, IW, OH, OW, scalar, stream); after_kernel_launch(); } -#define INST(ctype, format, bmode) \ - template void \ - forward_proxy( \ +#define INST(ctype, format, bmode, imode) \ + template void forward_proxy< \ + ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \ + ::InterpolationMode::imode>( \ const ctype*, const float*, ctype*, int, int, int, int, int, int, float, \ cudaStream_t); -#define FOR_FORMAT_BMODE(ctype) \ - INST(ctype, NCHW, BORDER_CONSTANT) \ - INST(ctype, NCHW, BORDER_REPLICATE) \ - INST(ctype, NCHW, BORDER_REFLECT) \ - INST(ctype, NCHW, BORDER_REFLECT_101) \ - INST(ctype, NCHW, BORDER_WRAP) \ - INST(ctype, NHWC, BORDER_CONSTANT) \ - INST(ctype, NHWC, BORDER_REPLICATE) \ - INST(ctype, NHWC, BORDER_REFLECT) \ - INST(ctype, NHWC, BORDER_REFLECT_101) \ - INST(ctype, NHWC, BORDER_WRAP) +#define FOR_FORMAT_BMODE(ctype) \ + INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \ + INST(ctype, NHWC, BORDER_CONSTANT, INTER_NEAREST) \ + INST(ctype, NHWC, BORDER_REPLICATE, INTER_NEAREST) \ + INST(ctype, NHWC, BORDER_REFLECT, INTER_NEAREST) \ + INST(ctype, NHWC, BORDER_REFLECT_101, INTER_NEAREST) \ + INST(ctype, NHWC, BORDER_WRAP, INTER_NEAREST) \ + INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \ + INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) \ + INST(ctype, NHWC, BORDER_CONSTANT, INTER_LINEAR) \ + INST(ctype, NHWC, BORDER_REPLICATE, INTER_LINEAR) \ + INST(ctype, NHWC, BORDER_REFLECT, INTER_LINEAR) \ + INST(ctype, NHWC, BORDER_REFLECT_101, INTER_LINEAR) \ + INST(ctype, NHWC, BORDER_WRAP, INTER_LINEAR) FOR_FORMAT_BMODE(float) DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) diff --git a/dnn/src/naive/remap/opr_impl.cpp b/dnn/src/naive/remap/opr_impl.cpp index 56e8b632e6b51560e16212862c3d19f91c7557b8..440d73d2b995180f8a635c4acd3c719eb3658832 100644 --- a/dnn/src/naive/remap/opr_impl.cpp +++ b/dnn/src/naive/remap/opr_impl.cpp @@ -36,6 +36,12 @@ inline int get_offset( return height * w * c + width * c + channel; } +template <> +inline int get_offset( + int height, int width, int channel, int, int w, int c) { + return ((height * c + channel) * w + width) * 4; +} + template < typename ctype, param::Remap::Format format, param::Remap::BorderMode bordertype> @@ -80,8 +86,12 @@ void remap_LINEAR( const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, float scalar) { RoundingConverter round_converter; - for (int n = 0; n < N; - ++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) { + size_t c_scale = 1; + if (format == param::Remap::Format::NHWCD4) { + c_scale = 4; + } + for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW, + dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) { for (int h = 0; h < OH; ++h) { for (int w = 0; w < OW; ++w) { float index_col = map_xy[h * OW * 2 + w * 2 + 0]; @@ -92,18 +102,102 @@ void remap_LINEAR( float u = index_row - row; // alphah const float one = 1.f; for (int c = 0; c < C; ++c) { - ctype a00 = GetSrcData::get( - src, row + 0, col + 0, c, IH, IW, C, scalar); - ctype a01 = GetSrcData::get( - src, row + 0, col + 1, c, IH, IW, C, scalar); - ctype a10 = GetSrcData::get( - src, row + 1, col + 0, c, IH, IW, C, scalar); - ctype a11 = GetSrcData::get( - src, row + 1, col + 1, c, IH, IW, C, scalar); - - dst[get_offset(h, w, c, OH, OW, C)] = round_converter( - a00 * (one - v) * (one - u) + a01 * (one - u) * v + - a10 * (one - v) * u + a11 * u * v); + if (format == param::Remap::Format::NHWCD4) { + int idx00 = GetSrcData::get_index( + row + 0, col + 0, c, IH, IW, C); + int idx01 = GetSrcData::get_index( + row + 0, col + 1, c, IH, IW, C); + int idx10 = GetSrcData::get_index( + row + 1, col + 0, c, IH, IW, C); + int idx11 = GetSrcData::get_index( + row + 1, col + 1, c, IH, IW, C); + for (int c_inner = 0; c_inner < 4; ++c_inner) { + ctype a00 = (idx00 != -1) ? src[idx00 + c_inner] + : round_converter(scalar); + ctype a01 = (idx01 != -1) ? src[idx01 + c_inner] + : round_converter(scalar); + ctype a10 = (idx10 != -1) ? src[idx10 + c_inner] + : round_converter(scalar); + ctype a11 = (idx11 != -1) ? src[idx11 + c_inner] + : round_converter(scalar); + dst[get_offset(h, w, c, OH, OW, C) + c_inner] = + round_converter( + a00 * (one - v) * (one - u) + + a01 * (one - u) * v + a10 * (one - v) * u + + a11 * u * v); + } + } else { + ctype a00 = GetSrcData::get( + src, row + 0, col + 0, c, IH, IW, C, scalar); + ctype a01 = GetSrcData::get( + src, row + 0, col + 1, c, IH, IW, C, scalar); + ctype a10 = GetSrcData::get( + src, row + 1, col + 0, c, IH, IW, C, scalar); + ctype a11 = GetSrcData::get( + src, row + 1, col + 1, c, IH, IW, C, scalar); + + dst[get_offset(h, w, c, OH, OW, C)] = round_converter( + a00 * (one - v) * (one - u) + a01 * (one - u) * v + + a10 * (one - v) * u + a11 * u * v); + } + } + } + } + } +} + +namespace { + +inline float round_half_to_even(float f) { + const float round_away_from_zero = std::round(f); + const float diff = round_away_from_zero - f; + + if ((diff != 0.5f) && (diff != -0.5f)) { + return round_away_from_zero; + } + + if (std::fmod(round_away_from_zero, 2.0f) == 0.0f) { + return round_away_from_zero; + } + + return f - diff; +} + +} // anonymous namespace + +template < + typename ctype, param::Remap::Format format, + param::Remap::BorderMode bordertype> +void remap_NEAREST( + const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, + int OH, int OW, float scalar) { + RoundingConverter round_converter; + size_t c_scale = 1; + if (format == param::Remap::Format::NHWCD4) { + c_scale = 4; + } + for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW, + dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) { + for (int h = 0; h < OH; ++h) { + for (int w = 0; w < OW; ++w) { + float index_col = map_xy[h * OW * 2 + w * 2 + 0]; + float index_row = map_xy[h * OW * 2 + w * 2 + 1]; + int col = static_cast(round_half_to_even(index_col)); + int row = static_cast(round_half_to_even(index_row)); + for (int c = 0; c < C; ++c) { + if (format == param::Remap::Format::NHWCD4) { + int idx = GetSrcData::get_index( + row, col, c, IH, IW, C); + for (int c_inner = 0; c_inner < 4; ++c_inner) { + dst[get_offset(h, w, c, OH, OW, C) + c_inner] = + (idx != -1) ? (src[idx + c_inner]) + : round_converter(scalar); + } + } else { + dst[get_offset(h, w, c, OH, OW, C)] = + GetSrcData::get( + src, row, col, c, IH, IW, C, scalar); + } } } } @@ -161,13 +255,40 @@ void remap_LINEAR_backwarddata( } } +template < + typename ctype, param::Remap::Format format, + param::Remap::BorderMode bordertype> +void remap_NEAREST_backwarddata( + ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, + int IW, int OH, int OW) { + std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW); + for (int n = 0; n < N; + ++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) { + for (int h = 0; h < OH; ++h) { + for (int w = 0; w < OW; ++w) { + float index_col = map_xy[h * OW * 2 + w * 2 + 0]; + float index_row = map_xy[h * OW * 2 + w * 2 + 1]; + int col = static_cast(round_half_to_even(index_col)); + int row = static_cast(round_half_to_even(index_row)); + for (int c = 0; c < C; ++c) { + ctype hidden = diff[get_offset(h, w, c, OH, OW, C)]; + int idx = GetSrcData::get_index( + row, col, c, IH, IW, C); + if (idx != -1) { + grad[idx] += hidden; + } + } + } + } + } +} + template < typename ctype, param::Remap::Format format, param::Remap::BorderMode bordertype> void remap_LINEAR_backwardmat( const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, int C, int IH, int IW, int OH, int OW, float scalar) { - RoundingConverter round_converter; std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW); for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2, grad += OH * OW * 2) { @@ -194,24 +315,38 @@ void remap_LINEAR_backwardmat( int a11 = GetSrcData::get_index( row + 1, col + 1, c, IH, IW, C); - dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u); - dv += ((a01 != -1) ? src[a01] : scalar) * (one - u); - dv -= ((a10 != -1) ? src[a10] : scalar) * u; - dv += ((a11 != -1) ? src[a11] : scalar) * u; - - du -= ((a00 != -1) ? src[a00] : scalar) * (one - v); - du -= ((a01 != -1) ? src[a01] : scalar) * v; - du += ((a10 != -1) ? src[a10] : scalar) * (one - v); - du += ((a11 != -1) ? src[a11] : scalar) * v; - - grad[h * OW * 2 + w * 2 + 0] += round_converter(hidden * dv); - grad[h * OW * 2 + w * 2 + 1] += round_converter(hidden * du); + dv -= ((a00 != -1) ? static_cast(src[a00]) : scalar) * + (one - u); + dv += ((a01 != -1) ? static_cast(src[a01]) : scalar) * + (one - u); + dv -= ((a10 != -1) ? static_cast(src[a10]) : scalar) * u; + dv += ((a11 != -1) ? static_cast(src[a11]) : scalar) * u; + + du -= ((a00 != -1) ? static_cast(src[a00]) : scalar) * + (one - v); + du -= ((a01 != -1) ? static_cast(src[a01]) : scalar) * v; + du += ((a10 != -1) ? static_cast(src[a10]) : scalar) * + (one - v); + du += ((a11 != -1) ? static_cast(src[a11]) : scalar) * v; + + grad[h * OW * 2 + w * 2 + 0] += hidden * dv; + grad[h * OW * 2 + w * 2 + 1] += hidden * du; } } } } } +template < + typename ctype, param::Remap::Format format, + param::Remap::BorderMode bordertype> +void remap_NEAREST_backwardmat( + const ctype*, const float*, const ctype*, float* grad, int N, int, int, int, + int OH, int OW, float) { + std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW); + return; +} + } // namespace void RemapImpl::exec( @@ -229,6 +364,11 @@ void RemapImpl::exec( C = src.layout.shape[3]; IH = src.layout.shape[1]; IW = src.layout.shape[2]; + } else if (param().format == param::Remap::Format::NHWCD4) { + N = src.layout.shape[0]; + C = src.layout.shape[2]; + IH = src.layout.shape[1]; + IW = src.layout.shape[3]; } else { megdnn_throw("unsupported format"); } @@ -255,11 +395,31 @@ void RemapImpl::exec( cb(dt, NCHW, REFLECT, LINEAR); \ cb(dt, NCHW, REFLECT_101, LINEAR); \ cb(dt, NCHW, WRAP, LINEAR); \ + cb(dt, NHWCD4, CONSTANT, LINEAR); \ + cb(dt, NHWCD4, REPLICATE, LINEAR); \ + cb(dt, NHWCD4, REFLECT, LINEAR); \ + cb(dt, NHWCD4, REFLECT_101, LINEAR); \ + cb(dt, NHWCD4, WRAP, LINEAR); \ cb(dt, NHWC, CONSTANT, LINEAR); \ cb(dt, NHWC, REPLICATE, LINEAR); \ cb(dt, NHWC, REFLECT, LINEAR); \ cb(dt, NHWC, REFLECT_101, LINEAR); \ cb(dt, NHWC, WRAP, LINEAR); \ + cb(dt, NCHW, CONSTANT, NEAREST); \ + cb(dt, NCHW, REPLICATE, NEAREST); \ + cb(dt, NCHW, REFLECT, NEAREST); \ + cb(dt, NCHW, REFLECT_101, NEAREST); \ + cb(dt, NCHW, WRAP, NEAREST); \ + cb(dt, NHWCD4, CONSTANT, NEAREST); \ + cb(dt, NHWCD4, REPLICATE, NEAREST); \ + cb(dt, NHWCD4, REFLECT, NEAREST); \ + cb(dt, NHWCD4, REFLECT_101, NEAREST); \ + cb(dt, NHWCD4, WRAP, NEAREST); \ + cb(dt, NHWC, CONSTANT, NEAREST); \ + cb(dt, NHWC, REPLICATE, NEAREST); \ + cb(dt, NHWC, REFLECT, NEAREST); \ + cb(dt, NHWC, REFLECT_101, NEAREST); \ + cb(dt, NHWC, WRAP, NEAREST); \ megdnn_throw( \ "format, border type or imode is incorrect in remap navie " \ "with dtype = " #dt); \ @@ -313,6 +473,11 @@ void RemapBackwardDataImpl::exec( cb(dt, NCHW, REFLECT, LINEAR); \ cb(dt, NCHW, REFLECT_101, LINEAR); \ cb(dt, NCHW, WRAP, LINEAR); \ + cb(dt, NCHW, CONSTANT, NEAREST); \ + cb(dt, NCHW, REPLICATE, NEAREST); \ + cb(dt, NCHW, REFLECT, NEAREST); \ + cb(dt, NCHW, REFLECT_101, NEAREST); \ + cb(dt, NCHW, WRAP, NEAREST); \ megdnn_throw( \ "format, border type or imode is incorrect in remap navie " \ "with dtype = " #dt); \ @@ -365,6 +530,11 @@ void RemapBackwardMatImpl::exec( cb(dt, NCHW, REFLECT, LINEAR); \ cb(dt, NCHW, REFLECT_101, LINEAR); \ cb(dt, NCHW, WRAP, LINEAR); \ + cb(dt, NCHW, CONSTANT, NEAREST); \ + cb(dt, NCHW, REPLICATE, NEAREST); \ + cb(dt, NCHW, REFLECT, NEAREST); \ + cb(dt, NCHW, REFLECT_101, NEAREST); \ + cb(dt, NCHW, WRAP, NEAREST); \ megdnn_throw( \ "format, border type or imode is incorrect in remap navie " \ "with dtype = " #dt); \ diff --git a/dnn/test/common/remap.h b/dnn/test/common/remap.h index 8c2198843e2a248357d8d056831c864fd9940308..9ae64d204c6fd69d87c3b8bc0d1fa646541feb98 100644 --- a/dnn/test/common/remap.h +++ b/dnn/test/common/remap.h @@ -34,53 +34,91 @@ static inline std::vector get_nchw_args() { param::Remap param; std::vector format_vec = {param::Remap::Format::NCHW}; + std::vector interp_mode_vec = { + param::Remap::InterpolationMode::NEAREST, + param::Remap::InterpolationMode::LINEAR}; std::vector border_mode_vec = { param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101, param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP, param::Remap::BorderMode::REPLICATE}; + // current do not test this. std::vector scalar; for (auto fmt : format_vec) { - for (auto border_type : border_mode_vec) { - param.format = fmt; - param.border_type = border_type; - args.emplace_back( - param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2}, - TensorShape{70000, 1, 2, 2}); - - args.emplace_back( - param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2}, - TensorShape{1, 1, 2, 2}); - - args.emplace_back( - param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2}, - TensorShape{1, 3, 2, 2}); - - args.emplace_back( - param, TensorShape{1, 10, 100, 100}, TensorShape{1, 100, 100, 2}, - TensorShape{1, 10, 100, 100}); - - args.emplace_back( - param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2}, - TensorShape{2, 4, 100, 200}); - - args.emplace_back( - param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2}, - TensorShape{2, 4, 20, 30}); - - args.emplace_back( - param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2}, - TensorShape{2, 4, 20, 30}); + for (auto interp_mode : interp_mode_vec) { + for (auto border_type : border_mode_vec) { + param.format = fmt; + param.imode = interp_mode; + param.border_type = border_type; + args.emplace_back( + param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2}, + TensorShape{70000, 1, 2, 2}); + + args.emplace_back( + param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2}, + TensorShape{1, 1, 2, 2}); + + args.emplace_back( + param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2}, + TensorShape{1, 3, 2, 2}); + + args.emplace_back( + param, TensorShape{1, 10, 100, 100}, + TensorShape{1, 100, 100, 2}, TensorShape{1, 10, 100, 100}); + + args.emplace_back( + param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2}, + TensorShape{2, 4, 100, 200}); + + args.emplace_back( + param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2}, + TensorShape{2, 4, 20, 30}); + + args.emplace_back( + param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2}, + TensorShape{2, 4, 20, 30}); + } } } return args; } +static inline std::vector get_nhwcd4_args() { + std::vector args; + + param::Remap param; + param.format = param::Remap::Format::NHWCD4; + param.imode = param::Remap::InterpolationMode::LINEAR; + param.border_type = param::Remap::BorderMode::CONSTANT; + // FIXME: when fractional part of bval is not zero, naive and opencl bankend may + // have different rounding result + param.scalar = 77; + args.emplace_back( + param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2}, + TensorShape{2, 4, 1, 6, 4}); + args.emplace_back( + param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2}, + TensorShape{2, 2, 1, 3, 4}); + + param.imode = param::Remap::InterpolationMode::NEAREST; + args.emplace_back( + param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2}, + TensorShape{2, 4, 1, 6, 4}); + args.emplace_back( + param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2}, + TensorShape{2, 2, 1, 3, 4}); + + return args; +} + static inline std::vector get_nhwc_args() { std::vector args; param::Remap param; std::vector format_vec = {param::Remap::Format::NHWC}; + std::vector interp_mode_vec = { + param::Remap::InterpolationMode::NEAREST, + param::Remap::InterpolationMode::LINEAR}; std::vector border_mode_vec = { param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101, param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP, @@ -88,41 +126,44 @@ static inline std::vector get_nhwc_args() { // current do not test this. std::vector scalar; for (auto fmt : format_vec) { - for (auto border_type : border_mode_vec) { - param.format = fmt; - param.border_type = border_type; - param.scalar = 12.f; - args.emplace_back( - param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2}, - TensorShape{70000, 2, 2, 1}); - - args.emplace_back( - param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2}, - TensorShape{1, 2, 2, 1}); - - args.emplace_back( - param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2}, - TensorShape{1, 2, 2, 3}); - - args.emplace_back( - param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2}, - TensorShape{1, 2, 2, 66}); - - args.emplace_back( - param, TensorShape{1, 100, 100, 10}, TensorShape{1, 100, 100, 2}, - TensorShape{1, 100, 100, 10}); - - args.emplace_back( - param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2}, - TensorShape{2, 100, 200, 4}); - - args.emplace_back( - param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2}, - TensorShape{2, 20, 30, 4}); - - args.emplace_back( - param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2}, - TensorShape{2, 20, 30, 4}); + for (auto interp_mode : interp_mode_vec) { + for (auto border_type : border_mode_vec) { + param.format = fmt; + param.imode = interp_mode; + param.border_type = border_type; + param.scalar = 12.f; + args.emplace_back( + param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2}, + TensorShape{70000, 2, 2, 1}); + + args.emplace_back( + param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2}, + TensorShape{1, 2, 2, 1}); + + args.emplace_back( + param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2}, + TensorShape{1, 2, 2, 3}); + + args.emplace_back( + param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2}, + TensorShape{1, 2, 2, 66}); + + args.emplace_back( + param, TensorShape{1, 100, 100, 10}, + TensorShape{1, 100, 100, 2}, TensorShape{1, 100, 100, 10}); + + args.emplace_back( + param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2}, + TensorShape{2, 100, 200, 4}); + + args.emplace_back( + param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2}, + TensorShape{2, 20, 30, 4}); + + args.emplace_back( + param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2}, + TensorShape{2, 20, 30, 4}); + } } } return args; diff --git a/dnn/test/common/resize.h b/dnn/test/common/resize.h index caddf8208c425a5254e51a5e98a3ab293b0208f6..6e78d74e8f1a718934c8f3ffee3791fb2c0055d9 100644 --- a/dnn/test/common/resize.h +++ b/dnn/test/common/resize.h @@ -58,6 +58,11 @@ static void set_nchw_args(std::vector& args) { args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8}); args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3}); args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); + + param.imode = param::Resize::InterpolationMode::NEAREST; + args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8}); + args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3}); + args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); } static inline std::vector get_args(IMode imode = IMode::INTER_LINEAR) { @@ -75,6 +80,25 @@ static inline std::vector get_args(IMode imode = IMode::INTER_LINEAR) { return args; } +static inline std::vector get_nhwc_args() { + std::vector args; + + param::Resize param; + param.format = param::Resize::Format::NHWC; + param.imode = param::Resize::InterpolationMode::LINEAR; + + args.emplace_back(param, TensorShape{2, 3, 4, 2}, TensorShape{2, 6, 8, 2}); + args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2}); + args.emplace_back(param, TensorShape{1, 6, 8, 2}, TensorShape{1, 3, 4, 2}); + + param.imode = param::Resize::InterpolationMode::NEAREST; + args.emplace_back(param, TensorShape{2, 3, 4, 2}, TensorShape{2, 6, 8, 2}); + args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2}); + args.emplace_back(param, TensorShape{1, 6, 8, 2}, TensorShape{1, 3, 4, 2}); + + return args; +} + static inline std::vector get_nhwcd4_args() { std::vector args; @@ -83,6 +107,9 @@ static inline std::vector get_nhwcd4_args() { param.imode = param::Resize::InterpolationMode::LINEAR; args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4}); args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4}); + param.imode = param::Resize::InterpolationMode::NEAREST; + args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4}); + args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4}); return args; } diff --git a/imperative/python/megengine/functional/vision.py b/imperative/python/megengine/functional/vision.py index 0fd905ee4c9639590482464b58870a6811369caa..bf98e3e15f082a8071b090857f93aaddfe81996b 100644 --- a/imperative/python/megengine/functional/vision.py +++ b/imperative/python/megengine/functional/vision.py @@ -351,7 +351,7 @@ def remap( "reflect_101", "wrap". scalar: value used in case of a constant border. Default: 0 interp_mode: interpolation methods. - Default: "linear". Currently only support "linear" mode. + Default: "linear". Currently also support "nearest" mode. Returns: output tensor.