From f1799c33869ed8c54ea1d765be79dcbfca43480d Mon Sep 17 00:00:00 2001 From: HansBug Date: Tue, 21 Sep 2021 21:57:11 +0800 Subject: [PATCH] doc, test(hansbug): complete the math functions --- test/torch/test_funcs.py | 231 +++++++++++++++++++++++++++++++++ test/torch/test_tensor.py | 231 ++++++++++++++++++++++++++++++++- treetensor/torch/funcs.py | 265 +++++++++++++++++++++++++++++++++++++- 3 files changed, 721 insertions(+), 6 deletions(-) diff --git a/test/torch/test_funcs.py b/test/torch/test_funcs.py index 1323ac2f8..5bf386bce 100644 --- a/test/torch/test_funcs.py +++ b/test/torch/test_funcs.py @@ -686,3 +686,234 @@ class TestTorchFuncs: 'a': [False, False, False, False, True], 'b': {'x': [[False, False, False], [False, False, True]]}, })).all() + + def test_abs(self): + t1 = ttorch.abs(ttorch.tensor([12, 0, -3])) + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([12, 0, 3])).all() + + t2 = ttorch.abs(ttorch.tensor({ + 'a': [12, 0, -3], + 'b': {'x': [[-3, 1], [0, -2]]}, + })) + assert (t2 == ttorch.tensor({ + 'a': [12, 0, 3], + 'b': {'x': [[3, 1], [0, 2]]}, + })).all() + + def test_abs_(self): + t1 = ttorch.tensor([12, 0, -3]) + assert isinstance(t1, torch.Tensor) + + t1r = ttorch.abs_(t1) + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([12, 0, 3])).all() + + t2 = ttorch.tensor({ + 'a': [12, 0, -3], + 'b': {'x': [[-3, 1], [0, -2]]}, + }) + t2r = ttorch.abs_(t2) + assert t2r is t2 + assert (t2 == ttorch.tensor({ + 'a': [12, 0, 3], + 'b': {'x': [[3, 1], [0, 2]]}, + })).all() + + def test_clamp(self): + t1 = ttorch.clamp(ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]), min=-0.5, max=0.5) + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([-0.5000, 0.1734, -0.0478, 0.5000])) < 1e-6).all() + + t2 = ttorch.clamp(ttorch.tensor({ + 'a': [-1.7120, 0.1734, -0.0478, 2.0922], + 'b': {'x': [[-0.9049, 1.7029, -0.3697], [0.0489, -1.3127, -1.0221]]}, + }), min=-0.5, max=0.5) + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [-0.5000, 0.1734, -0.0478, 0.5000], + 'b': {'x': [[-0.5000, 0.5000, -0.3697], + [0.0489, -0.5000, -0.5000]]}, + })) < 1e-6).all() + + def test_clamp_(self): + t1 = ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]) + t1r = ttorch.clamp_(t1, min=-0.5, max=0.5) + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([-0.5000, 0.1734, -0.0478, 0.5000])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [-1.7120, 0.1734, -0.0478, 2.0922], + 'b': {'x': [[-0.9049, 1.7029, -0.3697], [0.0489, -1.3127, -1.0221]]}, + }) + t2r = ttorch.clamp_(t2, min=-0.5, max=0.5) + assert t2r is t2 + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [-0.5000, 0.1734, -0.0478, 0.5000], + 'b': {'x': [[-0.5000, 0.5000, -0.3697], + [0.0489, -0.5000, -0.5000]]}, + })) < 1e-6).all() + + def test_sign(self): + t1 = ttorch.sign(ttorch.tensor([12, 0, -3])) + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([1, 0, -1])).all() + + t2 = ttorch.sign(ttorch.tensor({ + 'a': [12, 0, -3], + 'b': {'x': [[-3, 1], [0, -2]]}, + })) + assert (t2 == ttorch.tensor({ + 'a': [1, 0, -1], + 'b': {'x': [[-1, 1], + [0, -1]]}, + })).all() + + def test_round(self): + t1 = ttorch.round(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])) + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[1., -2.], + [-2., 3.]])) < 1e-6).all() + + t2 = ttorch.round(ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + })) + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[1., -2.], + [-2., 3.]], + 'b': {'x': [[1., -4., 1.], + [-5., -2., 3.]]}, + })) < 1e-6).all() + + def test_round_(self): + t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) + t1r = ttorch.round_(t1) + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[1., -2.], + [-2., 3.]])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + }) + t2r = ttorch.round_(t2) + assert t2r is t2 + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[1., -2.], + [-2., 3.]], + 'b': {'x': [[1., -4., 1.], + [-5., -2., 3.]]}, + })) < 1e-6).all() + + def test_floor(self): + t1 = ttorch.floor(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])) + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[1., -2.], + [-3., 2.]])) < 1e-6).all() + + t2 = ttorch.floor(ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + })) + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[1., -2.], + [-3., 2.]], + 'b': {'x': [[1., -4., 1.], + [-5., -2., 2.]]}, + })) < 1e-6).all() + + def test_floor_(self): + t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) + t1r = ttorch.floor_(t1) + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[1., -2.], + [-3., 2.]])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + }) + t2r = ttorch.floor_(t2) + assert t2r is t2 + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[1., -2.], + [-3., 2.]], + 'b': {'x': [[1., -4., 1.], + [-5., -2., 2.]]}, + })) < 1e-6).all() + + def test_ceil(self): + t1 = ttorch.ceil(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])) + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[2., -1.], + [-2., 3.]])) < 1e-6).all() + + t2 = ttorch.ceil(ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + })) + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[2., -1.], + [-2., 3.]], + 'b': {'x': [[1., -3., 2.], + [-4., -2., 3.]]}, + })) < 1e-6).all() + + def test_ceil_(self): + t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) + t1r = ttorch.ceil_(t1) + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[2., -1.], + [-2., 3.]])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + }) + t2r = ttorch.ceil_(t2) + assert t2r is t2 + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[2., -1.], + [-2., 3.]], + 'b': {'x': [[1., -3., 2.], + [-4., -2., 3.]]}, + })) < 1e-6).all() + + def test_sigmoid(self): + t1 = ttorch.sigmoid(ttorch.tensor([1.0, 2.0, -1.5])) + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([0.7311, 0.8808, 0.1824])) < 1e-4).all() + + t2 = ttorch.sigmoid(ttorch.tensor({ + 'a': [1.0, 2.0, -1.5], + 'b': {'x': [[0.5, 1.2], [-2.5, 0.25]]}, + })) + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [0.7311, 0.8808, 0.1824], + 'b': {'x': [[0.6225, 0.7685], + [0.0759, 0.5622]]}, + })) < 1e-4).all() + + def test_sigmoid_(self): + t1 = ttorch.tensor([1.0, 2.0, -1.5]) + t1r = ttorch.sigmoid_(t1) + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([0.7311, 0.8808, 0.1824])) < 1e-4).all() + + t2 = ttorch.tensor({ + 'a': [1.0, 2.0, -1.5], + 'b': {'x': [[0.5, 1.2], [-2.5, 0.25]]}, + }) + t2r = ttorch.sigmoid_(t2) + assert t2r is t2 + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [0.7311, 0.8808, 0.1824], + 'b': {'x': [[0.6225, 0.7685], + [0.0759, 0.5622]]}, + })) < 1e-4).all() diff --git a/test/torch/test_tensor.py b/test/torch/test_tensor.py index 80fe85ec7..bfd8f8fe5 100644 --- a/test/torch/test_tensor.py +++ b/test/torch/test_tensor.py @@ -10,7 +10,7 @@ from treetensor.common import Object _all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y) -# noinspection PyUnresolvedReferences +# noinspection PyUnresolvedReferences,DuplicatedCode @pytest.mark.unittest class TestTorchTensor: _DEMO_1 = ttorch.Tensor({ @@ -323,3 +323,232 @@ class TestTorchTensor: 'a': [False, False, False, False, True], 'b': {'x': [[False, False, False], [False, False, True]]}, })).all() + + def test_abs(self): + t1 = ttorch.tensor([12, 0, -3]).abs() + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([12, 0, 3])).all() + + t2 = ttorch.tensor({ + 'a': [12, 0, -3], + 'b': {'x': [[-3, 1], [0, -2]]}, + }).abs() + assert (t2 == ttorch.tensor({ + 'a': [12, 0, 3], + 'b': {'x': [[3, 1], [0, 2]]}, + })).all() + + def test_abs_(self): + t1 = ttorch.tensor([12, 0, -3]) + t1r = t1.abs_() + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([12, 0, 3])).all() + + t2 = ttorch.tensor({ + 'a': [12, 0, -3], + 'b': {'x': [[-3, 1], [0, -2]]}, + }) + t2r = t2.abs_() + assert t2r is t2 + assert (t2 == ttorch.tensor({ + 'a': [12, 0, 3], + 'b': {'x': [[3, 1], [0, 2]]}, + })).all() + + def test_clamp(self): + t1 = ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]).clamp(min=-0.5, max=0.5) + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([-0.5000, 0.1734, -0.0478, 0.5000])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [-1.7120, 0.1734, -0.0478, 2.0922], + 'b': {'x': [[-0.9049, 1.7029, -0.3697], [0.0489, -1.3127, -1.0221]]}, + }).clamp(min=-0.5, max=0.5) + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [-0.5000, 0.1734, -0.0478, 0.5000], + 'b': {'x': [[-0.5000, 0.5000, -0.3697], + [0.0489, -0.5000, -0.5000]]}, + })) < 1e-6).all() + + def test_clamp_(self): + t1 = ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]) + t1r = t1.clamp_(min=-0.5, max=0.5) + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([-0.5000, 0.1734, -0.0478, 0.5000])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [-1.7120, 0.1734, -0.0478, 2.0922], + 'b': {'x': [[-0.9049, 1.7029, -0.3697], [0.0489, -1.3127, -1.0221]]}, + }) + t2r = t2.clamp_(min=-0.5, max=0.5) + assert t2r is t2 + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [-0.5000, 0.1734, -0.0478, 0.5000], + 'b': {'x': [[-0.5000, 0.5000, -0.3697], + [0.0489, -0.5000, -0.5000]]}, + })) < 1e-6).all() + + def test_sign(self): + t1 = ttorch.tensor([12, 0, -3]).sign() + assert isinstance(t1, torch.Tensor) + assert (t1 == ttorch.tensor([1, 0, -1])).all() + + t2 = ttorch.tensor({ + 'a': [12, 0, -3], + 'b': {'x': [[-3, 1], [0, -2]]}, + }).sign() + assert (t2 == ttorch.tensor({ + 'a': [1, 0, -1], + 'b': {'x': [[-1, 1], + [0, -1]]}, + })).all() + + def test_round(self): + t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]).round() + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[1., -2.], + [-2., 3.]])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + }).round() + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[1., -2.], + [-2., 3.]], + 'b': {'x': [[1., -4., 1.], + [-5., -2., 3.]]}, + })) < 1e-6).all() + + def test_round_(self): + t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) + t1r = t1.round_() + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[1., -2.], + [-2., 3.]])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + }) + t2r = t2.round_() + assert t2r is t2 + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[1., -2.], + [-2., 3.]], + 'b': {'x': [[1., -4., 1.], + [-5., -2., 3.]]}, + })) < 1e-6).all() + + def test_floor(self): + t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]).floor() + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[1., -2.], + [-3., 2.]])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + }).floor() + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[1., -2.], + [-3., 2.]], + 'b': {'x': [[1., -4., 1.], + [-5., -2., 2.]]}, + })) < 1e-6).all() + + def test_floor_(self): + t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) + t1r = t1.floor_() + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[1., -2.], + [-3., 2.]])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + }) + t2r = t2.floor_() + assert t2r is t2 + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[1., -2.], + [-3., 2.]], + 'b': {'x': [[1., -4., 1.], + [-5., -2., 2.]]}, + })) < 1e-6).all() + + def test_ceil(self): + t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]).ceil() + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[2., -1.], + [-2., 3.]])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + }).ceil() + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[2., -1.], + [-2., 3.]], + 'b': {'x': [[1., -3., 2.], + [-4., -2., 3.]]}, + })) < 1e-6).all() + + def test_ceil_(self): + t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) + t1r = t1.ceil_() + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([[2., -1.], + [-2., 3.]])) < 1e-6).all() + + t2 = ttorch.tensor({ + 'a': [[1.2, -1.8], [-2.3, 2.8]], + 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + }) + t2r = t2.ceil_() + assert t2r is t2 + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [[2., -1.], + [-2., 3.]], + 'b': {'x': [[1., -3., 2.], + [-4., -2., 3.]]}, + })) < 1e-6).all() + + def test_sigmoid(self): + t1 = ttorch.tensor([1.0, 2.0, -1.5]).sigmoid() + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([0.7311, 0.8808, 0.1824])) < 1e-4).all() + + t2 = ttorch.tensor({ + 'a': [1.0, 2.0, -1.5], + 'b': {'x': [[0.5, 1.2], [-2.5, 0.25]]}, + }).sigmoid() + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [0.7311, 0.8808, 0.1824], + 'b': {'x': [[0.6225, 0.7685], + [0.0759, 0.5622]]}, + })) < 1e-4).all() + + def test_sigmoid_(self): + t1 = ttorch.tensor([1.0, 2.0, -1.5]) + t1r = t1.sigmoid_() + assert t1r is t1 + assert isinstance(t1, torch.Tensor) + assert (ttorch.abs(t1 - ttorch.tensor([0.7311, 0.8808, 0.1824])) < 1e-4).all() + + t2 = ttorch.tensor({ + 'a': [1.0, 2.0, -1.5], + 'b': {'x': [[0.5, 1.2], [-2.5, 0.25]]}, + }) + t2r = t2.sigmoid_() + assert t2r is t2 + assert (ttorch.abs(t2 - ttorch.tensor({ + 'a': [0.7311, 0.8808, 0.1824], + 'b': {'x': [[0.6225, 0.7685], + [0.0759, 0.5622]]}, + })) < 1e-4).all() diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index 0a8de9eb4..f77fbea5b 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -1047,6 +1047,26 @@ def isnan(input): @doc_from(torch.abs) @func_treelize() def abs(input, *args, **kwargs): + """ + Computes the absolute value of each element in ``input``. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.abs(ttorch.tensor([12, 0, -3])) + tensor([12, 0, 3]) + + >>> ttorch.abs(ttorch.tensor({ + ... 'a': [12, 0, -3], + ... 'b': {'x': [[-3, 1], [0, -2]]}, + ... })) + + ├── a --> tensor([12, 0, 3]) + └── b --> + └── x --> tensor([[3, 1], + [0, 2]]) + """ return torch.abs(input, *args, **kwargs) @@ -1055,6 +1075,30 @@ def abs(input, *args, **kwargs): @return_self @func_treelize() def abs_(input): + """ + In-place version of :func:`treetensor.torch.abs`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> t = ttorch.tensor([12, 0, -3]) + >>> ttorch.abs_(t) + >>> t + tensor([12, 0, 3]) + + >>> t = ttorch.tensor({ + ... 'a': [12, 0, -3], + ... 'b': {'x': [[-3, 1], [0, -2]]}, + ... }) + >>> ttorch.abs_(t) + >>> t + + ├── a --> tensor([12, 0, 3]) + └── b --> + └── x --> tensor([[3, 1], + [0, 2]]) + """ return torch.abs_(input) @@ -1062,14 +1106,58 @@ def abs_(input): @doc_from(torch.clamp) @func_treelize() def clamp(input, *args, **kwargs): + """ + Clamp all elements in ``input`` into the range `[` ``min``, ``max`` `]`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.clamp(ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]), min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, 0.5000]) + + >>> ttorch.clamp(ttorch.tensor({ + ... 'a': [-1.7120, 0.1734, -0.0478, 2.0922], + ... 'b': {'x': [[-0.9049, 1.7029, -0.3697], [0.0489, -1.3127, -1.0221]]}, + ... }), min=-0.5, max=0.5) + + ├── a --> tensor([-0.5000, 0.1734, -0.0478, 0.5000]) + └── b --> + └── x --> tensor([[-0.5000, 0.5000, -0.3697], + [ 0.0489, -0.5000, -0.5000]]) + """ return torch.clamp(input, *args, **kwargs) -# noinspection PyShadowingBuiltins +# noinspection PyShadowingBuiltins,PyUnresolvedReferences @doc_from(torch.clamp_) @return_self @func_treelize() def clamp_(input, *args, **kwargs): + """ + In-place version of :func:`treetensor.torch.clamp`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> t = ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]) + >>> ttorch.clamp_(t, min=-0.5, max=0.5) + >>> t + tensor([-0.5000, 0.1734, -0.0478, 0.5000]) + + >>> t = ttorch.tensor({ + ... 'a': [-1.7120, 0.1734, -0.0478, 2.0922], + ... 'b': {'x': [[-0.9049, 1.7029, -0.3697], [0.0489, -1.3127, -1.0221]]}, + ... }) + >>> ttorch.clamp_(t, min=-0.5, max=0.5) + >>> t + + ├── a --> tensor([-0.5000, 0.1734, -0.0478, 0.5000]) + └── b --> + └── x --> tensor([[-0.5000, 0.5000, -0.3697], + [ 0.0489, -0.5000, -0.5000]]) + """ return torch.clamp_(input, *args, **kwargs) @@ -1077,6 +1165,26 @@ def clamp_(input, *args, **kwargs): @doc_from(torch.sign) @func_treelize() def sign(input, *args, **kwargs): + """ + Returns a tree of new tensors with the signs of the elements of input. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.sign(ttorch.tensor([12, 0, -3])) + tensor([ 1, 0, -1]) + + >>> ttorch.sign(ttorch.tensor({ + ... 'a': [12, 0, -3], + ... 'b': {'x': [[-3, 1], [0, -2]]}, + ... })) + + ├── a --> tensor([ 1, 0, -1]) + └── b --> + └── x --> tensor([[-1, 1], + [ 0, -1]]) + """ return torch.sign(input, *args, **kwargs) @@ -1084,6 +1192,29 @@ def sign(input, *args, **kwargs): @doc_from(torch.round) @func_treelize() def round(input, *args, **kwargs): + """ + Returns a tree of new tensors with each of the elements of ``input`` + rounded to the closest integer. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.round(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])) + tensor([[ 1., -2.], + [-2., 3.]]) + + >>> ttorch.round(ttorch.tensor({ + ... 'a': [[1.2, -1.8], [-2.3, 2.8]], + ... 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + ... })) + + ├── a --> tensor([[ 1., -2.], + │ [-2., 3.]]) + └── b --> + └── x --> tensor([[ 1., -4., 1.], + [-5., -2., 3.]]) + """ return torch.round(input, *args, **kwargs) @@ -1092,6 +1223,32 @@ def round(input, *args, **kwargs): @return_self @func_treelize() def round_(input): + """ + In-place version of :func:`treetensor.torch.round`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> t = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) + >>> ttorch.round_(t) + >>> t + tensor([[ 1., -2.], + [-2., 3.]]) + + >>> t = ttorch.tensor({ + ... 'a': [[1.2, -1.8], [-2.3, 2.8]], + ... 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + ... }) + >>> ttorch.round_(t) + >>> t + + ├── a --> tensor([[ 1., -2.], + │ [-2., 3.]]) + └── b --> + └── x --> tensor([[ 1., -4., 1.], + [-5., -2., 3.]]) + """ return torch.round_(input) @@ -1099,6 +1256,29 @@ def round_(input): @doc_from(torch.floor) @func_treelize() def floor(input, *args, **kwargs): + """ + Returns a tree of new tensors with the floor of the elements of ``input``, + the largest integer less than or equal to each element. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.floor(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])) + tensor([[ 1., -2.], + [-3., 2.]]) + + >>> ttorch.floor(ttorch.tensor({ + ... 'a': [[1.2, -1.8], [-2.3, 2.8]], + ... 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + ... })) + + ├── a --> tensor([[ 1., -2.], + │ [-3., 2.]]) + └── b --> + └── x --> tensor([[ 1., -4., 1.], + [-5., -2., 2.]]) + """ return torch.floor(input, *args, **kwargs) @@ -1107,6 +1287,32 @@ def floor(input, *args, **kwargs): @return_self @func_treelize() def floor_(input): + """ + In-place version of :func:`treetensor.torch.floor`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> t = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) + >>> ttorch.floor_(t) + >>> t + tensor([[ 1., -2.], + [-3., 2.]]) + + >>> t = ttorch.tensor({ + ... 'a': [[1.2, -1.8], [-2.3, 2.8]], + ... 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + ... }) + >>> ttorch.floor_(t) + >>> t + + ├── a --> tensor([[ 1., -2.], + │ [-3., 2.]]) + └── b --> + └── x --> tensor([[ 1., -4., 1.], + [-5., -2., 2.]]) + """ return torch.floor_(input) @@ -1114,6 +1320,29 @@ def floor_(input): @doc_from(torch.ceil) @func_treelize() def ceil(input, *args, **kwargs): + """ + Returns a tree of new tensors with the ceil of the elements of ``input``, + the smallest integer greater than or equal to each element. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.ceil(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])) + tensor([[ 2., -1.], + [-2., 3.]]) + + >>> ttorch.ceil(ttorch.tensor({ + ... 'a': [[1.2, -1.8], [-2.3, 2.8]], + ... 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + ... })) + + ├── a --> tensor([[ 2., -1.], + │ [-2., 3.]]) + └── b --> + └── x --> tensor([[ 1., -3., 2.], + [-4., -2., 3.]]) + """ return torch.ceil(input, *args, **kwargs) @@ -1122,6 +1351,32 @@ def ceil(input, *args, **kwargs): @return_self @func_treelize() def ceil_(input): + """ + In-place version of :func:`treetensor.torch.ceil`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> t = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) + >>> ttorch.ceil_(t) + >>> t + tensor([[ 2., -1.], + [-2., 3.]]) + + >>> t = ttorch.tensor({ + ... 'a': [[1.2, -1.8], [-2.3, 2.8]], + ... 'b': {'x': [[1.0, -3.9, 1.3], [-4.8, -2.0, 2.8]]}, + ... }) + >>> ttorch.ceil_(t) + >>> t + + ├── a --> tensor([[ 2., -1.], + │ [-2., 3.]]) + └── b --> + └── x --> tensor([[ 1., -3., 2.], + [-4., -2., 3.]]) + """ return torch.ceil_(input) @@ -1130,19 +1385,19 @@ def ceil_(input): @func_treelize() def sigmoid(input, *args, **kwargs): """ - Get a tree of new tensors with the sigmoid of the elements of ``input``. + Returns a tree of new tensors with the sigmoid of the elements of ``input``. Examples:: >>> import torch >>> import treetensor.torch as ttorch - >>> ttorch.tensor([1.0, 2.0, -1.5]).sigmoid() + >>> ttorch.sigmoid(ttorch.tensor([1.0, 2.0, -1.5])) tensor([0.7311, 0.8808, 0.1824]) - >>> ttorch.tensor({ + >>> ttorch.sigmoid(ttorch.tensor({ ... 'a': [1.0, 2.0, -1.5], ... 'b': {'x': [[0.5, 1.2], [-2.5, 0.25]]}, - ... }).sigmoid() + ... })) ├── a --> tensor([0.7311, 0.8808, 0.1824]) └── b --> -- GitLab