未验证 提交 5e1d0b5c 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] bugfix for bilinear_interp_v2_grad (#36160)

上级 1b1210ea
...@@ -1198,7 +1198,12 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, ...@@ -1198,7 +1198,12 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
} else if ("bicubic" == interp_method) { } else if ("bicubic" == interp_method) {
KeBicubicInterpFw<T><<<config.block_per_grid, 512, 0, #ifdef __HIPCC__
constexpr int thread_per_block = 256;
#else
constexpr int thread_per_block = 512;
#endif
KeBicubicInterpFw<T><<<config.block_per_grid, thread_per_block, 0,
ctx.cuda_device_context().stream()>>>( ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout); out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
...@@ -1606,9 +1611,11 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, ...@@ -1606,9 +1611,11 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0; const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false; bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false;
bool optimize_flag = false; bool optimize_flag = false;
#ifndef __HIPCC__
optimize_flag = (in_h < (out_h >> 6) && in_w < (out_w >> 6)) optimize_flag = (in_h < (out_h >> 6) && in_w < (out_w >> 6))
? true ? true
: ((in_h == 1 && in_w == 1) ? true : false); : ((in_h == 1 && in_w == 1) ? true : false);
#endif
if (optimize_flag & is_nchw) { if (optimize_flag & is_nchw) {
KeBilinearInterpBwShareMemory< KeBilinearInterpBwShareMemory<
...@@ -1623,7 +1630,12 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, ...@@ -1623,7 +1630,12 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
ratio_h, ratio_w, align_type_value, is_nchw); ratio_h, ratio_w, align_type_value, is_nchw);
} }
} else if ("bicubic" == interp_method) { } else if ("bicubic" == interp_method) {
KeBicubicInterpBw<T><<<config.block_per_grid, 512, 0, #ifdef __HIPCC__
constexpr int thread_per_block = 256;
#else
constexpr int thread_per_block = 512;
#endif
KeBicubicInterpBw<T><<<config.block_per_grid, thread_per_block, 0,
ctx.cuda_device_context().stream()>>>( ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册