From 14dd68e1d7d05552f0f5de02adb5de76271f71d0 Mon Sep 17 00:00:00 2001 From: liuruyan <44316842+liuruyan@users.noreply.github.com> Date: Thu, 2 Feb 2023 10:15:18 +0800 Subject: [PATCH] Fix the FP16 precision problem of add_n. (#50129) --- paddle/phi/kernels/gpu/add_n_kernel.cu | 20 +++--- .../fluid/tests/unittests/test_add_n_op.py | 64 +++++++++++++++++++ 2 files changed, 75 insertions(+), 9 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_add_n_op.py diff --git a/paddle/phi/kernels/gpu/add_n_kernel.cu b/paddle/phi/kernels/gpu/add_n_kernel.cu index f32ba597f5b..69bc248a7e2 100644 --- a/paddle/phi/kernels/gpu/add_n_kernel.cu +++ b/paddle/phi/kernels/gpu/add_n_kernel.cu @@ -14,11 +14,10 @@ #include "paddle/phi/kernels/add_n_kernel.h" -#include "paddle/phi/kernels/impl/add_n_kernel_impl.h" - #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/memcpy.h" - +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/impl/add_n_kernel_impl.h" namespace phi { #define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) @@ -38,16 +37,18 @@ __global__ void Sum2CUDAKernel(const T *in_0, template __global__ void SumArrayCUDAKernel( T **in, T *out, int64_t N, size_t in_size, bool read_dst) { + using MPType = typename phi::dtype::MPTypeTrait::Type; int id = blockIdx.x * blockDim.x + threadIdx.x; while (id < N) { - T total(read_dst ? out[id] : static_cast(0)); + MPType total(read_dst ? static_cast(out[id]) + : static_cast(0)); for (int i = 0; i < in_size; ++i) { const T *tmp = in[i]; if (tmp) { - total += tmp[id]; + total += static_cast(tmp[id]); } } - out[id] = total; + out[id] = static_cast(total); id += blockDim.x * gridDim.x; } } @@ -116,11 +117,12 @@ void AddNKernel(const Context &dev_ctx, int64_t length_0 = in_0.numel(); int64_t length_1 = in_1.numel(); if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) { + using MPType = typename phi::dtype::MPTypeTrait::Type; auto result = EigenVector::Flatten(*out); auto &place = *dev_ctx.eigen_device(); - auto in_0_e = EigenVector::Flatten(in_0); - auto in_1_e = EigenVector::Flatten(in_1); - result.device(place) = in_0_e + in_1_e; + auto in_0_e = EigenVector::Flatten(in_0).template cast(); + auto in_1_e = EigenVector::Flatten(in_1).template cast(); + result.device(place) = (in_0_e + in_1_e).template cast(); } else if (length_0 && in_0.IsInitialized()) { auto result = EigenVector::Flatten(*out); auto &place = *dev_ctx.eigen_device(); diff --git a/python/paddle/fluid/tests/unittests/test_add_n_op.py b/python/paddle/fluid/tests/unittests/test_add_n_op.py new file mode 100644 index 00000000000..3ca485b1419 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_add_n_op.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np + +import paddle + + +class TestAddnOp(unittest.TestCase): + def setUp(self): + np.random.seed(20) + l = 32 + self.x_np = np.random.random([l, 16, 256]) + + def check_main(self, x_np, dtype, axis=None): + paddle.disable_static() + x = [] + for i in range(x_np.shape[0]): + val = paddle.to_tensor(x_np[i].astype(dtype)) + val.stop_gradient = False + x.append(val) + y = paddle.add_n(x) + x_g = paddle.grad(y, x) + y_np = y.numpy().astype('float32') + x_g_np = [] + for val in x_g: + x_g_np.append(val.numpy().astype('float32')) + paddle.enable_static() + return y_np, x_g_np + + def test_add_n_fp16(self): + if not paddle.is_compiled_with_cuda(): + return + y_np_16, x_g_np_16 = self.check_main(self.x_np, 'float16') + y_np_32, x_g_np_32 = self.check_main(self.x_np, 'float32') + + np.testing.assert_allclose(y_np_16, y_np_32, rtol=1e-03) + for i in range(len(x_g_np_32)): + np.testing.assert_allclose(x_g_np_16[i], x_g_np_32[i], rtol=1e-03) + + def test_add_n_api(self): + if not paddle.is_compiled_with_cuda(): + return + + y_np_32, x_g_np_32 = self.check_main(self.x_np, 'float32') + y_np_gt = np.sum(self.x_np, axis=0).astype('float32') + + np.testing.assert_allclose(y_np_32, y_np_gt, rtol=1e-06) + + +if __name__ == "__main__": + unittest.main() -- GitLab