From c576169b8c4b18d4a714133e459c0380dacf84b9 Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Tue, 28 Sep 2021 10:14:19 +0800 Subject: [PATCH] [cherry-pick] [ROCM] bugfix for bilinear_interp_v2_grad (#36160) #36161 ATT, cherry-pick #36160 --- paddle/fluid/operators/interpolate_v2_op.cu | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/interpolate_v2_op.cu b/paddle/fluid/operators/interpolate_v2_op.cu index 6f8b89ce64..fe92281356 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cu +++ b/paddle/fluid/operators/interpolate_v2_op.cu @@ -1198,7 +1198,12 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("bicubic" == interp_method) { - KeBicubicInterpFw<<<<>>( input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); @@ -1606,9 +1611,11 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0; bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false; bool optimize_flag = false; +#ifndef __HIPCC__ optimize_flag = (in_h < (out_h >> 6) && in_w < (out_w >> 6)) ? true : ((in_h == 1 && in_w == 1) ? true : false); +#endif if (optimize_flag & is_nchw) { KeBilinearInterpBwShareMemory< @@ -1623,7 +1630,12 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, ratio_h, ratio_w, align_type_value, is_nchw); } } else if ("bicubic" == interp_method) { - KeBicubicInterpBw<<<<>>( input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); -- GitLab