diff --git a/paddle/phi/kernels/gpu/atan2_grad_kernel.cu b/paddle/phi/kernels/gpu/atan2_grad_kernel.cu index 7e68610af1d54df072b7d625738d494ad2343c80..0e0b4329fa08ae9d582eeecc7e511ba77bc70167 100644 --- a/paddle/phi/kernels/gpu/atan2_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/atan2_grad_kernel.cu @@ -22,4 +22,5 @@ PD_REGISTER_KERNEL(atan2_grad, phi::Atan2GradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/atan2_kernel.cu b/paddle/phi/kernels/gpu/atan2_kernel.cu index d77c21473b4c44d3a54a549a8c5a0391fcf7c5dd..ed66318fc25285bf79724760d8ea5cd7ca1dd4b3 100644 --- a/paddle/phi/kernels/gpu/atan2_kernel.cu +++ b/paddle/phi/kernels/gpu/atan2_kernel.cu @@ -23,6 +23,7 @@ PD_REGISTER_KERNEL(atan2, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); diff --git a/paddle/phi/kernels/impl/atan2_kernel_impl.h b/paddle/phi/kernels/impl/atan2_kernel_impl.h index b7799a777046f4d63034df1aff23c68915d8d7a8..578bac61701de97e96f752b8a97c1e8aa726ccb0 100644 --- a/paddle/phi/kernels/impl/atan2_kernel_impl.h +++ b/paddle/phi/kernels/impl/atan2_kernel_impl.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" #include "paddle/phi/kernels/atan2_kernel.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/python/paddle/fluid/tests/unittests/test_atan2_op.py b/python/paddle/fluid/tests/unittests/test_atan2_op.py index 33ee52bf6ac692831e5d881af82d3f8b92bb13bd..7dd3ceaca8b008f147c85193e65787e3487045fc 100644 --- a/python/paddle/fluid/tests/unittests/test_atan2_op.py +++ b/python/paddle/fluid/tests/unittests/test_atan2_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid.core as core @@ -129,6 +129,35 @@ class TestAtan2API(unittest.TestCase): run(place) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestAtan2BF16OP(OpTest): + def setUp(self): + self.op_type = 'atan2' + self.python_api = paddle.atan2 + self.dtype = np.uint16 + x1 = np.random.uniform(-1, -0.1, [15, 17]).astype('float32') + x2 = np.random.uniform(0.1, 1, [15, 17]).astype('float32') + out = np.arctan2(x1, x2) + + self.inputs = { + 'X1': convert_float_to_uint16(x1), + 'X2': convert_float_to_uint16(x2), + } + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X1', 'X2'], 'Out') + + class TestAtan2Error(unittest.TestCase): def test_mismatch(self): paddle.enable_static()