提交 0d7ace15 编写于 作者: M Megvii Engine Team

fix(mgb/dnn): suport fp16 for resize nhwc

GitOrigin-RevId: bb04d2a801b5cbe9c8704ce922842231c4158a4c
上级 cfed86f9
......@@ -67,8 +67,12 @@ void ResizeBackward::check_exec(
auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
megdnn_assert(
param().format == Param::Format::NCHW && grad.dtype == dtype::Float32(),
"Backward resize only supports Float32 and NCHW.");
(param().format == Param::Format::NCHW ||
param().format == Param::Format::NHWC) &&
(grad.dtype == dtype::Float32() DNN_INC_FLOAT16(
|| grad.dtype == dtype::Float16())),
"Backward resize only supports NCHW and NHWC, the dtype only supports "
"Float32 and Float16.");
}
std::pair<float, int> ResizeBase::get_cubic_coord(float scale, int idx) {
......
......@@ -11,26 +11,56 @@ void ResizeBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
check_exec(diff.layout, grad.layout, workspace.size);
auto stream = cuda_stream(this->handle());
auto N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2],
IW = grad.layout.shape[3], OH = diff.layout.shape[2],
OW = diff.layout.shape[3];
bool is_nhwc = param().format == param::Resize::Format::NHWC;
size_t N, C, IH, IW, OH, OW;
if (is_nhwc) {
if (param().imode != Param::InterpolationMode::LINEAR &&
is_nhwc_contig_wc(grad.layout)) {
megdnn_assert(
0,
"unsupport mode in resizeBackward, only support param().imode = "
"LINEAR");
}
N = grad.layout.shape[0];
C = grad.layout.shape[3];
IH = grad.layout.shape[1];
IW = grad.layout.shape[2];
OH = diff.layout.shape[1];
OW = diff.layout.shape[2];
} else {
N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2],
IW = grad.layout.shape[3], OH = diff.layout.shape[2], OW = diff.layout.shape[3];
}
size_t max_batch_x_channel = max_batch_x_channel_size();
dt_float32* diff_ptr = diff.ptr<dt_float32>();
dt_float32* grad_ptr = grad.ptr<dt_float32>();
size_t max_batch_size = max_batch_x_channel / C;
while (N > 0) {
size_t curr_batch_size = N > max_batch_size ? max_batch_size : N;
resize::backward_data_proxy(
resize::get_imode(param().imode), diff_ptr, grad_ptr, curr_batch_size,
C, IH, IW, OH, OW, stream);
if (N <= max_batch_size) {
break;
} else {
N -= max_batch_size;
diff_ptr += curr_batch_size * diff.layout.stride[0];
grad_ptr += curr_batch_size * grad.layout.stride[0];
switch (grad.layout.dtype.enumv()) {
#define cb(_t) \
case DTypeTrait<_t>::enumv: { \
typedef DTypeTrait<_t>::ctype ct; \
ct* diff_ptr = diff.ptr<ct>(); \
ct* grad_ptr = grad.ptr<ct>(); \
resize::backward_data_proxy( \
is_nhwc, resize::get_imode(param().imode), diff_ptr, grad_ptr, \
curr_batch_size, C, IH, IW, OH, OW, stream); \
if (N <= max_batch_size) { \
return; \
} else { \
N -= max_batch_size; \
diff_ptr += curr_batch_size * diff.layout.stride[0]; \
grad_ptr += curr_batch_size * grad.layout.stride[0]; \
} \
break; \
}
cb(megdnn::dtype::Float32);
DNN_INC_FLOAT16(cb(megdnn::dtype::Float16));
default:
megdnn_throw(ssprintf(
"unsupported dtype: %s in resize backward",
grad.layout.dtype.name()));
}
#undef cb
}
}
......
#include "src/common/rounding_converter.cuh"
#include "src/cuda/resize/common.cuh"
#include "src/cuda/resize/common.h"
......@@ -11,9 +12,52 @@ namespace megdnn {
namespace cuda {
namespace resize {
template <typename ctype, typename OutputConverter>
__global__ void resize_bwd_nhwc_kernel(
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) {
OutputConverter output_converter;
int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
hidden += n * C * OH * OW;
dst += n * C * IH * IW;
if (ow < OW && oh < OH) {
float alphah, alphaw;
int ih0, iw0;
get_origin_coord(scale_h, IH, oh, alphah, ih0);
get_origin_coord(scale_w, IW, ow, alphaw, iw0);
int ih1 = ih0 + 1;
int iw1 = iw0 + 1;
float nalphaw = 1.0f - alphaw;
float nalphah = 1.0f - alphah;
for (int c = 0; c < C; ++c) {
atomic_add(
dst + (ih0 * IW + iw0) * C + c,
output_converter(
hidden[(oh * OW + ow) * C + c] * nalphaw * nalphah));
atomic_add(
dst + (ih0 * IW + iw1) * C + c,
output_converter(
hidden[(oh * OW + ow) * C + c] * alphaw * nalphah));
atomic_add(
dst + (ih1 * IW + iw0) * C + c,
output_converter(
hidden[(oh * OW + ow) * C + c] * nalphaw * alphah));
atomic_add(
dst + (ih1 * IW + iw1) * C + c,
output_converter(hidden[(oh * OW + ow) * C + c] * alphaw * alphah));
}
}
}
template <typename ctype, typename OutputConverter>
__global__ void resize_bwd_linear_kernel(
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW,
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) {
OutputConverter output_converter;
int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
......@@ -31,19 +75,29 @@ __global__ void resize_bwd_linear_kernel(
float nalphaw = 1.0f - alphaw;
float nalphah = 1.0f - alphah;
for (int c = 0; c < C; ++c) {
atomicAdd(dst + ih0 * IW + iw0, hidden[oh * OW + ow] * nalphaw * nalphah);
atomicAdd(dst + ih0 * IW + iw1, hidden[oh * OW + ow] * alphaw * nalphah);
atomicAdd(dst + ih1 * IW + iw0, hidden[oh * OW + ow] * nalphaw * alphah);
atomicAdd(dst + ih1 * IW + iw1, hidden[oh * OW + ow] * alphaw * alphah);
atomic_add(
dst + ih0 * IW + iw0,
output_converter(hidden[oh * OW + ow] * nalphaw * nalphah));
atomic_add(
dst + ih0 * IW + iw1,
output_converter(hidden[oh * OW + ow] * alphaw * nalphah));
atomic_add(
dst + ih1 * IW + iw0,
output_converter(hidden[oh * OW + ow] * nalphaw * alphah));
atomic_add(
dst + ih1 * IW + iw1,
output_converter(hidden[oh * OW + ow] * alphaw * alphah));
hidden += OH * OW;
dst += IH * IW;
}
}
}
template <typename ctype, typename OutputConverter>
__global__ void resize_bwd_nearest_kernel(
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW,
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) {
OutputConverter output_converter;
int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
......@@ -54,16 +108,18 @@ __global__ void resize_bwd_nearest_kernel(
int iw = get_nearest_src(scale_w, IW, ow);
for (int c = 0; c < C; ++c) {
atomicAdd(dst + ih * IW + iw, hidden[oh * OW + ow]);
atomic_add(dst + ih * IW + iw, output_converter(hidden[oh * OW + ow]));
hidden += OH * OW;
dst += IH * IW;
}
}
}
template <typename ctype, typename OutputConverter>
__global__ void resize_bwd_cubic_kernel(
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW,
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) {
OutputConverter output_converter;
int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
......@@ -85,9 +141,10 @@ __global__ void resize_bwd_cubic_kernel(
int ih = saturate(ih0 + kh, 0, IH - 1);
for (int kw = 0; kw < ksize; kw++) {
int iw = saturate(iw0 + kw, 0, IW - 1);
atomicAdd(
atomic_add(
dst + ih * IW + iw,
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]);
output_converter(
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]));
}
}
......@@ -97,41 +154,59 @@ __global__ void resize_bwd_cubic_kernel(
}
}
template <typename ctype>
void backward_data_proxy(
InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream) {
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N,
int C, int IH, int IW, int OH, int OW, cudaStream_t stream) {
const int BY = 16, BX = 32;
{
dim3 threads(BX, BY);
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, N);
cuda_check(cudaMemsetAsync(grad, 0, sizeof(float) * N * C * IH * IW, stream));
cuda_check(cudaMemsetAsync(grad, 0, sizeof(ctype) * N * C * IH * IW, stream));
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
switch (imode) {
case InterpolationMode::INTER_LINEAR: {
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
case InterpolationMode::INTER_NEAREST: {
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
case InterpolationMode::INTER_CUBIC: {
resize_bwd_cubic_kernel<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
default: {
megdnn_throw("unsupported interpolation mode");
break;
if (is_nhwc) {
resize_bwd_nhwc_kernel<ctype, rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
} else {
switch (imode) {
case InterpolationMode::INTER_LINEAR: {
resize_bwd_linear_kernel<ctype, rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
case InterpolationMode::INTER_NEAREST: {
resize_bwd_nearest_kernel<ctype, rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
case InterpolationMode::INTER_CUBIC: {
resize_bwd_cubic_kernel<ctype, rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
default: {
megdnn_throw("unsupported interpolation mode");
break;
}
}
}
}
after_kernel_launch();
}
#define INST(ctype) \
template void backward_data_proxy( \
bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \
int, cudaStream_t);
INST(dt_float32);
DNN_INC_FLOAT16(INST(dt_float16));
#undef INST
} // namespace resize
} // namespace cuda
} // namespace megdnn
......
......@@ -20,9 +20,10 @@ void forward_proxy_nchw4(
const ctype* src, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
cudaStream_t stream);
template <typename ctype>
void backward_data_proxy(
InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH,
int IW, int OH, int OW, cudaStream_t stream);
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N,
int C, int IH, int IW, int OH, int OW, cudaStream_t stream);
} // namespace resize
} // namespace cuda
......
......@@ -148,6 +148,11 @@ void ResizeImpl::exec(
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float32>(),
dst.ptr<dt_float32>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, stream);
} else if (src.layout.dtype == dtype::Float16{}) {
resize::forward_proxy(
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float16>(),
dst.ptr<dt_float16>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, stream);
} else if (src.layout.dtype == dtype::Uint8()) {
resize::forward_proxy(
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_uint8>(),
......
......@@ -298,6 +298,7 @@ void forward_proxy_nchw4(
INST(float)
INST(uint8_t)
INST(int8_t)
DNN_INC_FLOAT16(INST(dt_float16))
#undef INST
#define INST(ctype) \
......
......@@ -387,40 +387,53 @@ void ResizeImpl::exec(
}
}
void ResizeBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
check_exec(diff.layout, grad.layout, workspace.size);
megdnn_assert(
param().format == param::Resize::Format::NCHW, "invalid resize format");
const int N = grad.layout.shape[0], C = grad.layout.shape[1],
IH = grad.layout.shape[2], IW = grad.layout.shape[3];
const int OH = diff.layout.shape[2], OW = diff.layout.shape[3];
const float* hptr_ = diff.ptr<dt_float32>();
float* sptr_ = grad.ptr<dt_float32>();
// ***************************Backward*************************** //
template <typename ctype>
void ResizeBackwardImpl::kern_naive(
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N,
int C, int IH, int IW, int OH, int OW) {
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
rounding::RoundingConverter<ctype> output_converter;
auto kern = [=]() {
auto hptr = hptr_;
auto sptr = sptr_;
std::memset(sptr, 0, sizeof(float) * N * C * IH * IW);
auto hptr = diff;
auto sptr = grad;
std::memset(sptr, 0, sizeof(ctype) * N * C * IH * IW);
rep(n, N) {
rep(oh, OH) rep(ow, OW) {
switch (param().imode) {
switch (imode) {
case InterpolationMode::INTER_LINEAR: {
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;
std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord(
param().imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord(
param().imode, scale_w, IW, ow);
rep(c, C) {
float hidden = hptr[c * OH * OW + oh * OW + ow];
sptr[c * IH * IW + ih0 * IW + iw0] += ah0 * aw0 * hidden;
sptr[c * IH * IW + ih1 * IW + iw0] += ah1 * aw0 * hidden;
sptr[c * IH * IW + ih0 * IW + iw1] += ah0 * aw1 * hidden;
sptr[c * IH * IW + ih1 * IW + iw1] += ah1 * aw1 * hidden;
std::tie(ah0, ih0, ah1, ih1) =
get_nearest_linear_coord(imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) =
get_nearest_linear_coord(imode, scale_w, IW, ow);
if (is_nhwc) {
rep(c, C) {
sptr[(ih0 * IW + iw0) * C + c] += output_converter(
hptr[(oh * OW + ow) * C + c] * ah0 * aw0);
sptr[(ih0 * IW + iw1) * C + c] += output_converter(
hptr[(oh * OW + ow) * C + c] * ah0 * aw1);
sptr[(ih1 * IW + iw0) * C + c] += output_converter(
hptr[(oh * OW + ow) * C + c] * ah1 * aw0);
sptr[(ih1 * IW + iw1) * C + c] += output_converter(
hptr[(oh * OW + ow) * C + c] * ah1 * aw1);
}
} else {
rep(c, C) {
float hidden = hptr[c * OH * OW + oh * OW + ow];
sptr[c * IH * IW + ih0 * IW + iw0] +=
output_converter(ah0 * aw0 * hidden);
sptr[c * IH * IW + ih1 * IW + iw0] +=
output_converter(ah1 * aw0 * hidden);
sptr[c * IH * IW + ih0 * IW + iw1] +=
output_converter(ah0 * aw1 * hidden);
sptr[c * IH * IW + ih1 * IW + iw1] +=
output_converter(ah1 * aw1 * hidden);
}
}
break;
}
......@@ -429,7 +442,7 @@ void ResizeBackwardImpl::exec(
auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) {
sptr[c * IH * IW + ih * IW + iw] +=
hptr[c * OH * OW + oh * OW + ow];
output_converter(hptr[c * OH * OW + oh * OW + ow]);
}
break;
}
......@@ -452,9 +465,9 @@ void ResizeBackwardImpl::exec(
int h = saturate<int, int>(ih0 + kh, 0, IH - 1);
rep(kw, ksize) {
int w = saturate<int, int>(iw0 + kw, 0, IW - 1);
sptr[c * IH * IW + h * IW + w] +=
sptr[c * IH * IW + h * IW + w] += output_converter(
hptr[c * OH * OW + oh * OW + ow] *
h_coeff[kh] * w_coeff[kw];
h_coeff[kh] * w_coeff[kw]);
}
}
}
......@@ -473,4 +486,59 @@ void ResizeBackwardImpl::exec(
MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
}
#define INST(ctype) \
template void ResizeBackwardImpl::kern_naive( \
bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \
int);
INST(dt_float32);
DNN_INC_FLOAT16(INST(dt_float16));
#undef INST
void ResizeBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
check_exec(diff.layout, grad.layout, workspace.size);
megdnn_assert(
param().format == param::Resize::Format::NCHW ||
param().format == param::Resize::Format::NHWC,
"invalid resize format");
size_t N, C, IH, IW, OH, OW;
bool is_nhwc = param().format == param::Resize::Format::NHWC;
if (is_nhwc) {
if (param().imode != Param::InterpolationMode::LINEAR &&
is_nhwc_contig_wc(grad.layout)) {
megdnn_assert(
0,
"unsupport mode in resizeBackward, only support param().imode = "
"LINEAR");
}
N = grad.layout.shape[0];
C = grad.layout.shape[3];
IH = grad.layout.shape[1];
IW = grad.layout.shape[2];
OH = diff.layout.shape[1];
OW = diff.layout.shape[2];
} else {
N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2],
IW = grad.layout.shape[3];
OH = diff.layout.shape[2], OW = diff.layout.shape[3];
}
switch (grad.layout.dtype.enumv()) {
#define cb(_t) \
case DTypeTrait<_t>::enumv: { \
typedef DTypeTrait<_t>::ctype ct; \
ct* diff_ptr = diff.ptr<ct>(); \
ct* grad_ptr = grad.ptr<ct>(); \
ResizeBackwardImpl::kern_naive( \
is_nhwc, param().imode, diff_ptr, grad_ptr, N, C, IH, IW, OH, OW); \
break; \
}
cb(megdnn::dtype::Float32);
DNN_INC_FLOAT16(cb(megdnn::dtype::Float16));
default:
megdnn_throw(ssprintf(
"unsupported dtype: %s in resize backward",
grad.layout.dtype.name()));
}
}
// vim: syntax=cpp.doxygen
......@@ -75,6 +75,12 @@ public:
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override {
return 0;
}
private:
template <typename ctype>
void kern_naive(
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad,
int N, int C, int IH, int IW, int OH, int OW);
};
} // namespace naive
......
......@@ -61,13 +61,67 @@ TEST_F(CUDA, RESIZE_FORWARD) {
.set_epsilon(1)
.execs({arg.src, arg.dst});
}
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
}
}
TEST_F(CUDA, RESIZE_NHWC) {
using namespace resize;
std::vector<TestArg> args;
param::Resize param;
param.format = param::Resize::Format::NHWC;
param.imode = param::Resize::InterpolationMode::LINEAR;
args.emplace_back(param, TensorShape{1, 1, 4, 5}, TensorShape{1, 1, 8, 5});
args.emplace_back(param, TensorShape{2, 6, 4, 5}, TensorShape{2, 3, 8, 5});
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2});
Checker<ResizeBackward> checkerBackward(handle_cuda());
for (auto&& arg : args) {
checkerBackward.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
for (auto&& arg : args) {
checkerBackward.set_param(arg.param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
Checker<ResizeForward> checkerForward(handle_cuda());
for (auto&& arg : args) {
checkerForward.set_param(arg.param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
for (auto&& arg : args) {
checkerForward.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
}
TEST_F(CUDA, RESIZE_NCHW4) {
using namespace resize;
Checker<Resize> checker(handle_cuda());
auto args = get_nchw4_args();
for (auto&& arg : args) {
checker.set_param(arg.param)
......@@ -113,6 +167,24 @@ TEST_F(CUDA, RESIZE_BACKWARD) {
param.format = param::Resize::Format::NCHW;
param.imode = imode;
checker.set_param(param);
checker.set_dtype(0, dtype::Float16());
checker.set_dtype(1, dtype::Float16());
checker.set_epsilon(1 + 1e-3);
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}});
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}});
checker.execs({{2, 5, 8, 5}, {2, 5, 4, 9}});
checker.execs({{2, 5, 4, 9}, {2, 5, 8, 5}});
}
for (auto imode : modes) {
Checker<ResizeBackward> checker(handle_cuda());
param::Resize param;
param.format = param::Resize::Format::NCHW;
param.imode = imode;
checker.set_param(param);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}});
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册