提交 0d7ace15 编写于 作者: M Megvii Engine Team

fix(mgb/dnn): suport fp16 for resize nhwc

GitOrigin-RevId: bb04d2a801b5cbe9c8704ce922842231c4158a4c
上级 cfed86f9
...@@ -67,8 +67,12 @@ void ResizeBackward::check_exec( ...@@ -67,8 +67,12 @@ void ResizeBackward::check_exec(
auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad); auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
megdnn_assert( megdnn_assert(
param().format == Param::Format::NCHW && grad.dtype == dtype::Float32(), (param().format == Param::Format::NCHW ||
"Backward resize only supports Float32 and NCHW."); param().format == Param::Format::NHWC) &&
(grad.dtype == dtype::Float32() DNN_INC_FLOAT16(
|| grad.dtype == dtype::Float16())),
"Backward resize only supports NCHW and NHWC, the dtype only supports "
"Float32 and Float16.");
} }
std::pair<float, int> ResizeBase::get_cubic_coord(float scale, int idx) { std::pair<float, int> ResizeBase::get_cubic_coord(float scale, int idx) {
......
...@@ -11,26 +11,56 @@ void ResizeBackwardImpl::exec( ...@@ -11,26 +11,56 @@ void ResizeBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
check_exec(diff.layout, grad.layout, workspace.size); check_exec(diff.layout, grad.layout, workspace.size);
auto stream = cuda_stream(this->handle()); auto stream = cuda_stream(this->handle());
auto N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], bool is_nhwc = param().format == param::Resize::Format::NHWC;
IW = grad.layout.shape[3], OH = diff.layout.shape[2], size_t N, C, IH, IW, OH, OW;
OW = diff.layout.shape[3]; if (is_nhwc) {
if (param().imode != Param::InterpolationMode::LINEAR &&
is_nhwc_contig_wc(grad.layout)) {
megdnn_assert(
0,
"unsupport mode in resizeBackward, only support param().imode = "
"LINEAR");
}
N = grad.layout.shape[0];
C = grad.layout.shape[3];
IH = grad.layout.shape[1];
IW = grad.layout.shape[2];
OH = diff.layout.shape[1];
OW = diff.layout.shape[2];
} else {
N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2],
IW = grad.layout.shape[3], OH = diff.layout.shape[2], OW = diff.layout.shape[3];
}
size_t max_batch_x_channel = max_batch_x_channel_size(); size_t max_batch_x_channel = max_batch_x_channel_size();
dt_float32* diff_ptr = diff.ptr<dt_float32>();
dt_float32* grad_ptr = grad.ptr<dt_float32>();
size_t max_batch_size = max_batch_x_channel / C; size_t max_batch_size = max_batch_x_channel / C;
while (N > 0) { while (N > 0) {
size_t curr_batch_size = N > max_batch_size ? max_batch_size : N; size_t curr_batch_size = N > max_batch_size ? max_batch_size : N;
resize::backward_data_proxy( switch (grad.layout.dtype.enumv()) {
resize::get_imode(param().imode), diff_ptr, grad_ptr, curr_batch_size, #define cb(_t) \
C, IH, IW, OH, OW, stream); case DTypeTrait<_t>::enumv: { \
typedef DTypeTrait<_t>::ctype ct; \
if (N <= max_batch_size) { ct* diff_ptr = diff.ptr<ct>(); \
break; ct* grad_ptr = grad.ptr<ct>(); \
} else { resize::backward_data_proxy( \
N -= max_batch_size; is_nhwc, resize::get_imode(param().imode), diff_ptr, grad_ptr, \
diff_ptr += curr_batch_size * diff.layout.stride[0]; curr_batch_size, C, IH, IW, OH, OW, stream); \
grad_ptr += curr_batch_size * grad.layout.stride[0]; if (N <= max_batch_size) { \
return; \
} else { \
N -= max_batch_size; \
diff_ptr += curr_batch_size * diff.layout.stride[0]; \
grad_ptr += curr_batch_size * grad.layout.stride[0]; \
} \
break; \
}
cb(megdnn::dtype::Float32);
DNN_INC_FLOAT16(cb(megdnn::dtype::Float16));
default:
megdnn_throw(ssprintf(
"unsupported dtype: %s in resize backward",
grad.layout.dtype.name()));
} }
#undef cb
} }
} }
......
#include "src/common/rounding_converter.cuh"
#include "src/cuda/resize/common.cuh" #include "src/cuda/resize/common.cuh"
#include "src/cuda/resize/common.h" #include "src/cuda/resize/common.h"
...@@ -11,9 +12,52 @@ namespace megdnn { ...@@ -11,9 +12,52 @@ namespace megdnn {
namespace cuda { namespace cuda {
namespace resize { namespace resize {
template <typename ctype, typename OutputConverter>
__global__ void resize_bwd_nhwc_kernel(
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) {
OutputConverter output_converter;
int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
hidden += n * C * OH * OW;
dst += n * C * IH * IW;
if (ow < OW && oh < OH) {
float alphah, alphaw;
int ih0, iw0;
get_origin_coord(scale_h, IH, oh, alphah, ih0);
get_origin_coord(scale_w, IW, ow, alphaw, iw0);
int ih1 = ih0 + 1;
int iw1 = iw0 + 1;
float nalphaw = 1.0f - alphaw;
float nalphah = 1.0f - alphah;
for (int c = 0; c < C; ++c) {
atomic_add(
dst + (ih0 * IW + iw0) * C + c,
output_converter(
hidden[(oh * OW + ow) * C + c] * nalphaw * nalphah));
atomic_add(
dst + (ih0 * IW + iw1) * C + c,
output_converter(
hidden[(oh * OW + ow) * C + c] * alphaw * nalphah));
atomic_add(
dst + (ih1 * IW + iw0) * C + c,
output_converter(
hidden[(oh * OW + ow) * C + c] * nalphaw * alphah));
atomic_add(
dst + (ih1 * IW + iw1) * C + c,
output_converter(hidden[(oh * OW + ow) * C + c] * alphaw * alphah));
}
}
}
template <typename ctype, typename OutputConverter>
__global__ void resize_bwd_linear_kernel( __global__ void resize_bwd_linear_kernel(
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) { float scale_h, float scale_w) {
OutputConverter output_converter;
int n = blockIdx.z; int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x; int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y; int oh = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -31,19 +75,29 @@ __global__ void resize_bwd_linear_kernel( ...@@ -31,19 +75,29 @@ __global__ void resize_bwd_linear_kernel(
float nalphaw = 1.0f - alphaw; float nalphaw = 1.0f - alphaw;
float nalphah = 1.0f - alphah; float nalphah = 1.0f - alphah;
for (int c = 0; c < C; ++c) { for (int c = 0; c < C; ++c) {
atomicAdd(dst + ih0 * IW + iw0, hidden[oh * OW + ow] * nalphaw * nalphah); atomic_add(
atomicAdd(dst + ih0 * IW + iw1, hidden[oh * OW + ow] * alphaw * nalphah); dst + ih0 * IW + iw0,
atomicAdd(dst + ih1 * IW + iw0, hidden[oh * OW + ow] * nalphaw * alphah); output_converter(hidden[oh * OW + ow] * nalphaw * nalphah));
atomicAdd(dst + ih1 * IW + iw1, hidden[oh * OW + ow] * alphaw * alphah); atomic_add(
dst + ih0 * IW + iw1,
output_converter(hidden[oh * OW + ow] * alphaw * nalphah));
atomic_add(
dst + ih1 * IW + iw0,
output_converter(hidden[oh * OW + ow] * nalphaw * alphah));
atomic_add(
dst + ih1 * IW + iw1,
output_converter(hidden[oh * OW + ow] * alphaw * alphah));
hidden += OH * OW; hidden += OH * OW;
dst += IH * IW; dst += IH * IW;
} }
} }
} }
template <typename ctype, typename OutputConverter>
__global__ void resize_bwd_nearest_kernel( __global__ void resize_bwd_nearest_kernel(
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) { float scale_h, float scale_w) {
OutputConverter output_converter;
int n = blockIdx.z; int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x; int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y; int oh = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -54,16 +108,18 @@ __global__ void resize_bwd_nearest_kernel( ...@@ -54,16 +108,18 @@ __global__ void resize_bwd_nearest_kernel(
int iw = get_nearest_src(scale_w, IW, ow); int iw = get_nearest_src(scale_w, IW, ow);
for (int c = 0; c < C; ++c) { for (int c = 0; c < C; ++c) {
atomicAdd(dst + ih * IW + iw, hidden[oh * OW + ow]); atomic_add(dst + ih * IW + iw, output_converter(hidden[oh * OW + ow]));
hidden += OH * OW; hidden += OH * OW;
dst += IH * IW; dst += IH * IW;
} }
} }
} }
template <typename ctype, typename OutputConverter>
__global__ void resize_bwd_cubic_kernel( __global__ void resize_bwd_cubic_kernel(
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
float scale_h, float scale_w) { float scale_h, float scale_w) {
OutputConverter output_converter;
int n = blockIdx.z; int n = blockIdx.z;
int ow = blockIdx.x * blockDim.x + threadIdx.x; int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y; int oh = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -85,9 +141,10 @@ __global__ void resize_bwd_cubic_kernel( ...@@ -85,9 +141,10 @@ __global__ void resize_bwd_cubic_kernel(
int ih = saturate(ih0 + kh, 0, IH - 1); int ih = saturate(ih0 + kh, 0, IH - 1);
for (int kw = 0; kw < ksize; kw++) { for (int kw = 0; kw < ksize; kw++) {
int iw = saturate(iw0 + kw, 0, IW - 1); int iw = saturate(iw0 + kw, 0, IW - 1);
atomicAdd( atomic_add(
dst + ih * IW + iw, dst + ih * IW + iw,
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]); output_converter(
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]));
} }
} }
...@@ -97,41 +154,59 @@ __global__ void resize_bwd_cubic_kernel( ...@@ -97,41 +154,59 @@ __global__ void resize_bwd_cubic_kernel(
} }
} }
template <typename ctype>
void backward_data_proxy( void backward_data_proxy(
InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH, bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N,
int IW, int OH, int OW, cudaStream_t stream) { int C, int IH, int IW, int OH, int OW, cudaStream_t stream) {
const int BY = 16, BX = 32; const int BY = 16, BX = 32;
{ {
dim3 threads(BX, BY); dim3 threads(BX, BY);
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, N); dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, N);
cuda_check(cudaMemsetAsync(grad, 0, sizeof(float) * N * C * IH * IW, stream)); cuda_check(cudaMemsetAsync(grad, 0, sizeof(ctype) * N * C * IH * IW, stream));
float scale_h = static_cast<float>(OH) / IH; float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW; float scale_w = static_cast<float>(OW) / IW;
switch (imode) { if (is_nhwc) {
case InterpolationMode::INTER_LINEAR: { resize_bwd_nhwc_kernel<ctype, rounding::RoundingConverter<ctype>>
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>( <<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break; } else {
} switch (imode) {
case InterpolationMode::INTER_NEAREST: { case InterpolationMode::INTER_LINEAR: {
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>( resize_bwd_linear_kernel<ctype, rounding::RoundingConverter<ctype>>
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); <<<blocks, threads, 0, stream>>>(
break; diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
} break;
case InterpolationMode::INTER_CUBIC: { }
resize_bwd_cubic_kernel<<<blocks, threads, 0, stream>>>( case InterpolationMode::INTER_NEAREST: {
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); resize_bwd_nearest_kernel<ctype, rounding::RoundingConverter<ctype>>
break; <<<blocks, threads, 0, stream>>>(
} diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
default: { break;
megdnn_throw("unsupported interpolation mode"); }
break; case InterpolationMode::INTER_CUBIC: {
resize_bwd_cubic_kernel<ctype, rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w);
break;
}
default: {
megdnn_throw("unsupported interpolation mode");
break;
}
} }
} }
} }
after_kernel_launch(); after_kernel_launch();
} }
#define INST(ctype) \
template void backward_data_proxy( \
bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \
int, cudaStream_t);
INST(dt_float32);
DNN_INC_FLOAT16(INST(dt_float16));
#undef INST
} // namespace resize } // namespace resize
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
...@@ -20,9 +20,10 @@ void forward_proxy_nchw4( ...@@ -20,9 +20,10 @@ void forward_proxy_nchw4(
const ctype* src, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, const ctype* src, ctype* dst, int N, int C, int IH, int IW, int OH, int OW,
cudaStream_t stream); cudaStream_t stream);
template <typename ctype>
void backward_data_proxy( void backward_data_proxy(
InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH, bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N,
int IW, int OH, int OW, cudaStream_t stream); int C, int IH, int IW, int OH, int OW, cudaStream_t stream);
} // namespace resize } // namespace resize
} // namespace cuda } // namespace cuda
......
...@@ -148,6 +148,11 @@ void ResizeImpl::exec( ...@@ -148,6 +148,11 @@ void ResizeImpl::exec(
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float32>(), is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float32>(),
dst.ptr<dt_float32>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, dst.ptr<dt_float32>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, stream); S_IH, S_IW, stream);
} else if (src.layout.dtype == dtype::Float16{}) {
resize::forward_proxy(
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float16>(),
dst.ptr<dt_float16>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, stream);
} else if (src.layout.dtype == dtype::Uint8()) { } else if (src.layout.dtype == dtype::Uint8()) {
resize::forward_proxy( resize::forward_proxy(
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_uint8>(), is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_uint8>(),
......
...@@ -298,6 +298,7 @@ void forward_proxy_nchw4( ...@@ -298,6 +298,7 @@ void forward_proxy_nchw4(
INST(float) INST(float)
INST(uint8_t) INST(uint8_t)
INST(int8_t) INST(int8_t)
DNN_INC_FLOAT16(INST(dt_float16))
#undef INST #undef INST
#define INST(ctype) \ #define INST(ctype) \
......
...@@ -387,40 +387,53 @@ void ResizeImpl::exec( ...@@ -387,40 +387,53 @@ void ResizeImpl::exec(
} }
} }
void ResizeBackwardImpl::exec( // ***************************Backward*************************** //
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { template <typename ctype>
check_exec(diff.layout, grad.layout, workspace.size); void ResizeBackwardImpl::kern_naive(
megdnn_assert( bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N,
param().format == param::Resize::Format::NCHW, "invalid resize format"); int C, int IH, int IW, int OH, int OW) {
const int N = grad.layout.shape[0], C = grad.layout.shape[1],
IH = grad.layout.shape[2], IW = grad.layout.shape[3];
const int OH = diff.layout.shape[2], OW = diff.layout.shape[3];
const float* hptr_ = diff.ptr<dt_float32>();
float* sptr_ = grad.ptr<dt_float32>();
float scale_h = static_cast<float>(OH) / IH; float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW; float scale_w = static_cast<float>(OW) / IW;
rounding::RoundingConverter<ctype> output_converter;
auto kern = [=]() { auto kern = [=]() {
auto hptr = hptr_; auto hptr = diff;
auto sptr = sptr_; auto sptr = grad;
std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); std::memset(sptr, 0, sizeof(ctype) * N * C * IH * IW);
rep(n, N) { rep(n, N) {
rep(oh, OH) rep(ow, OW) { rep(oh, OH) rep(ow, OW) {
switch (param().imode) { switch (imode) {
case InterpolationMode::INTER_LINEAR: { case InterpolationMode::INTER_LINEAR: {
int ih0, ih1, iw0, iw1; int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1; float ah0, ah1, aw0, aw1;
std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( std::tie(ah0, ih0, ah1, ih1) =
param().imode, scale_h, IH, oh); get_nearest_linear_coord(imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( std::tie(aw0, iw0, aw1, iw1) =
param().imode, scale_w, IW, ow); get_nearest_linear_coord(imode, scale_w, IW, ow);
rep(c, C) { if (is_nhwc) {
float hidden = hptr[c * OH * OW + oh * OW + ow]; rep(c, C) {
sptr[c * IH * IW + ih0 * IW + iw0] += ah0 * aw0 * hidden; sptr[(ih0 * IW + iw0) * C + c] += output_converter(
sptr[c * IH * IW + ih1 * IW + iw0] += ah1 * aw0 * hidden; hptr[(oh * OW + ow) * C + c] * ah0 * aw0);
sptr[c * IH * IW + ih0 * IW + iw1] += ah0 * aw1 * hidden; sptr[(ih0 * IW + iw1) * C + c] += output_converter(
sptr[c * IH * IW + ih1 * IW + iw1] += ah1 * aw1 * hidden; hptr[(oh * OW + ow) * C + c] * ah0 * aw1);
sptr[(ih1 * IW + iw0) * C + c] += output_converter(
hptr[(oh * OW + ow) * C + c] * ah1 * aw0);
sptr[(ih1 * IW + iw1) * C + c] += output_converter(
hptr[(oh * OW + ow) * C + c] * ah1 * aw1);
}
} else {
rep(c, C) {
float hidden = hptr[c * OH * OW + oh * OW + ow];
sptr[c * IH * IW + ih0 * IW + iw0] +=
output_converter(ah0 * aw0 * hidden);
sptr[c * IH * IW + ih1 * IW + iw0] +=
output_converter(ah1 * aw0 * hidden);
sptr[c * IH * IW + ih0 * IW + iw1] +=
output_converter(ah0 * aw1 * hidden);
sptr[c * IH * IW + ih1 * IW + iw1] +=
output_converter(ah1 * aw1 * hidden);
}
} }
break; break;
} }
...@@ -429,7 +442,7 @@ void ResizeBackwardImpl::exec( ...@@ -429,7 +442,7 @@ void ResizeBackwardImpl::exec(
auto iw = get_nearest_src(scale_w, IW, ow); auto iw = get_nearest_src(scale_w, IW, ow);
rep(c, static_cast<int>(C)) { rep(c, static_cast<int>(C)) {
sptr[c * IH * IW + ih * IW + iw] += sptr[c * IH * IW + ih * IW + iw] +=
hptr[c * OH * OW + oh * OW + ow]; output_converter(hptr[c * OH * OW + oh * OW + ow]);
} }
break; break;
} }
...@@ -452,9 +465,9 @@ void ResizeBackwardImpl::exec( ...@@ -452,9 +465,9 @@ void ResizeBackwardImpl::exec(
int h = saturate<int, int>(ih0 + kh, 0, IH - 1); int h = saturate<int, int>(ih0 + kh, 0, IH - 1);
rep(kw, ksize) { rep(kw, ksize) {
int w = saturate<int, int>(iw0 + kw, 0, IW - 1); int w = saturate<int, int>(iw0 + kw, 0, IW - 1);
sptr[c * IH * IW + h * IW + w] += sptr[c * IH * IW + h * IW + w] += output_converter(
hptr[c * OH * OW + oh * OW + ow] * hptr[c * OH * OW + oh * OW + ow] *
h_coeff[kh] * w_coeff[kw]; h_coeff[kh] * w_coeff[kw]);
} }
} }
} }
...@@ -473,4 +486,59 @@ void ResizeBackwardImpl::exec( ...@@ -473,4 +486,59 @@ void ResizeBackwardImpl::exec(
MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
} }
#define INST(ctype) \
template void ResizeBackwardImpl::kern_naive( \
bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \
int);
INST(dt_float32);
DNN_INC_FLOAT16(INST(dt_float16));
#undef INST
void ResizeBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
check_exec(diff.layout, grad.layout, workspace.size);
megdnn_assert(
param().format == param::Resize::Format::NCHW ||
param().format == param::Resize::Format::NHWC,
"invalid resize format");
size_t N, C, IH, IW, OH, OW;
bool is_nhwc = param().format == param::Resize::Format::NHWC;
if (is_nhwc) {
if (param().imode != Param::InterpolationMode::LINEAR &&
is_nhwc_contig_wc(grad.layout)) {
megdnn_assert(
0,
"unsupport mode in resizeBackward, only support param().imode = "
"LINEAR");
}
N = grad.layout.shape[0];
C = grad.layout.shape[3];
IH = grad.layout.shape[1];
IW = grad.layout.shape[2];
OH = diff.layout.shape[1];
OW = diff.layout.shape[2];
} else {
N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2],
IW = grad.layout.shape[3];
OH = diff.layout.shape[2], OW = diff.layout.shape[3];
}
switch (grad.layout.dtype.enumv()) {
#define cb(_t) \
case DTypeTrait<_t>::enumv: { \
typedef DTypeTrait<_t>::ctype ct; \
ct* diff_ptr = diff.ptr<ct>(); \
ct* grad_ptr = grad.ptr<ct>(); \
ResizeBackwardImpl::kern_naive( \
is_nhwc, param().imode, diff_ptr, grad_ptr, N, C, IH, IW, OH, OW); \
break; \
}
cb(megdnn::dtype::Float32);
DNN_INC_FLOAT16(cb(megdnn::dtype::Float16));
default:
megdnn_throw(ssprintf(
"unsupported dtype: %s in resize backward",
grad.layout.dtype.name()));
}
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -75,6 +75,12 @@ public: ...@@ -75,6 +75,12 @@ public:
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override {
return 0; return 0;
} }
private:
template <typename ctype>
void kern_naive(
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad,
int N, int C, int IH, int IW, int OH, int OW);
}; };
} // namespace naive } // namespace naive
......
...@@ -61,13 +61,67 @@ TEST_F(CUDA, RESIZE_FORWARD) { ...@@ -61,13 +61,67 @@ TEST_F(CUDA, RESIZE_FORWARD) {
.set_epsilon(1) .set_epsilon(1)
.execs({arg.src, arg.dst}); .execs({arg.src, arg.dst});
} }
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
}
}
TEST_F(CUDA, RESIZE_NHWC) {
using namespace resize;
std::vector<TestArg> args;
param::Resize param;
param.format = param::Resize::Format::NHWC;
param.imode = param::Resize::InterpolationMode::LINEAR;
args.emplace_back(param, TensorShape{1, 1, 4, 5}, TensorShape{1, 1, 8, 5});
args.emplace_back(param, TensorShape{2, 6, 4, 5}, TensorShape{2, 3, 8, 5});
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2});
Checker<ResizeBackward> checkerBackward(handle_cuda());
for (auto&& arg : args) {
checkerBackward.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
for (auto&& arg : args) {
checkerBackward.set_param(arg.param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
Checker<ResizeForward> checkerForward(handle_cuda());
for (auto&& arg : args) {
checkerForward.set_param(arg.param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
}
for (auto&& arg : args) {
checkerForward.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_epsilon(1e-3)
.execs({arg.src, arg.dst});
} }
} }
TEST_F(CUDA, RESIZE_NCHW4) { TEST_F(CUDA, RESIZE_NCHW4) {
using namespace resize; using namespace resize;
Checker<Resize> checker(handle_cuda()); Checker<Resize> checker(handle_cuda());
auto args = get_nchw4_args(); auto args = get_nchw4_args();
for (auto&& arg : args) { for (auto&& arg : args) {
checker.set_param(arg.param) checker.set_param(arg.param)
...@@ -113,6 +167,24 @@ TEST_F(CUDA, RESIZE_BACKWARD) { ...@@ -113,6 +167,24 @@ TEST_F(CUDA, RESIZE_BACKWARD) {
param.format = param::Resize::Format::NCHW; param.format = param::Resize::Format::NCHW;
param.imode = imode; param.imode = imode;
checker.set_param(param); checker.set_param(param);
checker.set_dtype(0, dtype::Float16());
checker.set_dtype(1, dtype::Float16());
checker.set_epsilon(1 + 1e-3);
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}});
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}});
checker.execs({{2, 5, 8, 5}, {2, 5, 4, 9}});
checker.execs({{2, 5, 4, 9}, {2, 5, 8, 5}});
}
for (auto imode : modes) {
Checker<ResizeBackward> checker(handle_cuda());
param::Resize param;
param.format = param::Resize::Format::NCHW;
param.imode = imode;
checker.set_param(param);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}}); checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}});
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}}); checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册