From b94fe95a2119872aba08826778966f75d27e426d Mon Sep 17 00:00:00 2001 From: Infinity_lee Date: Fri, 17 Mar 2023 17:59:48 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=20No58=E3=80=91fix=20atan2?= =?UTF-8?q?=20(#51185)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/gpu/atan2_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/atan2_kernel.cu | 1 + paddle/phi/kernels/impl/atan2_kernel_impl.h | 1 + .../fluid/tests/unittests/test_atan2_op.py | 31 ++++++++++++++++++- 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/atan2_grad_kernel.cu b/paddle/phi/kernels/gpu/atan2_grad_kernel.cu index 7e68610af1..0e0b4329fa 100644 --- a/paddle/phi/kernels/gpu/atan2_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/atan2_grad_kernel.cu @@ -22,4 +22,5 @@ PD_REGISTER_KERNEL(atan2_grad, phi::Atan2GradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/atan2_kernel.cu b/paddle/phi/kernels/gpu/atan2_kernel.cu index d77c21473b..ed66318fc2 100644 --- a/paddle/phi/kernels/gpu/atan2_kernel.cu +++ b/paddle/phi/kernels/gpu/atan2_kernel.cu @@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(atan2, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); diff --git a/paddle/phi/kernels/impl/atan2_kernel_impl.h b/paddle/phi/kernels/impl/atan2_kernel_impl.h index b7799a7770..578bac6170 100644 --- a/paddle/phi/kernels/impl/atan2_kernel_impl.h +++ b/paddle/phi/kernels/impl/atan2_kernel_impl.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" #include "paddle/phi/kernels/atan2_kernel.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/python/paddle/fluid/tests/unittests/test_atan2_op.py b/python/paddle/fluid/tests/unittests/test_atan2_op.py index 33ee52bf6a..7dd3ceaca8 100644 --- a/python/paddle/fluid/tests/unittests/test_atan2_op.py +++ b/python/paddle/fluid/tests/unittests/test_atan2_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid.core as core @@ -129,6 +129,35 @@ class TestAtan2API(unittest.TestCase): run(place) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestAtan2BF16OP(OpTest): + def setUp(self): + self.op_type = 'atan2' + self.python_api = paddle.atan2 + self.dtype = np.uint16 + x1 = np.random.uniform(-1, -0.1, [15, 17]).astype('float32') + x2 = np.random.uniform(0.1, 1, [15, 17]).astype('float32') + out = np.arctan2(x1, x2) + + self.inputs = { + 'X1': convert_float_to_uint16(x1), + 'X2': convert_float_to_uint16(x2), + } + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X1', 'X2'], 'Out') + + class TestAtan2Error(unittest.TestCase): def test_mismatch(self): paddle.enable_static() -- GitLab