From effd51c721606d873d90e41eb1d9c0617a9451bd Mon Sep 17 00:00:00 2001 From: ZZK <359521840@qq.com> Date: Fri, 16 Dec 2022 14:45:46 +0800 Subject: [PATCH] Fix bilinear interp fp16 diff (#49095) * cast to higher precision type to prevent fp16 diff problem * fix bilinear backward and add more unittest case --- .../kernels/gpu/interpolate_grad_kernel.cu | 67 ++++++++++--------- paddle/phi/kernels/gpu/interpolate_kernel.cu | 43 ++++++------ .../unittests/test_bilinear_interp_v2_op.py | 50 +++++++++++++- 3 files changed, 106 insertions(+), 54 deletions(-) diff --git a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu index cb1d959e30..e9f3643f66 100644 --- a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu @@ -251,7 +251,8 @@ __global__ void KeBilinearInterpBwShareMemory(T* in, float ratio_w, const float align_type_value, bool is_nchw) { - __shared__ T s_data[2][1024]; + using MT = typename phi::dtype::MPTypeTrait::Type; + __shared__ MT s_data[2][1024]; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; int in_chw = in_h * in_w * num_channels; @@ -263,18 +264,18 @@ __global__ void KeBilinearInterpBwShareMemory(T* in, int out_id_w = tid % out_chw; const int in_img_size = in_h * in_w; const int out_img_size = out_h * out_w; - T value = out[out_id_h * out_chw + out_id_w]; + MT value = static_cast(out[out_id_h * out_chw + out_id_w]); int channel_id = out_id_w / out_img_size; int out_img_idy = (out_id_w % out_img_size) / out_w; int out_img_idx = tid % out_w; int in_img_idx, in_img_idy, w_id, h_id; - T w1lambda, h1lambda, w2lambda, h2lambda; - T src_w = static_cast(ratio_w * (out_img_idx + align_type_value) - - align_type_value); - T src_h = static_cast(ratio_h * (out_img_idy + align_type_value) - - align_type_value); + MT w1lambda, h1lambda, w2lambda, h2lambda; + MT src_w = static_cast(ratio_w * (out_img_idx + align_type_value) - + align_type_value); + MT src_h = static_cast(ratio_h * (out_img_idy + align_type_value) - + align_type_value); PreCalculatorForLinearInterpInputIndex( &in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_w); @@ -289,8 +290,8 @@ __global__ void KeBilinearInterpBwShareMemory(T* in, int bot_right_index = input_index + h_id * in_w + w_id; int in_top_min_index, in_bot_min_index; - s_data[0][threadIdx.x] = static_cast(0); - s_data[1][threadIdx.x] = static_cast(0); + s_data[0][threadIdx.x] = static_cast(0); + s_data[1][threadIdx.x] = static_cast(0); int remain = nthreads - (tid & (-blockDim.x)); int in_top_max_index = phi::funcs::blockReduceMax(top_right_index, FINAL_MASK); @@ -327,9 +328,9 @@ __global__ void KeBilinearInterpBwShareMemory(T* in, if (threadIdx.x <= upper_limit_share_idx) { phi::CudaAtomicAdd(&in[in_top_min_index + threadIdx.x], - s_data[0][threadIdx.x]); + static_cast(s_data[0][threadIdx.x])); phi::CudaAtomicAdd(&in[in_bot_min_index + threadIdx.x], - s_data[1][threadIdx.x]); + static_cast(s_data[1][threadIdx.x])); } } } @@ -358,6 +359,7 @@ __global__ void KeBilinearInterpNCHWBw(T* in, int stride = blockDim.x * gridDim.x; int num_out = n * num_channels * out_h * out_w; int num_in = n * num_channels * in_h * in_w; + using MT = typename phi::dtype::MPTypeTrait::Type; for (; index < num_out; index += stride) { int index_tmp = index; @@ -367,29 +369,29 @@ __global__ void KeBilinearInterpNCHWBw(T* in, int nc = index_tmp / out_h; int h1, y_id; - T h1lambda, h0lambda; - T src_y = - static_cast(ratio_h * (h2 + align_type_value) - align_type_value); + MT h1lambda, h0lambda; + MT src_y = + static_cast(ratio_h * (h2 + align_type_value) - align_type_value); PreCalculatorForLinearInterpInputIndex( &h1, &y_id, &h1lambda, &h0lambda, src_y, in_h); int w1, x_id; - T w1lambda, w0lambda; - T src_x = - static_cast(ratio_w * (w2 + align_type_value) - align_type_value); + MT w1lambda, w0lambda; + MT src_x = + static_cast(ratio_w * (w2 + align_type_value) - align_type_value); PreCalculatorForLinearInterpInputIndex( &w1, &x_id, &w1lambda, &w0lambda, src_x, in_w); - T d2val = out[index]; + MT d2val = static_cast(out[index]); phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1), - h0lambda * w0lambda * d2val); + static_cast(h0lambda * w0lambda * d2val)); phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id), - h0lambda * w1lambda * d2val); + static_cast(h0lambda * w1lambda * d2val)); phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1), - h1lambda * w0lambda * d2val); + static_cast(h1lambda * w0lambda * d2val)); phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id), - h1lambda * w1lambda * d2val); + static_cast(h1lambda * w1lambda * d2val)); } } @@ -411,6 +413,7 @@ __global__ void KeBilinearInterpBw(T* in, int stride = blockDim.x * gridDim.x; int in_chw = in_h * in_w * num_channels; int nthreads = n * out_chw; + using MT = typename phi::dtype::MPTypeTrait::Type; for (; tid < nthreads; tid += stride) { auto out_id_divmod = divmods.output_w_div.Divmod(tid); @@ -424,28 +427,28 @@ __global__ void KeBilinearInterpBw(T* in, divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0]; int in_img_idx, in_img_idy, w_id, h_id; - T w1lambda, h1lambda, w2lambda, h2lambda; - T src_w = static_cast(ratio_w * (out_img_idx + align_type_value) - - align_type_value); - T src_h = static_cast(ratio_h * (out_img_idy + align_type_value) - - align_type_value); + MT w1lambda, h1lambda, w2lambda, h2lambda; + MT src_w = static_cast(ratio_w * (out_img_idx + align_type_value) - + align_type_value); + MT src_h = static_cast(ratio_h * (out_img_idy + align_type_value) - + align_type_value); PreCalculatorForLinearInterpInputIndex( &in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_w); PreCalculatorForLinearInterpInputIndex( &in_img_idy, &h_id, &h1lambda, &h2lambda, src_h, in_h); - T value = out[tid]; + MT value = static_cast(out[tid]); T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels + in_img_idx * num_channels + channel_id]; - phi::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value); + phi::CudaAtomicAdd(&in_pos[0], static_cast(h2lambda * w2lambda * value)); phi::CudaAtomicAdd(&in_pos[w_id * num_channels], - h2lambda * w1lambda * value); + static_cast(h2lambda * w1lambda * value)); phi::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels], - h1lambda * w2lambda * value); + static_cast(h1lambda * w2lambda * value)); phi::CudaAtomicAdd( &in_pos[h_id * in_w * num_channels + w_id * num_channels], - h1lambda * w1lambda * value); + static_cast(h1lambda * w1lambda * value)); } } diff --git a/paddle/phi/kernels/gpu/interpolate_kernel.cu b/paddle/phi/kernels/gpu/interpolate_kernel.cu index 8ca24b3e4f..2510ff8a54 100644 --- a/paddle/phi/kernels/gpu/interpolate_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_kernel.cu @@ -209,6 +209,7 @@ __global__ void KeBilinearInterpFw(const T* in, const float ratio_w, const float align_type_value, funcs::FastDivModForInterpolate divmods) { + using MT = typename phi::dtype::MPTypeTrait::Type; int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; @@ -225,11 +226,11 @@ __global__ void KeBilinearInterpFw(const T* in, divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0]; int in_img_idx, in_img_idy, h_id, w_id; - T h1lambda, w1lambda, h2lambda, w2lambda; - T src_w = static_cast(ratio_w * (out_img_idx + align_type_value) - - align_type_value); - T src_h = static_cast(ratio_h * (out_img_idy + align_type_value) - - align_type_value); + MT h1lambda, w1lambda, h2lambda, w2lambda; + MT src_w = static_cast(ratio_w * (out_img_idx + align_type_value) - + align_type_value); + MT src_h = static_cast(ratio_h * (out_img_idy + align_type_value) - + align_type_value); PreCalculatorForLinearInterpInputIndex( &in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w); @@ -241,12 +242,13 @@ __global__ void KeBilinearInterpFw(const T* in, &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + in_img_idx * num_channels + channel_id]; out[tid] = - h2lambda * - (w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) + + h2lambda * (w2lambda * static_cast(in_pos[0]) + + w1lambda * static_cast(in_pos[w_id * num_channels])) + h1lambda * - (w2lambda * in_pos[h_id * in_img_w * num_channels] + - w1lambda * - in_pos[h_id * in_img_w * num_channels + w_id * num_channels]); + (w2lambda * + static_cast(in_pos[h_id * in_img_w * num_channels]) + + w1lambda * static_cast(in_pos[h_id * in_img_w * num_channels + + w_id * num_channels])); } } @@ -261,17 +263,18 @@ __global__ void KeBilinearInterpNCHWFw(const T* in, const float ratio_h, const float ratio_w, const float align_type_value) { + using MT = typename phi::dtype::MPTypeTrait::Type; int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x; int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y; int nc_id = threadIdx.z + blockIdx.z * blockDim.z; int nc_stride = blockDim.z * gridDim.z; int in_img_idx, in_img_idy, h_id, w_id; - T h1lambda, w1lambda, h2lambda, w2lambda; - T src_w = static_cast(ratio_w * (out_img_idx + align_type_value) - - align_type_value); - T src_h = static_cast(ratio_h * (out_img_idy + align_type_value) - - align_type_value); + MT h1lambda, w1lambda, h2lambda, w2lambda; + MT src_w = static_cast(ratio_w * (out_img_idx + align_type_value) - + align_type_value); + MT src_h = static_cast(ratio_h * (out_img_idy + align_type_value) - + align_type_value); PreCalculatorForLinearInterpInputIndex( &in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w); @@ -288,10 +291,12 @@ __global__ void KeBilinearInterpNCHWFw(const T* in, if (out_img_idx < out_img_w && out_img_idy < out_img_h) { while (nc_id < nc) { const T* in_pos = &in[in_index]; - out[out_index] = - h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + - h1lambda * (w2lambda * in_pos[h_id * in_img_w] + - w1lambda * in_pos[h_id * in_img_w + w_id]); + out[out_index] = static_cast( + h2lambda * (w2lambda * static_cast(in_pos[0]) + + w1lambda * static_cast(in_pos[w_id])) + + h1lambda * + (w2lambda * static_cast(in_pos[h_id * in_img_w]) + + w1lambda * static_cast(in_pos[h_id * in_img_w + w_id]))); in_index += in_index_stride; out_index += out_index_stride; diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_v2_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_v2_op.py index 01d5759500..ed7b1375e5 100755 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_v2_op.py @@ -733,7 +733,7 @@ class TestBilinearInterpOpAPI_dy4(unittest.TestCase): @unittest.skipIf( not fluid.core.is_compiled_with_cuda(), "core is not compiled with CUDA" ) -class TestBilinearInterpOpForFloat16(unittest.TestCase): +class TestBilinearInterpOpZoomOutForFloat16(unittest.TestCase): def init_test_case(self): self.interp_method = 'bilinear' self.input_shape = [2, 3, 5, 5] @@ -768,8 +768,52 @@ class TestBilinearInterpOpForFloat16(unittest.TestCase): y_np_1, x_g_np_1 = self.check_main(x_np, 'float16') y_np_2, x_g_np_2 = self.check_main(x_np, 'float32') - np.testing.assert_allclose(y_np_1, y_np_2) - np.testing.assert_allclose(x_g_np_1, x_g_np_2) + np.testing.assert_allclose(y_np_1, y_np_2, atol=1e-3, rtol=1e-3) + # Since atomicAdd half will bring some diff, here we relax tolerance to 1e-2. + np.testing.assert_allclose(x_g_np_1, x_g_np_2, atol=1e-2, rtol=1e-2) + + +@unittest.skipIf( + not fluid.core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestBilinearInterpOpZoomInForFloat16(unittest.TestCase): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 5] + self.out_size = np.array([10, 10]).astype("int32") + self.align_corners = True + self.align_mode = 1 + self.data_layout = 'NCHW' + + def check_main(self, x_np, dtype): + paddle.disable_static() + x_np = x_np.astype(dtype) + x = paddle.to_tensor(x_np) + x.stop_gradient = False + y = interpolate( + x, + size=self.out_size.tolist(), + mode=self.interp_method, + align_mode=self.align_mode, + align_corners=self.align_corners, + data_format=self.data_layout, + ) + x_g = paddle.grad(y, x) + y_np = y[0].numpy().astype('float32') + x_g_np = x_g[0].numpy().astype('float32') + paddle.enable_static() + return y_np, x_g_np + + def test_main(self): + self.init_test_case() + x_np = np.random.random(self.input_shape).astype("float16") + + y_np_1, x_g_np_1 = self.check_main(x_np, 'float16') + y_np_2, x_g_np_2 = self.check_main(x_np, 'float32') + + np.testing.assert_allclose(y_np_1, y_np_2, atol=1e-3, rtol=1e-3) + # Since atomicAdd half will bring some diff, here we relax tolerance to 1e-2. + np.testing.assert_allclose(x_g_np_1, x_g_np_2, atol=1e-2, rtol=1e-2) if __name__ == "__main__": -- GitLab