diff --git a/paddle/phi/kernels/gpu/allclose_kernel.cu b/paddle/phi/kernels/gpu/allclose_kernel.cu index fa6a8fce0bf861047cc8da47b2da03a69005f8af..99ccfcd8667e6d7723b3f7cd6aeafbf20caa0267 100644 --- a/paddle/phi/kernels/gpu/allclose_kernel.cu +++ b/paddle/phi/kernels/gpu/allclose_kernel.cu @@ -16,6 +16,8 @@ #include "glog/logging.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" @@ -31,14 +33,16 @@ __global__ void AllcloseCUDAKernel(const T* in_data, bool* out_data) { unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; bool val; + using MPType = typename phi::dtype::MPTypeTrait::Type; for (int i = idx; i < num; i += blockDim.x * gridDim.x) { - const T a = in_data[i], b = other_data[i]; + const MPType a = static_cast(in_data[i]); + const MPType b = static_cast(other_data[i]); if (isnan(a) || isnan(b)) { val = equal_nan && isnan(a) == isnan(b); } else { - T left = (a > b ? a - b : b - a); - T right = atol + (b > 0 ? rtol * b : (-rtol) * b); - T diff = (left > right ? left - right : right - left); + MPType left = (a > b ? a - b : b - a); + MPType right = atol + (b > 0 ? rtol * b : (-rtol) * b); + MPType diff = (left > right ? left - right : right - left); val = a == b || left <= right || diff <= 1e-15; } if (!val) *out_data = false; @@ -92,7 +96,12 @@ void AllCloseKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - allclose, GPU, ALL_LAYOUT, phi::AllCloseKernel, float, double) { +PD_REGISTER_KERNEL(allclose, + GPU, + ALL_LAYOUT, + phi::AllCloseKernel, + float, + double, + phi::dtype::float16) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } diff --git a/paddle/phi/kernels/gpu/isclose_kernel.cu b/paddle/phi/kernels/gpu/isclose_kernel.cu index 34774ec715c48de953945f94624c2c3cfe742d30..cfae8d0bbda29a529434f0c29c8dc85ca6131c6b 100644 --- a/paddle/phi/kernels/gpu/isclose_kernel.cu +++ b/paddle/phi/kernels/gpu/isclose_kernel.cu @@ -15,8 +15,14 @@ #include "paddle/phi/kernels/isclose_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/isclose_kernel_impl.h" -PD_REGISTER_KERNEL( - isclose, GPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {} +PD_REGISTER_KERNEL(isclose, + GPU, + ALL_LAYOUT, + phi::IscloseKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/isclose_kernel_impl.h b/paddle/phi/kernels/impl/isclose_kernel_impl.h index 4ee6831d7af1142b0a428f5e2563e1d9fbebebf3..4d9d0cd7b866df5bc50022801964c40a1763a9f2 100644 --- a/paddle/phi/kernels/impl/isclose_kernel_impl.h +++ b/paddle/phi/kernels/impl/isclose_kernel_impl.h @@ -18,6 +18,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" @@ -109,14 +110,16 @@ __global__ void IscloseCUDAKernel(const T* in_data, bool* out_data) { unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; bool val; + using MPType = typename phi::dtype::MPTypeTrait::Type; for (int i = idx; i < num; i += blockDim.x * gridDim.x) { - const T a = in_data[i], b = other_data[i]; + const MPType a = static_cast(in_data[i]); + const MPType b = static_cast(other_data[i]); if (isnan(a) || isnan(b)) { val = equal_nan && isnan(a) == isnan(b); } else { - T left = (a > b ? a - b : b - a); - T right = atol + (b > 0 ? rtol * b : (-rtol) * b); - T diff = (left > right ? left - right : right - left); + MPType left = (a > b ? a - b : b - a); + MPType right = atol + (b > 0 ? rtol * b : (-rtol) * b); + MPType diff = (left > right ? left - right : right - left); val = a == b || left <= right || diff <= 1e-15; } out_data[i] = val; diff --git a/python/paddle/fluid/tests/unittests/test_allclose_op.py b/python/paddle/fluid/tests/unittests/test_allclose_op.py index c4cde0ec49ee99efac14c597b800332f14c2e7cf..53753a1764651cf9ffc42c4af0f941d09ae95b54 100644 --- a/python/paddle/fluid/tests/unittests/test_allclose_op.py +++ b/python/paddle/fluid/tests/unittests/test_allclose_op.py @@ -18,6 +18,7 @@ import numpy as np from op_test import OpTest import paddle +import paddle.fluid.core as core class TestAllcloseOp(OpTest): @@ -134,7 +135,7 @@ class TestAllcloseError(unittest.TestCase): with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() ): - x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16') + x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32') y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') result = paddle.allclose(x, y) @@ -170,6 +171,36 @@ class TestAllcloseError(unittest.TestCase): self.assertRaises(TypeError, test_equal_nan) +class TestAllcloseOpFp16(unittest.TestCase): + def test_fp16(self): + x_data = np.random.rand(10, 10).astype('float16') + y_data = np.random.rand(10, 10).astype('float16') + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[10, 10], name='x', dtype='float16') + y = paddle.static.data(shape=[10, 10], name='x', dtype='float16') + out = paddle.allclose(x, y, rtol=1e-05, atol=1e-08) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out]) + + +class TestAllcloseOpFloat16(TestAllcloseOp): + def set_args(self): + self.input = np.array([10.1]).astype("float16") + self.other = np.array([10]).astype("float16") + self.rtol = np.array([0.01]).astype("float64") + self.atol = np.array([0]).astype("float64") + self.equal_nan = False + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, check_eager=True) + + class TestAllcloseOpFloat32(TestAllcloseOp): def set_args(self): self.input = np.array([10.1]).astype("float32") diff --git a/python/paddle/fluid/tests/unittests/test_isclose_op.py b/python/paddle/fluid/tests/unittests/test_isclose_op.py index fc2a5cd5ebef2d54bffc4fcd980caa9f1a9e9392..c587420a0ce1ac40542e8cc695201aa96018b9c7 100644 --- a/python/paddle/fluid/tests/unittests/test_isclose_op.py +++ b/python/paddle/fluid/tests/unittests/test_isclose_op.py @@ -18,6 +18,7 @@ import numpy as np from op_test import OpTest import paddle +import paddle.fluid.core as core class TestIscloseOp(OpTest): @@ -166,7 +167,7 @@ class TestIscloseError(unittest.TestCase): with paddle.static.program_guard( paddle.static.Program(), paddle.static.Program() ): - x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16') + x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32') y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') result = paddle.isclose(x, y) @@ -203,6 +204,36 @@ class TestIscloseError(unittest.TestCase): self.assertRaises(TypeError, test_equal_nan) +class TestIscloseOpFp16(unittest.TestCase): + def test_fp16(self): + x_data = np.random.rand(10, 10).astype('float16') + y_data = np.random.rand(10, 10).astype('float16') + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[10, 10], name='x', dtype='float16') + y = paddle.static.data(shape=[10, 10], name='x', dtype='float16') + out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out]) + + +class TestIscloseOpFloat16(TestIscloseOp): + def set_args(self): + self.input = np.array([10.1]).astype("float16") + self.other = np.array([10]).astype("float16") + self.rtol = np.array([0.01]).astype("float64") + self.atol = np.array([0]).astype("float64") + self.equal_nan = False + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, check_eager=True) + + class TestIscloseOpFloat32(TestIscloseOp): def set_args(self): self.input = np.array([10.1]).astype("float32") diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index bb2060e55cfa73843f37ab727bd84c9f7a3cdfb0..e332e554a556932de3eea69a3ecd433deb496901 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -361,8 +361,8 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): two tensors are elementwise equal within a tolerance. Args: - x(Tensor): The input tensor, it's data type should be float32, float64.. - y(Tensor): The input tensor, it's data type should be float32, float64.. + x(Tensor): The input tensor, it's data type should be float16, float32, float64.. + y(Tensor): The input tensor, it's data type should be float16, float32, float64.. rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` . atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` . equal_nan(equalnantype, optional): ${equal_nan_comment}. @@ -401,8 +401,12 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): if in_dygraph_mode(): return _C_ops.allclose(x, y, rtol, atol, equal_nan) else: - check_variable_and_dtype(x, "input", ['float32', 'float64'], 'allclose') - check_variable_and_dtype(y, "input", ['float32', 'float64'], 'allclose') + check_variable_and_dtype( + x, "input", ['float16', 'float32', 'float64'], 'allclose' + ) + check_variable_and_dtype( + y, "input", ['float16', 'float32', 'float64'], 'allclose' + ) check_type(rtol, 'rtol', float, 'allclose') check_type(atol, 'atol', float, 'allclose') check_type(equal_nan, 'equal_nan', bool, 'allclose') @@ -989,8 +993,8 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): two tensors are elementwise equal within a tolerance. Args: - x(Tensor): The input tensor, it's data type should be float32, float64. - y(Tensor): The input tensor, it's data type should be float32, float64. + x(Tensor): The input tensor, it's data type should be float16, float32, float64. + y(Tensor): The input tensor, it's data type should be float16, float32, float64. rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` . atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` . equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` . @@ -1027,8 +1031,12 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): if in_dygraph_mode(): return _C_ops.isclose(x, y, rtol, atol, equal_nan) else: - check_variable_and_dtype(x, "input", ['float32', 'float64'], 'isclose') - check_variable_and_dtype(y, "input", ['float32', 'float64'], 'isclose') + check_variable_and_dtype( + x, "input", ['float16', 'float32', 'float64'], 'isclose' + ) + check_variable_and_dtype( + y, "input", ['float16', 'float32', 'float64'], 'isclose' + ) check_type(rtol, 'rtol', float, 'isclose') check_type(atol, 'atol', float, 'isclose') check_type(equal_nan, 'equal_nan', bool, 'isclose')