diff --git a/test/torch/test_funcs.py b/test/torch/test_funcs.py index 315b47039f98cfca49c91a94c4299121324aab47..4e17619e9069e7ac2f1568fd6622ba7197728664 100644 --- a/test/torch/test_funcs.py +++ b/test/torch/test_funcs.py @@ -4,7 +4,7 @@ import torch import treetensor.torch as ttorch -# noinspection DuplicatedCode +# noinspection DuplicatedCode,PyUnresolvedReferences @pytest.mark.unittest class TestTorchFuncs: def test_tensor(self): @@ -567,3 +567,80 @@ class TestTorchFuncs: 'a': [1.0, 2.0, 1.5], 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, })) == torch.tensor(11.0) + + def test_clone(self): + t1 = ttorch.clone(torch.tensor([1.0, 2.0, 1.5])) + assert isinstance(t1, torch.Tensor) + assert (t1 == torch.tensor([1.0, 2.0, 1.5])).all() + + t2 = ttorch.clone(ttorch.tensor({ + 'a': [1.0, 2.0, 1.5], + 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + })) + assert (t2 == ttorch.tensor({ + 'a': [1.0, 2.0, 1.5], + 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + })).all() + + def test_dot(self): + t1 = ttorch.dot(torch.tensor([1, 2]), torch.tensor([2, 3])) + assert isinstance(t1, torch.Tensor) + assert t1.tolist() == 8 + + t2 = ttorch.dot( + ttorch.tensor({ + 'a': [1, 2, 3], + 'b': {'x': [3, 4]}, + }), + ttorch.tensor({ + 'a': [5, 6, 7], + 'b': {'x': [1, 2]}, + }) + ) + assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all() + + def test_matmul(self): + t1 = ttorch.matmul( + torch.tensor([[1, 2], [3, 4]]), + torch.tensor([[5, 6], [7, 2]]), + ) + assert isinstance(t1, torch.Tensor) + assert (t1 == torch.tensor([[19, 10], [43, 26]])).all() + + t2 = ttorch.matmul( + ttorch.tensor({ + 'a': [[1, 2], [3, 4]], + 'b': {'x': [3, 4, 5, 6]}, + }), + ttorch.tensor({ + 'a': [[5, 6], [7, 2]], + 'b': {'x': [4, 3, 2, 1]}, + }), + ) + assert (t2 == ttorch.tensor({ + 'a': [[19, 10], [43, 26]], + 'b': {'x': 40} + })).all() + + def test_mm(self): + t1 = ttorch.mm( + torch.tensor([[1, 2], [3, 4]]), + torch.tensor([[5, 6], [7, 2]]), + ) + assert isinstance(t1, torch.Tensor) + assert (t1 == torch.tensor([[19, 10], [43, 26]])).all() + + t2 = ttorch.mm( + ttorch.tensor({ + 'a': [[1, 2], [3, 4]], + 'b': {'x': [[3, 4, 5], [6, 7, 8]]}, + }), + ttorch.tensor({ + 'a': [[5, 6], [7, 2]], + 'b': {'x': [[6, 5], [4, 3], [2, 1]]}, + }), + ) + assert (t2 == ttorch.tensor({ + 'a': [[19, 10], [43, 26]], + 'b': {'x': [[44, 32], [80, 59]]}, + })).all() diff --git a/test/torch/test_tensor.py b/test/torch/test_tensor.py index 7f2a03df56d46873f6c288bcf6416229d0176436..66703a40816a312f1e2703368c85df08dcc58695 100644 --- a/test/torch/test_tensor.py +++ b/test/torch/test_tensor.py @@ -10,6 +10,7 @@ from treetensor.common import Object _all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y) +# noinspection PyUnresolvedReferences @pytest.mark.unittest class TestTorchTensor: _DEMO_1 = ttorch.Tensor({ @@ -208,3 +209,75 @@ class TestTorchTensor: 'a': [True, False], 'b': {'x': [[True, True], [False, True]]} })).all() + + def test_clone(self): + t1 = ttorch.tensor([1.0, 2.0, 1.5]).clone() + assert isinstance(t1, torch.Tensor) + assert (t1 == torch.tensor([1.0, 2.0, 1.5])).all() + + t2 = ttorch.tensor({ + 'a': [1.0, 2.0, 1.5], + 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + }).clone() + assert (t2 == ttorch.tensor({ + 'a': [1.0, 2.0, 1.5], + 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + })).all() + + def test_dot(self): + t1 = torch.tensor([1, 2]).dot(torch.tensor([2, 3])) + assert isinstance(t1, torch.Tensor) + assert t1.tolist() == 8 + + t2 = ttorch.tensor({ + 'a': [1, 2, 3], + 'b': {'x': [3, 4]}, + }).dot( + ttorch.tensor({ + 'a': [5, 6, 7], + 'b': {'x': [1, 2]}, + }) + ) + assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all() + + def test_matmul(self): + t1 = torch.tensor([[1, 2], [3, 4]]).matmul( + torch.tensor([[5, 6], [7, 2]]), + ) + assert isinstance(t1, torch.Tensor) + assert (t1 == torch.tensor([[19, 10], [43, 26]])).all() + + t2 = ttorch.tensor({ + 'a': [[1, 2], [3, 4]], + 'b': {'x': [3, 4, 5, 6]}, + }).matmul( + ttorch.tensor({ + 'a': [[5, 6], [7, 2]], + 'b': {'x': [4, 3, 2, 1]}, + }), + ) + assert (t2 == ttorch.tensor({ + 'a': [[19, 10], [43, 26]], + 'b': {'x': 40} + })).all() + + def test_mm(self): + t1 = torch.tensor([[1, 2], [3, 4]]).mm( + torch.tensor([[5, 6], [7, 2]]), + ) + assert isinstance(t1, torch.Tensor) + assert (t1 == torch.tensor([[19, 10], [43, 26]])).all() + + t2 = ttorch.tensor({ + 'a': [[1, 2], [3, 4]], + 'b': {'x': [[3, 4, 5], [6, 7, 8]]}, + }).mm( + ttorch.tensor({ + 'a': [[5, 6], [7, 2]], + 'b': {'x': [[6, 5], [4, 3], [2, 1]]}, + }), + ) + assert (t2 == ttorch.tensor({ + 'a': [[19, 10], [43, 26]], + 'b': {'x': [[44, 32], [80, 59]]}, + })).all() diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index 3908edb6998b7d1d1c092a49b149395f75d9bd54..6cb91283670a0b7d0bc81a2fcaa845715b5b5877 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -20,7 +20,8 @@ __all__ = [ 'all', 'any', 'min', 'max', 'sum', 'eq', 'ne', 'lt', 'le', 'gt', 'ge', - 'equal', 'tensor', + 'equal', 'tensor', 'clone', + 'dot', 'matmul', 'mm', ] func_treelize = post_process(post_process(args_mapping( @@ -816,3 +817,140 @@ def tensor(*args, **kwargs): [False, True]]) """ return torch.tensor(*args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.clone) +@func_treelize() +def clone(input, *args, **kwargs): + """ + In ``treetensor``, you can create a clone of the original tree with :func:`treetensor.torch.clone`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.clone(torch.tensor([[1, 2], [3, 4]])) + tensor([[1, 2], + [3, 4]]) + + >>> ttorch.clone(ttorch.tensor({ + ... 'a': [[1, 2], [3, 4]], + ... 'b': {'x': [[5], [6], [7]]}, + ... })) + + ├── a --> tensor([[1, 2], + │ [3, 4]]) + └── b --> + └── x --> tensor([[5], + [6], + [7]]) + """ + return torch.clone(input, *args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.dot) +@func_treelize() +def dot(input, other, *args, **kwargs): + """ + In ``treetensor``, you can get the dot product of 2 tree tensors with :func:`treetensor.torch.dot`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.dot(torch.tensor([1, 2]), torch.tensor([2, 3])) + tensor(8) + + >>> ttorch.dot( + ... ttorch.tensor({ + ... 'a': [1, 2, 3], + ... 'b': {'x': [3, 4]}, + ... }), + ... ttorch.tensor({ + ... 'a': [5, 6, 7], + ... 'b': {'x': [1, 2]}, + ... }) + ... ) + + ├── a --> tensor(38) + └── b --> + └── x --> tensor(11) + """ + return torch.dot(input, other, *args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.matmul) +@func_treelize() +def matmul(input, other, *args, **kwargs): + """ + In ``treetensor``, you can create a matrix product with :func:`treetensor.torch.matmul`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.matmul( + ... torch.tensor([[1, 2], [3, 4]]), + ... torch.tensor([[5, 6], [7, 2]]), + ... ) + tensor([[19, 10], + [43, 26]]) + + >>> ttorch.matmul( + ... ttorch.tensor({ + ... 'a': [[1, 2], [3, 4]], + ... 'b': {'x': [3, 4, 5, 6]}, + ... }), + ... ttorch.tensor({ + ... 'a': [[5, 6], [7, 2]], + ... 'b': {'x': [4, 3, 2, 1]}, + ... }), + ... ) + + ├── a --> tensor([[19, 10], + │ [43, 26]]) + └── b --> + └── x --> tensor(40) + """ + return torch.matmul(input, other, *args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.mm) +@func_treelize() +def mm(input, mat2, *args, **kwargs): + """ + In ``treetensor``, you can create a matrix multiplication with :func:`treetensor.torch.mm`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.mm( + ... torch.tensor([[1, 2], [3, 4]]), + ... torch.tensor([[5, 6], [7, 2]]), + ... ) + tensor([[19, 10], + [43, 26]]) + + >>> ttorch.mm( + ... ttorch.tensor({ + ... 'a': [[1, 2], [3, 4]], + ... 'b': {'x': [[3, 4, 5], [6, 7, 8]]}, + ... }), + ... ttorch.tensor({ + ... 'a': [[5, 6], [7, 2]], + ... 'b': {'x': [[6, 5], [4, 3], [2, 1]]}, + ... }), + ... ) + + ├── a --> tensor([[19, 10], + │ [43, 26]]) + └── b --> + └── x --> tensor([[44, 32], + [80, 59]]) + """ + return torch.mm(input, mat2, *args, **kwargs) diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index 77a9f69ce5d6d02a2a9660df00642ea94ce2731e..788e226ed34e79e995e9dcd1cfe69a8e006bbcc7 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -261,3 +261,35 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): See :func:`treetensor.torch.ge`. """ return self >= other + + @doc_from(torch.Tensor.clone) + @method_treelize() + def clone(self, *args, **kwargs): + """ + See :func:`treetensor.torch.clone`. + """ + return self.clone(*args, **kwargs) + + @doc_from(torch.Tensor.dot) + @method_treelize() + def dot(self, other, *args, **kwargs): + """ + See :func:`treetensor.torch.dot`. + """ + return self.dot(other, *args, **kwargs) + + @doc_from(torch.Tensor.mm) + @method_treelize() + def mm(self, mat2, *args, **kwargs): + """ + See :func:`treetensor.torch.mm`. + """ + return self.mm(mat2, *args, **kwargs) + + @doc_from(torch.Tensor.matmul) + @method_treelize() + def matmul(self, tensor2, *args, **kwargs): + """ + See :func:`treetensor.torch.matmul`. + """ + return self.matmul(tensor2, *args, **kwargs)