diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index abe3184ee918543bda7bae8b9c0d891535780410..fc8a9b52bc08f1c691703aaeba0da72667a7dfc1 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -230,6 +230,7 @@ from .tensor.math import sqrt # noqa: F401 from .tensor.math import square # noqa: F401 from .tensor.math import stanh # noqa: F401 from .tensor.math import sum # noqa: F401 +from .tensor.math import nan_to_num # noqa: F401 from .tensor.math import nansum # noqa: F401 from .tensor.math import nanmean # noqa: F401 from .tensor.math import count_nonzero # noqa: F401 @@ -666,6 +667,7 @@ __all__ = [ # noqa 'renorm', 'take_along_axis', 'put_along_axis', + 'nan_to_num', 'heaviside', 'tril_indices', 'index_add', diff --git a/python/paddle/fluid/tests/unittests/test_nan_to_num_op.py b/python/paddle/fluid/tests/unittests/test_nan_to_num_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a5cb1f3881d7079ba6d50354cc4d63a6f4dbfe --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nan_to_num_op.py @@ -0,0 +1,203 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional +import numpy as np +import paddle +import paddle.fluid.core as core + +# from op_test import OpTest + + +def np_nan_to_num( + x: np.ndarray, + nan: float = 0.0, + posinf: Optional[float] = None, + neginf: Optional[float] = None, +) -> np.ndarray: + return np.nan_to_num(x, True, nan=nan, posinf=posinf, neginf=neginf) + + +def np_nan_to_num_op( + x: np.ndarray, + nan: float, + replace_posinf_with_max: bool, + posinf: float, + replace_neginf_with_min: bool, + neginf: float, +) -> np.ndarray: + if replace_posinf_with_max: + posinf = None + if replace_neginf_with_min: + neginf = None + return np.nan_to_num(x, True, nan=nan, posinf=posinf, neginf=neginf) + + +def np_nan_to_num_grad(x: np.ndarray, dout: np.ndarray) -> np.ndarray: + dx = np.copy(dout) + dx[np.isnan(x) | (x == np.inf) | (x == -np.inf)] = 0 + return dx + + +class TestNanToNum(unittest.TestCase): + def setUp(self): + self.place = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static(self): + x_np = np.array([[1, np.nan, -2], [np.inf, 0, -np.inf]]).astype( + np.float32 + ) + out1_np = np_nan_to_num(x_np) + out2_np = np_nan_to_num(x_np, 1.0) + out3_np = np_nan_to_num(x_np, 1.0, 9.0) + out4_np = np_nan_to_num(x_np, 1.0, 9.0, -12.0) + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', x_np.shape) + out1 = paddle.nan_to_num(x) + out2 = paddle.nan_to_num(x, 1.0) + out3 = paddle.nan_to_num(x, 1.0, 9.0) + out4 = paddle.nan_to_num(x, 1.0, 9.0, -12.0) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': x_np}, fetch_list=[out1, out2, out3, out4]) + + self.assertTrue(np.allclose(out1_np, res[0])) + self.assertTrue(np.allclose(out2_np, res[1])) + self.assertTrue(np.allclose(out3_np, res[2])) + self.assertTrue(np.allclose(out4_np, res[3])) + + def test_dygraph(self): + + paddle.disable_static(place=self.place) + + with paddle.fluid.dygraph.guard(): + # NOTE(tiancaishaonvjituizi): float64 input fails the test + x_np = np.array([[1, np.nan, -2], [np.inf, 0, -np.inf]]).astype( + np.float32 + # np.float64 + ) + x_tensor = paddle.to_tensor(x_np, stop_gradient=False) + + out_tensor = paddle.nan_to_num(x_tensor) + out_np = np_nan_to_num(x_np) + self.assertTrue(np.allclose(out_tensor.numpy(), out_np)) + + out_tensor = paddle.nan_to_num(x_tensor, 1.0, None, None) + out_np = np_nan_to_num(x_np, 1, None, None) + self.assertTrue(np.allclose(out_tensor.numpy(), out_np)) + + out_tensor = paddle.nan_to_num(x_tensor, 1.0, 2.0, None) + out_np = np_nan_to_num(x_np, 1, 2, None) + self.assertTrue(np.allclose(out_tensor.numpy(), out_np)) + + out_tensor = paddle.nan_to_num(x_tensor, 1.0, None, -10.0) + out_np = np_nan_to_num(x_np, 1, None, -10) + self.assertTrue(np.allclose(out_tensor.numpy(), out_np)) + + out_tensor = paddle.nan_to_num(x_tensor, 1.0, 100.0, -10.0) + out_np = np_nan_to_num(x_np, 1, 100, -10) + self.assertTrue(np.allclose(out_tensor.numpy(), out_np)) + + paddle.enable_static() + + def test_check_grad(self): + paddle.disable_static(place=self.place) + x_np = np.array([[1, np.nan, -2], [np.inf, 0, -np.inf]]).astype( + np.float32 + ) + x_tensor = paddle.to_tensor(x_np, stop_gradient=False) + + y = paddle.nan_to_num(x_tensor) + dx = paddle.grad(y, x_tensor)[0].numpy() + + np_grad = np_nan_to_num_grad(x_np, np.ones_like(x_np)) + self.assertTrue(np.allclose(np_grad, dx)) + + paddle.enable_static() + + +# class BaseTestCases: +# +# class BaseOpTest(OpTest): +# +# def setUp(self): +# self.op_type = "nan_to_num" +# input = np.arange(100, dtype=np.float64) +# input[5] = np.nan +# input[29] = np.inf +# input[97] = -np.inf +# self.inputs = {'X': input} +# self.attrs = self._attrs() +# self.outputs = { +# 'Out': np_nan_to_num_op(self.inputs['X'], **self.attrs) +# } +# paddle.enable_static() +# +# def test_check_output(self): +# self.check_output() +# +# def test_check_grad(self): +# input = self.inputs['X'] +# dout = np.ones_like(input) / input.size +# self.check_grad( +# ['X'], +# 'Out', +# user_defined_grads=[np_nan_to_num_grad(self.inputs['X'], dout)]) +# +# def _attrs(self): +# raise NotImplementedError() +# +# +# class TestNanToNumOp1(BaseTestCases.BaseOpTest): +# +# def _attrs(self): +# return { +# 'nan': 0.0, +# 'replace_posinf_with_max': True, +# 'posinf': -1, +# 'replace_neginf_with_min': True, +# 'neginf': -10 +# } +# +# +# class TestNanToNumOp2(BaseTestCases.BaseOpTest): +# +# def _attrs(self): +# return { +# 'nan': 2.0, +# 'replace_posinf_with_max': False, +# 'posinf': -1, +# 'replace_neginf_with_min': True, +# 'neginf': -10 +# } +# +# +# class TestNanToNumOp3(BaseTestCases.BaseOpTest): +# +# def _attrs(self): +# return { +# 'nan': 0.0, +# 'replace_posinf_with_max': False, +# 'posinf': -1, +# 'replace_neginf_with_min': False, +# 'neginf': -10 +# } + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index f08d339c049a1ad530213a492854b675cc55fc6c..4c1ec078380506c399fe3866cad752a5dcb9a624 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -169,6 +169,7 @@ from .math import sqrt_ # noqa: F401 from .math import square # noqa: F401 from .math import stanh # noqa: F401 from .math import sum # noqa: F401 +from .math import nan_to_num # noqa: F401 from .math import nansum # noqa: F401 from .math import nanmean # noqa: F401 from .math import count_nonzero # noqa: F401 @@ -350,6 +351,7 @@ tensor_method_func = [ # noqa 'square', 'stanh', 'sum', + 'nan_to_num', 'nansum', 'nanmean', 'count_nonzero', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index cd8dbd08b94d865116bc68f390259f5427721873..91388a6f99a02bf7ebde19c3d6c2dcce1409fa58 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1364,6 +1364,54 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): return out +def nan_to_num(x, nan=0.0, posinf=None, neginf=None, name=None): + """ + Replaces NaN, positive infinity, and negative infinity values in input tensor. + + Args: + x (Tensor): An N-D Tensor, the data type is float32, float64. + nan (float, optional): the value to replace NaNs with. Default is 0. + posinf (float, optional): if a Number, the value to replace positive infinity values with. If None, positive infinity values are replaced with the greatest finite value representable by input’s dtype. Default is None. + neginf (float, optional): if a Number, the value to replace negative infinity values with. If None, negative infinity values are replaced with the lowest finite value representable by input’s dtype. Default is None. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: Results of nan_to_num operation input Tensor ``x``. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([float('nan'), 0.3, float('+inf'), float('-inf')], dtype='float32') + out1 = paddle.nan_to_num(x) # [0, 0.3, 3.4028235e+38, -3.4028235e+38] + out2 = paddle.nan_to_num(x, nan=1) # [1, 0.3, 3.4028235e+38, -3.4028235e+38] + out3 = paddle.nan_to_num(x, posinf=5) # [0, 0.3, 5, -3.4028235e+38] + out4 = paddle.nan_to_num(x, nan=10, neginf=-99) # [10, 0.3, 3.4028235e+38, -99] + """ + # NOTE(tiancaishaonvjituizi): it seems that paddle handles the dtype of python float number + # incorrectly, so we have to explicitly contruct tensors here + posinf_value = paddle.full_like(x, float("+inf")) + neginf_value = paddle.full_like(x, float("-inf")) + nan = paddle.full_like(x, nan) + assert x.dtype in [paddle.float32, paddle.float64] + is_float32 = x.dtype == paddle.float32 + if posinf is None: + posinf = ( + np.finfo(np.float32).max if is_float32 else np.finfo(np.float64).max + ) + posinf = paddle.full_like(x, posinf) + if neginf is None: + neginf = ( + np.finfo(np.float32).min if is_float32 else np.finfo(np.float64).min + ) + neginf = paddle.full_like(x, neginf) + x = paddle.where(paddle.isnan(x), nan, x) + x = paddle.where(x == posinf_value, posinf, x) + x = paddle.where(x == neginf_value, neginf, x) + return x + + def nansum(x, axis=None, dtype=None, keepdim=False, name=None): """ Computes the sum of tensor elements over the given axis, treating Not a Numbers (NaNs) as zero.