diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index a05f2858c0df3b0d80e6dbfd4758270902e1403e..7e002c8154147ddb83ac195ff40deeb6da434f6b 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -83,12 +83,13 @@ struct DataTypeTrait { _ForEachDataTypeHelper_( \ callback, ::paddle::platform::complex, COMPLEX128); -#define _ForEachDataTypeNormal_(callback) \ - _ForEachDataTypeHelper_(callback, float, FP32); \ - _ForEachDataTypeHelper_(callback, double, FP64); \ - _ForEachDataTypeHelper_(callback, int, INT32); \ - _ForEachDataTypeHelper_(callback, int64_t, INT64); \ - _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); +#define _ForEachDataTypeNormal_(callback) \ + _ForEachDataTypeHelper_(callback, float, FP32); \ + _ForEachDataTypeHelper_(callback, double, FP64); \ + _ForEachDataTypeHelper_(callback, int, INT32); \ + _ForEachDataTypeHelper_(callback, int64_t, INT64); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); // For the use of thrust, as index-type elements can be only integers. #define _ForEachDataTypeTiny_(callback) \ diff --git a/paddle/fluid/operators/isfinite_op.cu b/paddle/fluid/operators/isfinite_op.cu old mode 100644 new mode 100755 index d8e18f58fa9f2d3ffd712c07ba39eed7124e82b4..80a65cbda916b7f27dfd02cae30b0e6faa1e22c5 --- a/paddle/fluid/operators/isfinite_op.cu +++ b/paddle/fluid/operators/isfinite_op.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/operators/isfinite_op.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -22,18 +23,21 @@ REGISTER_OP_CUDA_KERNEL( ops::OverflowKernel, ops::OverflowKernel, ops::OverflowKernel, - ops::OverflowKernel); + ops::OverflowKernel, + ops::OverflowKernel); REGISTER_OP_CUDA_KERNEL( isnan, ops::OverflowKernel, ops::OverflowKernel, ops::OverflowKernel, - ops::OverflowKernel); + ops::OverflowKernel, + ops::OverflowKernel); REGISTER_OP_CUDA_KERNEL( isfinite, ops::OverflowKernel, ops::OverflowKernel, ops::OverflowKernel, - ops::OverflowKernel); + ops::OverflowKernel, + ops::OverflowKernel); diff --git a/paddle/phi/kernels/cpu/isfinite_kernel.cc b/paddle/phi/kernels/cpu/isfinite_kernel.cc index 85d125794871d327d0150b5dd475487652eb3775..c9f69c5f7e4f5e6551bf38c573eb9d8be3438408 100644 --- a/paddle/phi/kernels/cpu/isfinite_kernel.cc +++ b/paddle/phi/kernels/cpu/isfinite_kernel.cc @@ -25,6 +25,7 @@ PD_REGISTER_KERNEL(isinf, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); @@ -37,6 +38,7 @@ PD_REGISTER_KERNEL(isnan, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); @@ -49,6 +51,7 @@ PD_REGISTER_KERNEL(isfinite, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); diff --git a/paddle/phi/kernels/funcs/isfinite_functor.h b/paddle/phi/kernels/funcs/isfinite_functor.h index 1dc4fd57b48574594d54a56825fefa5539c85573..795b8f275c87ea3cef3d86601b878c389d982ef0 100644 --- a/paddle/phi/kernels/funcs/isfinite_functor.h +++ b/paddle/phi/kernels/funcs/isfinite_functor.h @@ -45,6 +45,13 @@ struct IsNanFunctor { } }; +template <> +struct IsNanFunctor { + HOSTDEVICE bool operator()(const phi::dtype::bfloat16& a) const { + return phi::dtype::isnan(a); + } +}; + template struct IsInfFunctor { HOSTDEVICE bool operator()(const T& a) const { @@ -69,6 +76,13 @@ struct IsInfFunctor { } }; +template <> +struct IsInfFunctor { + HOSTDEVICE bool operator()(const phi::dtype::bfloat16& a) const { + return phi::dtype::isinf(a); + } +}; + template struct IsFiniteFunctor { HOSTDEVICE bool operator()(const T& a) const { @@ -94,5 +108,12 @@ struct IsFiniteFunctor { } }; +template <> +struct IsFiniteFunctor { + HOSTDEVICE bool operator()(const phi::dtype::bfloat16& a) const { + return phi::dtype::isfinite(a); + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/isfinite_kernel.cu b/paddle/phi/kernels/gpu/isfinite_kernel.cu index e8c2fa022ec7a56819a4e611b6c9e365953b7dfc..9bde1d7a5bd38725addc852e7b8d96642b756a32 100644 --- a/paddle/phi/kernels/gpu/isfinite_kernel.cu +++ b/paddle/phi/kernels/gpu/isfinite_kernel.cu @@ -25,6 +25,7 @@ PD_REGISTER_KERNEL(isinf, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); @@ -37,6 +38,7 @@ PD_REGISTER_KERNEL(isnan, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); @@ -49,6 +51,7 @@ PD_REGISTER_KERNEL(isfinite, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); diff --git a/python/paddle/fluid/tests/unittests/test_isfinite_op.py b/python/paddle/fluid/tests/unittests/test_isfinite_op.py old mode 100644 new mode 100755 index 6599f66140c2298fd655edfc408985c4503e0cc9..efda5d502c6a6aa288be32153aa2188e36cbe85d --- a/python/paddle/fluid/tests/unittests/test_isfinite_op.py +++ b/python/paddle/fluid/tests/unittests/test_isfinite_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 from paddle.fluid import core @@ -48,6 +48,28 @@ class TestFP16Inf(TestInf): self.dtype = np.float16 +# BFP16 isinf Test +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestInfBF16(OpTest): + def setUp(self): + self.op_type = "isinf" + self.dtype = np.uint16 + x = np.random.uniform(0.1, 1, [11, 17]).astype(np.float32) + x[0] = np.inf + x[-1] = np.inf + + out = np.array(True) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': out} + + def test_output(self): + self.check_output_with_place(core.CUDAPlace(0)) + + class TestNAN(OpTest): def setUp(self): self.op_type = "isnan" @@ -76,6 +98,28 @@ class TestFP16NAN(TestNAN): self.dtype = np.float16 +# BFP16 isnan Test +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestNANBF16(OpTest): + def setUp(self): + self.op_type = "isnan" + self.dtype = np.uint16 + x = np.random.uniform(0.1, 1, [11, 17]).astype(np.float32) + x[0] = np.nan + x[-1] = np.nan + + out = np.array(True) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': out} + + def test_output(self): + self.check_output_with_place(core.CUDAPlace(0)) + + class TestIsfinite(OpTest): def setUp(self): self.op_type = "isfinite" @@ -105,5 +149,27 @@ class TestFP16Isfinite(TestIsfinite): self.dtype = np.float16 +# BFP16 isfinite Test +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestIsfiniteBF16(OpTest): + def setUp(self): + self.op_type = "isfinite" + self.dtype = np.uint16 + x = np.random.uniform(0.1, 1, [11, 17]).astype(np.float32) + x[0] = np.inf + x[-1] = np.nan + + out = np.array(False) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': out} + + def test_output(self): + self.check_output_with_place(core.CUDAPlace(0)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index ba7efb7956f77d3aaedd5df313f5dbb08d0c2c1b..1e969be880401ef5e6f1444473112eb04e9f1b1e 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3466,7 +3466,14 @@ def isfinite(x, name=None): check_variable_and_dtype( x, 'x', - ['float16', 'float32', 'float64', 'int32', 'int64'], + [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'uint16', + ], 'isfinite', ) out = helper.create_variable_for_type_inference('bool') @@ -3502,7 +3509,17 @@ def isinf(x, name=None): else: helper = LayerHelper("isinf_v2", **locals()) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isinf' + x, + 'x', + [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'uint16', + ], + 'isinf', ) out = helper.create_variable_for_type_inference(dtype='bool') helper.append_op(type="isinf_v2", inputs={"X": x}, outputs={"Out": out}) @@ -3535,7 +3552,17 @@ def isnan(x, name=None): else: helper = LayerHelper("isnan_v2", **locals()) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'isnan' + x, + 'x', + [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'uint16', + ], + 'isnan', ) out = helper.create_variable_for_type_inference(dtype='bool') helper.append_op(type="isnan_v2", inputs={"X": x}, outputs={"Out": out})