From 0d7ace15c87b15c2b28a0ce5e2ce7224e166cb92 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 20 Jun 2022 17:09:51 +0800 Subject: [PATCH] fix(mgb/dnn): suport fp16 for resize nhwc GitOrigin-RevId: bb04d2a801b5cbe9c8704ce922842231c4158a4c --- dnn/src/common/resize.cpp | 8 +- dnn/src/cuda/resize/backward.cpp | 60 +++++++++---- dnn/src/cuda/resize/backward.cu | 139 +++++++++++++++++++++++------- dnn/src/cuda/resize/common.h | 5 +- dnn/src/cuda/resize/forward.cpp | 5 ++ dnn/src/cuda/resize/forward.cu | 1 + dnn/src/naive/resize/opr_impl.cpp | 124 ++++++++++++++++++++------ dnn/src/naive/resize/opr_impl.h | 6 ++ dnn/test/cuda/resize.cpp | 74 +++++++++++++++- 9 files changed, 342 insertions(+), 80 deletions(-) diff --git a/dnn/src/common/resize.cpp b/dnn/src/common/resize.cpp index e6fe8282e..4bc3302c2 100644 --- a/dnn/src/common/resize.cpp +++ b/dnn/src/common/resize.cpp @@ -67,8 +67,12 @@ void ResizeBackward::check_exec( auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert( - param().format == Param::Format::NCHW && grad.dtype == dtype::Float32(), - "Backward resize only supports Float32 and NCHW."); + (param().format == Param::Format::NCHW || + param().format == Param::Format::NHWC) && + (grad.dtype == dtype::Float32() DNN_INC_FLOAT16( + || grad.dtype == dtype::Float16())), + "Backward resize only supports NCHW and NHWC, the dtype only supports " + "Float32 and Float16."); } std::pair ResizeBase::get_cubic_coord(float scale, int idx) { diff --git a/dnn/src/cuda/resize/backward.cpp b/dnn/src/cuda/resize/backward.cpp index bfcd0ef1b..3928d2247 100644 --- a/dnn/src/cuda/resize/backward.cpp +++ b/dnn/src/cuda/resize/backward.cpp @@ -11,26 +11,56 @@ void ResizeBackwardImpl::exec( _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { check_exec(diff.layout, grad.layout, workspace.size); auto stream = cuda_stream(this->handle()); - auto N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], - IW = grad.layout.shape[3], OH = diff.layout.shape[2], - OW = diff.layout.shape[3]; + bool is_nhwc = param().format == param::Resize::Format::NHWC; + size_t N, C, IH, IW, OH, OW; + if (is_nhwc) { + if (param().imode != Param::InterpolationMode::LINEAR && + is_nhwc_contig_wc(grad.layout)) { + megdnn_assert( + 0, + "unsupport mode in resizeBackward, only support param().imode = " + "LINEAR"); + } + N = grad.layout.shape[0]; + C = grad.layout.shape[3]; + IH = grad.layout.shape[1]; + IW = grad.layout.shape[2]; + OH = diff.layout.shape[1]; + OW = diff.layout.shape[2]; + } else { + N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], + IW = grad.layout.shape[3], OH = diff.layout.shape[2], OW = diff.layout.shape[3]; + } size_t max_batch_x_channel = max_batch_x_channel_size(); - dt_float32* diff_ptr = diff.ptr(); - dt_float32* grad_ptr = grad.ptr(); size_t max_batch_size = max_batch_x_channel / C; while (N > 0) { size_t curr_batch_size = N > max_batch_size ? max_batch_size : N; - resize::backward_data_proxy( - resize::get_imode(param().imode), diff_ptr, grad_ptr, curr_batch_size, - C, IH, IW, OH, OW, stream); - - if (N <= max_batch_size) { - break; - } else { - N -= max_batch_size; - diff_ptr += curr_batch_size * diff.layout.stride[0]; - grad_ptr += curr_batch_size * grad.layout.stride[0]; + switch (grad.layout.dtype.enumv()) { +#define cb(_t) \ + case DTypeTrait<_t>::enumv: { \ + typedef DTypeTrait<_t>::ctype ct; \ + ct* diff_ptr = diff.ptr(); \ + ct* grad_ptr = grad.ptr(); \ + resize::backward_data_proxy( \ + is_nhwc, resize::get_imode(param().imode), diff_ptr, grad_ptr, \ + curr_batch_size, C, IH, IW, OH, OW, stream); \ + if (N <= max_batch_size) { \ + return; \ + } else { \ + N -= max_batch_size; \ + diff_ptr += curr_batch_size * diff.layout.stride[0]; \ + grad_ptr += curr_batch_size * grad.layout.stride[0]; \ + } \ + break; \ + } + cb(megdnn::dtype::Float32); + DNN_INC_FLOAT16(cb(megdnn::dtype::Float16)); + default: + megdnn_throw(ssprintf( + "unsupported dtype: %s in resize backward", + grad.layout.dtype.name())); } +#undef cb } } diff --git a/dnn/src/cuda/resize/backward.cu b/dnn/src/cuda/resize/backward.cu index 8bea9f2a9..dadab55be 100644 --- a/dnn/src/cuda/resize/backward.cu +++ b/dnn/src/cuda/resize/backward.cu @@ -1,3 +1,4 @@ +#include "src/common/rounding_converter.cuh" #include "src/cuda/resize/common.cuh" #include "src/cuda/resize/common.h" @@ -11,9 +12,52 @@ namespace megdnn { namespace cuda { namespace resize { +template +__global__ void resize_bwd_nhwc_kernel( + const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, + float scale_h, float scale_w) { + OutputConverter output_converter; + int n = blockIdx.z; + int ow = blockIdx.x * blockDim.x + threadIdx.x; + int oh = blockIdx.y * blockDim.y + threadIdx.y; + hidden += n * C * OH * OW; + dst += n * C * IH * IW; + if (ow < OW && oh < OH) { + float alphah, alphaw; + int ih0, iw0; + get_origin_coord(scale_h, IH, oh, alphah, ih0); + get_origin_coord(scale_w, IW, ow, alphaw, iw0); + + int ih1 = ih0 + 1; + int iw1 = iw0 + 1; + + float nalphaw = 1.0f - alphaw; + float nalphah = 1.0f - alphah; + for (int c = 0; c < C; ++c) { + atomic_add( + dst + (ih0 * IW + iw0) * C + c, + output_converter( + hidden[(oh * OW + ow) * C + c] * nalphaw * nalphah)); + atomic_add( + dst + (ih0 * IW + iw1) * C + c, + output_converter( + hidden[(oh * OW + ow) * C + c] * alphaw * nalphah)); + atomic_add( + dst + (ih1 * IW + iw0) * C + c, + output_converter( + hidden[(oh * OW + ow) * C + c] * nalphaw * alphah)); + atomic_add( + dst + (ih1 * IW + iw1) * C + c, + output_converter(hidden[(oh * OW + ow) * C + c] * alphaw * alphah)); + } + } +} + +template __global__ void resize_bwd_linear_kernel( - const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, + const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, float scale_h, float scale_w) { + OutputConverter output_converter; int n = blockIdx.z; int ow = blockIdx.x * blockDim.x + threadIdx.x; int oh = blockIdx.y * blockDim.y + threadIdx.y; @@ -31,19 +75,29 @@ __global__ void resize_bwd_linear_kernel( float nalphaw = 1.0f - alphaw; float nalphah = 1.0f - alphah; for (int c = 0; c < C; ++c) { - atomicAdd(dst + ih0 * IW + iw0, hidden[oh * OW + ow] * nalphaw * nalphah); - atomicAdd(dst + ih0 * IW + iw1, hidden[oh * OW + ow] * alphaw * nalphah); - atomicAdd(dst + ih1 * IW + iw0, hidden[oh * OW + ow] * nalphaw * alphah); - atomicAdd(dst + ih1 * IW + iw1, hidden[oh * OW + ow] * alphaw * alphah); + atomic_add( + dst + ih0 * IW + iw0, + output_converter(hidden[oh * OW + ow] * nalphaw * nalphah)); + atomic_add( + dst + ih0 * IW + iw1, + output_converter(hidden[oh * OW + ow] * alphaw * nalphah)); + atomic_add( + dst + ih1 * IW + iw0, + output_converter(hidden[oh * OW + ow] * nalphaw * alphah)); + atomic_add( + dst + ih1 * IW + iw1, + output_converter(hidden[oh * OW + ow] * alphaw * alphah)); hidden += OH * OW; dst += IH * IW; } } } +template __global__ void resize_bwd_nearest_kernel( - const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, + const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, float scale_h, float scale_w) { + OutputConverter output_converter; int n = blockIdx.z; int ow = blockIdx.x * blockDim.x + threadIdx.x; int oh = blockIdx.y * blockDim.y + threadIdx.y; @@ -54,16 +108,18 @@ __global__ void resize_bwd_nearest_kernel( int iw = get_nearest_src(scale_w, IW, ow); for (int c = 0; c < C; ++c) { - atomicAdd(dst + ih * IW + iw, hidden[oh * OW + ow]); + atomic_add(dst + ih * IW + iw, output_converter(hidden[oh * OW + ow])); hidden += OH * OW; dst += IH * IW; } } } +template __global__ void resize_bwd_cubic_kernel( - const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, + const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, float scale_h, float scale_w) { + OutputConverter output_converter; int n = blockIdx.z; int ow = blockIdx.x * blockDim.x + threadIdx.x; int oh = blockIdx.y * blockDim.y + threadIdx.y; @@ -85,9 +141,10 @@ __global__ void resize_bwd_cubic_kernel( int ih = saturate(ih0 + kh, 0, IH - 1); for (int kw = 0; kw < ksize; kw++) { int iw = saturate(iw0 + kw, 0, IW - 1); - atomicAdd( + atomic_add( dst + ih * IW + iw, - hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]); + output_converter( + hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw])); } } @@ -97,41 +154,59 @@ __global__ void resize_bwd_cubic_kernel( } } +template void backward_data_proxy( - InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH, - int IW, int OH, int OW, cudaStream_t stream) { + bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N, + int C, int IH, int IW, int OH, int OW, cudaStream_t stream) { const int BY = 16, BX = 32; { dim3 threads(BX, BY); dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, N); - cuda_check(cudaMemsetAsync(grad, 0, sizeof(float) * N * C * IH * IW, stream)); + cuda_check(cudaMemsetAsync(grad, 0, sizeof(ctype) * N * C * IH * IW, stream)); float scale_h = static_cast(OH) / IH; float scale_w = static_cast(OW) / IW; - switch (imode) { - case InterpolationMode::INTER_LINEAR: { - resize_bwd_linear_kernel<<>>( - diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); - break; - } - case InterpolationMode::INTER_NEAREST: { - resize_bwd_nearest_kernel<<>>( - diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); - break; - } - case InterpolationMode::INTER_CUBIC: { - resize_bwd_cubic_kernel<<>>( - diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); - break; - } - default: { - megdnn_throw("unsupported interpolation mode"); - break; + if (is_nhwc) { + resize_bwd_nhwc_kernel> + <<>>( + diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); + } else { + switch (imode) { + case InterpolationMode::INTER_LINEAR: { + resize_bwd_linear_kernel> + <<>>( + diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); + break; + } + case InterpolationMode::INTER_NEAREST: { + resize_bwd_nearest_kernel> + <<>>( + diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); + break; + } + case InterpolationMode::INTER_CUBIC: { + resize_bwd_cubic_kernel> + <<>>( + diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); + break; + } + default: { + megdnn_throw("unsupported interpolation mode"); + break; + } } } } after_kernel_launch(); } +#define INST(ctype) \ + template void backward_data_proxy( \ + bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \ + int, cudaStream_t); +INST(dt_float32); +DNN_INC_FLOAT16(INST(dt_float16)); +#undef INST + } // namespace resize } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/resize/common.h b/dnn/src/cuda/resize/common.h index e9a6205d1..96813e392 100644 --- a/dnn/src/cuda/resize/common.h +++ b/dnn/src/cuda/resize/common.h @@ -20,9 +20,10 @@ void forward_proxy_nchw4( const ctype* src, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, cudaStream_t stream); +template void backward_data_proxy( - InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH, - int IW, int OH, int OW, cudaStream_t stream); + bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N, + int C, int IH, int IW, int OH, int OW, cudaStream_t stream); } // namespace resize } // namespace cuda diff --git a/dnn/src/cuda/resize/forward.cpp b/dnn/src/cuda/resize/forward.cpp index 31ab025b8..0fadab64a 100644 --- a/dnn/src/cuda/resize/forward.cpp +++ b/dnn/src/cuda/resize/forward.cpp @@ -148,6 +148,11 @@ void ResizeImpl::exec( is_nhwc, resize::get_imode((param().imode)), src.ptr(), dst.ptr(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, S_IH, S_IW, stream); + } else if (src.layout.dtype == dtype::Float16{}) { + resize::forward_proxy( + is_nhwc, resize::get_imode((param().imode)), src.ptr(), + dst.ptr(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, + S_IH, S_IW, stream); } else if (src.layout.dtype == dtype::Uint8()) { resize::forward_proxy( is_nhwc, resize::get_imode((param().imode)), src.ptr(), diff --git a/dnn/src/cuda/resize/forward.cu b/dnn/src/cuda/resize/forward.cu index 2170c5504..a09ee7232 100644 --- a/dnn/src/cuda/resize/forward.cu +++ b/dnn/src/cuda/resize/forward.cu @@ -298,6 +298,7 @@ void forward_proxy_nchw4( INST(float) INST(uint8_t) INST(int8_t) +DNN_INC_FLOAT16(INST(dt_float16)) #undef INST #define INST(ctype) \ diff --git a/dnn/src/naive/resize/opr_impl.cpp b/dnn/src/naive/resize/opr_impl.cpp index ac5430eed..661d7dc56 100644 --- a/dnn/src/naive/resize/opr_impl.cpp +++ b/dnn/src/naive/resize/opr_impl.cpp @@ -387,40 +387,53 @@ void ResizeImpl::exec( } } -void ResizeBackwardImpl::exec( - _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { - check_exec(diff.layout, grad.layout, workspace.size); - megdnn_assert( - param().format == param::Resize::Format::NCHW, "invalid resize format"); - const int N = grad.layout.shape[0], C = grad.layout.shape[1], - IH = grad.layout.shape[2], IW = grad.layout.shape[3]; - const int OH = diff.layout.shape[2], OW = diff.layout.shape[3]; - const float* hptr_ = diff.ptr(); - float* sptr_ = grad.ptr(); +// ***************************Backward*************************** // +template +void ResizeBackwardImpl::kern_naive( + bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N, + int C, int IH, int IW, int OH, int OW) { float scale_h = static_cast(OH) / IH; float scale_w = static_cast(OW) / IW; + rounding::RoundingConverter output_converter; auto kern = [=]() { - auto hptr = hptr_; - auto sptr = sptr_; - std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); + auto hptr = diff; + auto sptr = grad; + std::memset(sptr, 0, sizeof(ctype) * N * C * IH * IW); rep(n, N) { rep(oh, OH) rep(ow, OW) { - switch (param().imode) { + switch (imode) { case InterpolationMode::INTER_LINEAR: { int ih0, ih1, iw0, iw1; float ah0, ah1, aw0, aw1; - std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( - param().imode, scale_h, IH, oh); - std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( - param().imode, scale_w, IW, ow); - - rep(c, C) { - float hidden = hptr[c * OH * OW + oh * OW + ow]; - sptr[c * IH * IW + ih0 * IW + iw0] += ah0 * aw0 * hidden; - sptr[c * IH * IW + ih1 * IW + iw0] += ah1 * aw0 * hidden; - sptr[c * IH * IW + ih0 * IW + iw1] += ah0 * aw1 * hidden; - sptr[c * IH * IW + ih1 * IW + iw1] += ah1 * aw1 * hidden; + std::tie(ah0, ih0, ah1, ih1) = + get_nearest_linear_coord(imode, scale_h, IH, oh); + std::tie(aw0, iw0, aw1, iw1) = + get_nearest_linear_coord(imode, scale_w, IW, ow); + + if (is_nhwc) { + rep(c, C) { + sptr[(ih0 * IW + iw0) * C + c] += output_converter( + hptr[(oh * OW + ow) * C + c] * ah0 * aw0); + sptr[(ih0 * IW + iw1) * C + c] += output_converter( + hptr[(oh * OW + ow) * C + c] * ah0 * aw1); + sptr[(ih1 * IW + iw0) * C + c] += output_converter( + hptr[(oh * OW + ow) * C + c] * ah1 * aw0); + sptr[(ih1 * IW + iw1) * C + c] += output_converter( + hptr[(oh * OW + ow) * C + c] * ah1 * aw1); + } + } else { + rep(c, C) { + float hidden = hptr[c * OH * OW + oh * OW + ow]; + sptr[c * IH * IW + ih0 * IW + iw0] += + output_converter(ah0 * aw0 * hidden); + sptr[c * IH * IW + ih1 * IW + iw0] += + output_converter(ah1 * aw0 * hidden); + sptr[c * IH * IW + ih0 * IW + iw1] += + output_converter(ah0 * aw1 * hidden); + sptr[c * IH * IW + ih1 * IW + iw1] += + output_converter(ah1 * aw1 * hidden); + } } break; } @@ -429,7 +442,7 @@ void ResizeBackwardImpl::exec( auto iw = get_nearest_src(scale_w, IW, ow); rep(c, static_cast(C)) { sptr[c * IH * IW + ih * IW + iw] += - hptr[c * OH * OW + oh * OW + ow]; + output_converter(hptr[c * OH * OW + oh * OW + ow]); } break; } @@ -452,9 +465,9 @@ void ResizeBackwardImpl::exec( int h = saturate(ih0 + kh, 0, IH - 1); rep(kw, ksize) { int w = saturate(iw0 + kw, 0, IW - 1); - sptr[c * IH * IW + h * IW + w] += + sptr[c * IH * IW + h * IW + w] += output_converter( hptr[c * OH * OW + oh * OW + ow] * - h_coeff[kh] * w_coeff[kw]; + h_coeff[kh] * w_coeff[kw]); } } } @@ -473,4 +486,59 @@ void ResizeBackwardImpl::exec( MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); } +#define INST(ctype) \ + template void ResizeBackwardImpl::kern_naive( \ + bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \ + int); +INST(dt_float32); +DNN_INC_FLOAT16(INST(dt_float16)); +#undef INST + +void ResizeBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { + check_exec(diff.layout, grad.layout, workspace.size); + megdnn_assert( + param().format == param::Resize::Format::NCHW || + param().format == param::Resize::Format::NHWC, + "invalid resize format"); + size_t N, C, IH, IW, OH, OW; + bool is_nhwc = param().format == param::Resize::Format::NHWC; + if (is_nhwc) { + if (param().imode != Param::InterpolationMode::LINEAR && + is_nhwc_contig_wc(grad.layout)) { + megdnn_assert( + 0, + "unsupport mode in resizeBackward, only support param().imode = " + "LINEAR"); + } + N = grad.layout.shape[0]; + C = grad.layout.shape[3]; + IH = grad.layout.shape[1]; + IW = grad.layout.shape[2]; + OH = diff.layout.shape[1]; + OW = diff.layout.shape[2]; + } else { + N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], + IW = grad.layout.shape[3]; + OH = diff.layout.shape[2], OW = diff.layout.shape[3]; + } + switch (grad.layout.dtype.enumv()) { +#define cb(_t) \ + case DTypeTrait<_t>::enumv: { \ + typedef DTypeTrait<_t>::ctype ct; \ + ct* diff_ptr = diff.ptr(); \ + ct* grad_ptr = grad.ptr(); \ + ResizeBackwardImpl::kern_naive( \ + is_nhwc, param().imode, diff_ptr, grad_ptr, N, C, IH, IW, OH, OW); \ + break; \ + } + cb(megdnn::dtype::Float32); + DNN_INC_FLOAT16(cb(megdnn::dtype::Float16)); + default: + megdnn_throw(ssprintf( + "unsupported dtype: %s in resize backward", + grad.layout.dtype.name())); + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/resize/opr_impl.h b/dnn/src/naive/resize/opr_impl.h index 734f977ba..b5b6bc430 100644 --- a/dnn/src/naive/resize/opr_impl.h +++ b/dnn/src/naive/resize/opr_impl.h @@ -75,6 +75,12 @@ public: size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { return 0; } + +private: + template + void kern_naive( + bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, + int N, int C, int IH, int IW, int OH, int OW); }; } // namespace naive diff --git a/dnn/test/cuda/resize.cpp b/dnn/test/cuda/resize.cpp index 3a2b09d5c..22183be68 100644 --- a/dnn/test/cuda/resize.cpp +++ b/dnn/test/cuda/resize.cpp @@ -61,13 +61,67 @@ TEST_F(CUDA, RESIZE_FORWARD) { .set_epsilon(1) .execs({arg.src, arg.dst}); } + + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_epsilon(1e-3) + .execs({arg.src, arg.dst}); + } + } +} + +TEST_F(CUDA, RESIZE_NHWC) { + using namespace resize; + std::vector args; + + param::Resize param; + param.format = param::Resize::Format::NHWC; + param.imode = param::Resize::InterpolationMode::LINEAR; + + args.emplace_back(param, TensorShape{1, 1, 4, 5}, TensorShape{1, 1, 8, 5}); + args.emplace_back(param, TensorShape{2, 6, 4, 5}, TensorShape{2, 3, 8, 5}); + args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2}); + + Checker checkerBackward(handle_cuda()); + + for (auto&& arg : args) { + checkerBackward.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_epsilon(1e-3) + .execs({arg.src, arg.dst}); + } + + for (auto&& arg : args) { + checkerBackward.set_param(arg.param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_epsilon(1e-3) + .execs({arg.src, arg.dst}); + } + + Checker checkerForward(handle_cuda()); + for (auto&& arg : args) { + checkerForward.set_param(arg.param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_epsilon(1e-3) + .execs({arg.src, arg.dst}); + } + for (auto&& arg : args) { + checkerForward.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_epsilon(1e-3) + .execs({arg.src, arg.dst}); } } TEST_F(CUDA, RESIZE_NCHW4) { using namespace resize; Checker checker(handle_cuda()); - auto args = get_nchw4_args(); for (auto&& arg : args) { checker.set_param(arg.param) @@ -113,6 +167,24 @@ TEST_F(CUDA, RESIZE_BACKWARD) { param.format = param::Resize::Format::NCHW; param.imode = imode; checker.set_param(param); + checker.set_dtype(0, dtype::Float16()); + checker.set_dtype(1, dtype::Float16()); + checker.set_epsilon(1 + 1e-3); + + checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}}); + checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}}); + checker.execs({{2, 5, 8, 5}, {2, 5, 4, 9}}); + checker.execs({{2, 5, 4, 9}, {2, 5, 8, 5}}); + } + + for (auto imode : modes) { + Checker checker(handle_cuda()); + param::Resize param; + param.format = param::Resize::Format::NCHW; + param.imode = imode; + checker.set_param(param); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}}); checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}}); -- GitLab