diff --git a/test/torch/test_funcs.py b/test/torch/test_funcs.py index 4e17619e9069e7ac2f1568fd6622ba7197728664..1323ac2f8f00b374a9c72a611b7181ba9e4c8cf5 100644 --- a/test/torch/test_funcs.py +++ b/test/torch/test_funcs.py @@ -644,3 +644,45 @@ class TestTorchFuncs: 'a': [[19, 10], [43, 26]], 'b': {'x': [[44, 32], [80, 59]]}, })).all() + + def test_isfinite(self): + t1 = ttorch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([True, False, True, False, False])).all() + + t2 = ttorch.isfinite(ttorch.tensor({ + 'a': [1, float('inf'), 2, float('-inf'), float('nan')], + 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} + })) + assert (t2 == ttorch.tensor({ + 'a': [True, False, True, False, False], + 'b': {'x': [[True, False, True], [False, True, False]]}, + })) + + def test_isinf(self): + t1 = ttorch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([False, True, False, True, False])).all() + + t2 = ttorch.isinf(ttorch.tensor({ + 'a': [1, float('inf'), 2, float('-inf'), float('nan')], + 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} + })) + assert (t2 == ttorch.tensor({ + 'a': [False, True, False, True, False], + 'b': {'x': [[False, True, False], [True, False, False]]}, + })) + + def test_isnan(self): + t1 = ttorch.isnan(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([False, False, False, False, True])).all() + + t2 = ttorch.isnan(ttorch.tensor({ + 'a': [1, float('inf'), 2, float('-inf'), float('nan')], + 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} + })) + assert (t2 == ttorch.tensor({ + 'a': [False, False, False, False, True], + 'b': {'x': [[False, False, False], [False, False, True]]}, + })).all() diff --git a/test/torch/test_tensor.py b/test/torch/test_tensor.py index 66703a40816a312f1e2703368c85df08dcc58695..80fe85ec766fcaf2a61d700ef63801817427a765 100644 --- a/test/torch/test_tensor.py +++ b/test/torch/test_tensor.py @@ -281,3 +281,45 @@ class TestTorchTensor: 'a': [[19, 10], [43, 26]], 'b': {'x': [[44, 32], [80, 59]]}, })).all() + + def test_isfinite(self): + t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite() + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([True, False, True, False, False])).all() + + t2 = ttorch.tensor({ + 'a': [1, float('inf'), 2, float('-inf'), float('nan')], + 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} + }).isfinite() + assert (t2 == ttorch.tensor({ + 'a': [True, False, True, False, False], + 'b': {'x': [[True, False, True], [False, True, False]]}, + })) + + def test_isinf(self): + t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf() + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([False, True, False, True, False])).all() + + t2 = ttorch.tensor({ + 'a': [1, float('inf'), 2, float('-inf'), float('nan')], + 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} + }).isinf() + assert (t2 == ttorch.tensor({ + 'a': [False, True, False, True, False], + 'b': {'x': [[False, True, False], [True, False, False]]}, + })) + + def test_isnan(self): + t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan() + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([False, False, False, False, True])).all() + + t2 = ttorch.tensor({ + 'a': [1, float('inf'), 2, float('-inf'), float('nan')], + 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} + }).isnan() + assert (t2 == ttorch.tensor({ + 'a': [False, False, False, False, True], + 'b': {'x': [[False, False, False], [False, False, True]]}, + })).all() diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index 6cb91283670a0b7d0bc81a2fcaa845715b5b5877..d6907f1a744ac3aeca047b947f7e1227efea5900 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -22,6 +22,7 @@ __all__ = [ 'eq', 'ne', 'lt', 'le', 'gt', 'ge', 'equal', 'tensor', 'clone', 'dot', 'matmul', 'mm', + 'isfinite', 'isinf', 'isnan', ] func_treelize = post_process(post_process(args_mapping( @@ -954,3 +955,87 @@ def mm(input, mat2, *args, **kwargs): [80, 59]]) """ return torch.mm(input, mat2, *args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.isfinite) +@func_treelize() +def isfinite(input): + """ + In ``treetensor``, you can get a tree of new tensors with boolean elements + representing if each element is `finite` or not. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([ True, False, True, False, False]) + + >>> ttorch.isfinite(ttorch.tensor({ + ... 'a': [1, float('inf'), 2, float('-inf'), float('nan')], + ... 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} + ... })) + + ├── a --> tensor([ True, False, True, False, False]) + └── b --> + └── x --> tensor([[ True, False, True], + [False, True, False]]) + """ + return torch.isfinite(input) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.isinf) +@func_treelize() +def isinf(input): + """ + In ``treetensor``, you can test if each element of ``input`` + is infinite (positive or negative infinity) or not. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([False, True, False, True, False]) + + >>> ttorch.isinf(ttorch.tensor({ + ... 'a': [1, float('inf'), 2, float('-inf'), float('nan')], + ... 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} + ... })) + + ├── a --> tensor([False, True, False, True, False]) + └── b --> + └── x --> tensor([[False, True, False], + [ True, False, False]]) + """ + return torch.isinf(input) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.isnan) +@func_treelize() +def isnan(input): + """ + In ``treetensor``, you get a tree of new tensors with boolean elements representing + if each element of ``input`` is NaN or not + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.isnan(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([False, False, False, False, True]) + + >>> ttorch.isnan(ttorch.tensor({ + ... 'a': [1, float('inf'), 2, float('-inf'), float('nan')], + ... 'b': {'x': [[1, float('inf'), -2], [float('-inf'), 3, float('nan')]]} + ... })) + + ├── a --> tensor([False, False, False, False, True]) + └── b --> + └── x --> tensor([[False, False, False], + [False, False, True]]) + """ + return torch.isnan(input) diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index 788e226ed34e79e995e9dcd1cfe69a8e006bbcc7..2447e56b69677a522f3c685b2c125f59ed9047c1 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -293,3 +293,27 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): See :func:`treetensor.torch.matmul`. """ return self.matmul(tensor2, *args, **kwargs) + + @doc_from(torch.Tensor.isfinite) + @method_treelize() + def isfinite(self): + """ + See :func:`treetensor.torch.isfinite`. + """ + return self.isfinite() + + @doc_from(torch.Tensor.isinf) + @method_treelize() + def isinf(self): + """ + See :func:`treetensor.torch.isinf`. + """ + return self.isinf() + + @doc_from(torch.Tensor.isnan) + @method_treelize() + def isnan(self): + """ + See :func:`treetensor.torch.isnan`. + """ + return self.isnan()