diff --git a/test/torch/__init__.py b/test/torch/__init__.py index a78ed870d9be2e5912ab08c243239a40d7167194..1a713f0eb1fd5423dd3deafd14f69df479f6e27e 100644 --- a/test/torch/__init__.py +++ b/test/torch/__init__.py @@ -1,2 +1,2 @@ -from .test_funcs import TestTensorFuncs -from .test_treetensor import TestTensorTreetensor +from .test_funcs import TestTorchFuncs +from .test_tensor import TestTorchTensor diff --git a/test/torch/test_funcs.py b/test/torch/test_funcs.py index 23a2a4ec485f0b96c4c5182f59ace0b73cc7d69d..ec7c35712898be5e5376059ff5f2f11fa5e3bf67 100644 --- a/test/torch/test_funcs.py +++ b/test/torch/test_funcs.py @@ -7,7 +7,34 @@ import treetensor.torch as ttorch # noinspection DuplicatedCode @pytest.mark.unittest -class TestTensorFuncs: +class TestTorchFuncs: + def test_tensor(self): + t1 = ttorch.tensor(True) + assert isinstance(t1, torch.Tensor) + assert t1 + + t2 = ttorch.tensor([[1, 2, 3], [4, 5, 6]]) + assert isinstance(t2, torch.Tensor) + assert (t2 == torch.tensor([[1, 2, 3], [4, 5, 6]])).all() + + t3 = ttorch.tensor({ + 'a': [1, 2], + 'b': [[3, 4], [5, 6.2]], + 'x': { + 'c': True, + 'd': [False, True], + } + }) + assert isinstance(t3, ttorch.Tensor) + assert (t3 == ttorch.Tensor({ + 'a': torch.tensor([1, 2]), + 'b': torch.tensor([[3, 4], [5, 6.2]]), + 'x': { + 'c': torch.tensor(True), + 'd': torch.tensor([False, True]), + } + })).all() + def test_zeros(self): assert ttorch.all(ttorch.zeros((2, 3)) == torch.zeros(2, 3)) assert ttorch.all(ttorch.zeros(TreeValue({ @@ -16,7 +43,7 @@ class TestTensorFuncs: 'x': { 'c': (2, 3, 4), } - })) == ttorch.TreeTensor({ + })) == ttorch.Tensor({ 'a': torch.zeros(2, 3), 'b': torch.zeros(5, 6), 'x': { @@ -30,14 +57,14 @@ class TestTensorFuncs: torch.tensor([[0, 0, 0], [0, 0, 0]]), ) assert ttorch.all( - ttorch.zeros_like(ttorch.TreeTensor({ + ttorch.zeros_like(({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { 'c': torch.tensor([5, 6, 7]), 'd': torch.tensor([[[8, 9]]]), } - })) == ttorch.TreeTensor({ + })) == ttorch.Tensor({ 'a': torch.tensor([[0, 0, 0], [0, 0, 0]]), 'b': torch.tensor([0, 0, 0, 0]), 'x': { @@ -55,7 +82,7 @@ class TestTensorFuncs: 'x': { 'c': (2, 3, 4), } - })) == ttorch.TreeTensor({ + })) == ttorch.Tensor({ 'a': torch.ones(2, 3), 'b': torch.ones(5, 6), 'x': { @@ -69,14 +96,14 @@ class TestTensorFuncs: torch.tensor([[1, 1, 1], [1, 1, 1]]) ) assert ttorch.all( - ttorch.ones_like(ttorch.TreeTensor({ + ttorch.ones_like(({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { 'c': torch.tensor([5, 6, 7]), 'd': torch.tensor([[[8, 9]]]), } - })) == ttorch.TreeTensor({ + })) == ttorch.Tensor({ 'a': torch.tensor([[1, 1, 1], [1, 1, 1]]), 'b': torch.tensor([1, 1, 1, 1]), 'x': { @@ -99,7 +126,7 @@ class TestTensorFuncs: 'c': (2, 3, 4), } })) - assert _target.shape == ttorch.TreeSize({ + assert _target.shape == ttorch.Size({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), 'x': { @@ -113,7 +140,7 @@ class TestTensorFuncs: assert 0.98 <= _target.view(60000).std().tolist() <= 1.02 assert _target.shape == torch.Size([200, 300]) - _target = ttorch.randn_like(ttorch.TreeTensor({ + _target = ttorch.randn_like(({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), 'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32), 'x': { @@ -121,7 +148,7 @@ class TestTensorFuncs: 'd': torch.tensor([[[8, 9]]], dtype=torch.float32), } })) - assert _target.shape == ttorch.TreeSize({ + assert _target.shape == ttorch.Size({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), 'x': { @@ -140,7 +167,7 @@ class TestTensorFuncs: })) assert ttorch.all(_target < 10) assert ttorch.all(-10 <= _target) - assert _target.shape == ttorch.TreeSize({ + assert _target.shape == ttorch.Size({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), 'x': { @@ -157,7 +184,7 @@ class TestTensorFuncs: })) assert ttorch.all(_target < 10) assert ttorch.all(0 <= _target) - assert _target.shape == ttorch.TreeSize({ + assert _target.shape == ttorch.Size({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), 'x': { @@ -166,7 +193,7 @@ class TestTensorFuncs: }) def test_randint_like(self): - _target = ttorch.randint_like(ttorch.TreeTensor({ + _target = ttorch.randint_like(({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { @@ -176,7 +203,7 @@ class TestTensorFuncs: }), -10, 10) assert ttorch.all(_target < 10) assert ttorch.all(-10 <= _target) - assert _target.shape == ttorch.TreeSize({ + assert _target.shape == ttorch.Size({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), 'x': { @@ -185,7 +212,7 @@ class TestTensorFuncs: } }) - _target = ttorch.randint_like(ttorch.TreeTensor({ + _target = ttorch.randint_like(({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { @@ -195,7 +222,7 @@ class TestTensorFuncs: }), 10) assert ttorch.all(_target < 10) assert ttorch.all(0 <= _target) - assert _target.shape == ttorch.TreeSize({ + assert _target.shape == ttorch.Size({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), 'x': { @@ -213,7 +240,7 @@ class TestTensorFuncs: } }), 233) assert ttorch.all(_target == 233) - assert _target.shape == ttorch.TreeSize({ + assert _target.shape == ttorch.Size({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), 'x': { @@ -222,7 +249,7 @@ class TestTensorFuncs: }) def test_full_like(self): - _target = ttorch.full_like(ttorch.TreeTensor({ + _target = ttorch.full_like(({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { @@ -231,7 +258,7 @@ class TestTensorFuncs: } }), 233) assert ttorch.all(_target == 233) - assert _target.shape == ttorch.TreeSize({ + assert _target.shape == ttorch.Size({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), 'x': { @@ -248,7 +275,7 @@ class TestTensorFuncs: 'c': (2, 3, 4), } })) - assert _target.shape == ttorch.TreeSize({ + assert _target.shape == ttorch.Size({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), 'x': { @@ -257,7 +284,7 @@ class TestTensorFuncs: }) def test_empty_like(self): - _target = ttorch.empty_like(ttorch.TreeTensor({ + _target = ttorch.empty_like(({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { @@ -265,7 +292,7 @@ class TestTensorFuncs: 'd': torch.tensor([[[8, 9]]]), } })) - assert _target.shape == ttorch.TreeSize({ + assert _target.shape == ttorch.Size({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), 'x': { @@ -290,7 +317,7 @@ class TestTensorFuncs: assert r3 == torch.tensor(False) assert not r3 - r4 = ttorch.all(ttorch.TreeTensor({ + r4 = ttorch.all(({ 'a': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, True]), })).all() @@ -298,7 +325,7 @@ class TestTensorFuncs: assert r4 == torch.tensor(True) assert r4 - r5 = ttorch.all(ttorch.TreeTensor({ + r5 = ttorch.all(({ 'a': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, False]), })).all() @@ -306,7 +333,7 @@ class TestTensorFuncs: assert r5 == torch.tensor(False) assert not r5 - r6 = ttorch.all(ttorch.TreeTensor({ + r6 = ttorch.all(({ 'a': torch.tensor([False, False, False]), 'b': torch.tensor([False, False, False]), })).all() @@ -330,7 +357,7 @@ class TestTensorFuncs: assert r3 == torch.tensor(False) assert not r3 - r4 = ttorch.any(ttorch.TreeTensor({ + r4 = ttorch.any(({ 'a': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, True]), })).all() @@ -338,7 +365,7 @@ class TestTensorFuncs: assert r4 == torch.tensor(True) assert r4 - r5 = ttorch.any(ttorch.TreeTensor({ + r5 = ttorch.any(({ 'a': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, False]), })).all() @@ -346,7 +373,7 @@ class TestTensorFuncs: assert r5 == torch.tensor(True) assert r5 - r6 = ttorch.any(ttorch.TreeTensor({ + r6 = ttorch.any(({ 'a': torch.tensor([False, False, False]), 'b': torch.tensor([False, False, False]), })).all() @@ -360,17 +387,17 @@ class TestTensorFuncs: assert ttorch.eq(torch.tensor([1, 1, 1]), 1).all() assert not ttorch.eq(torch.tensor([1, 1, 2]), 1).all() - assert ttorch.eq(ttorch.TreeTensor({ + assert ttorch.eq(({ 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6]), - }), ttorch.TreeTensor({ + }), ({ 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6]), })).all() - assert not ttorch.eq(ttorch.TreeTensor({ + assert not ttorch.eq(({ 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6]), - }), ttorch.TreeTensor({ + }), ({ 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 5]), })).all() @@ -384,20 +411,20 @@ class TestTensorFuncs: assert isinstance(p2, bool) assert not p2 - p3 = ttorch.equal(ttorch.TreeTensor({ + p3 = ttorch.equal(({ 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6]), - }), ttorch.TreeTensor({ + }), ({ 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6]), })) assert isinstance(p3, bool) assert p3 - p4 = ttorch.equal(ttorch.TreeTensor({ + p4 = ttorch.equal(({ 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6]), - }), ttorch.TreeTensor({ + }), ({ 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 5]), })) diff --git a/test/torch/test_treetensor.py b/test/torch/test_tensor.py similarity index 90% rename from test/torch/test_treetensor.py rename to test/torch/test_tensor.py index ed436370b85b0a6abb0290b5959701e1ba8efd5d..18360afe4b0af97ac98542029e502232d7e3e9b8 100644 --- a/test/torch/test_treetensor.py +++ b/test/torch/test_tensor.py @@ -6,12 +6,12 @@ from treevalue import func_treelize import treetensor.numpy as tnp import treetensor.torch as ttorch -_all_is = func_treelize(return_type=ttorch.TreeTensor)(lambda x, y: x is y) +_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y) @pytest.mark.unittest -class TestTensorTreetensor: - _DEMO_1 = ttorch.TreeTensor({ +class TestTorchTensor: + _DEMO_1 = ttorch.Tensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([[1, 2], [5, 6]]), 'x': { @@ -20,7 +20,7 @@ class TestTensorTreetensor: } }) - _DEMO_2 = ttorch.TreeTensor({ + _DEMO_2 = ttorch.Tensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([[1, 2], [5, 60]]), 'x': { @@ -47,7 +47,7 @@ class TestTensorTreetensor: assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values())) def test_to(self): - assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.TreeTensor({ + assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), 'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32), 'x': { diff --git a/treetensor/__init__.py b/treetensor/__init__.py index 85e9a7cad40b27a4e3ff0a1d58037bc7356647c2..3b27b659ce5ce47c4a42033529205285ecbbce96 100644 --- a/treetensor/__init__.py +++ b/treetensor/__init__.py @@ -1,3 +1,3 @@ from .common import TreeObject from .numpy import TreeNumpy -from .torch import TreeTensor +from .torch import Tensor diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index 6ca9399af3790d60a5c561e8aeedd9cec52b1280..de688e73717bfd7b877ff86b6c55ef107d14295f 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -1,11 +1,13 @@ import builtins import torch +from treevalue import TreeValue from treevalue import func_treelize as original_func_treelize +from treevalue.utils import post_process -from .tensor import TreeTensor, tireduce +from .tensor import Tensor, tireduce from ..common import TreeObject, ireduce -from ..utils import replaceable_partial, doc_from +from ..utils import replaceable_partial, doc_from, args_mapping __all__ = [ 'zeros', 'zeros_like', @@ -16,9 +18,13 @@ __all__ = [ 'empty', 'empty_like', 'all', 'any', 'eq', 'equal', + 'tensor', ] -func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor) +func_treelize = post_process(post_process(args_mapping( + lambda i, x: Tensor(x) if isinstance(x, (dict, TreeValue)) else x)))( + replaceable_partial(original_func_treelize, return_type=Tensor) +) @doc_from(torch.zeros) @@ -102,18 +108,20 @@ def all(input_, *args, **kwargs): Example:: + >>> import torch + >>> import treetensor.torch as ttorch >>> all(torch.tensor([True, True])) # the same as torch.all torch.tensor(True) - >>> all(TreeTensor({ - >>> 'a': torch.tensor([True, True]), - >>> 'b': torch.tensor([True, True]), + >>> all(ttorch.tensor({ + >>> 'a': [True, True], + >>> 'b': [True, True], >>> })) torch.tensor(True) - >>> all(TreeTensor({ - >>> 'a': torch.tensor([True, True]), - >>> 'b': torch.tensor([True, False]), + >>> all(Tensor({ + >>> 'a': [True, True], + >>> 'b': [True, False], >>> })) torch.tensor(False) @@ -139,3 +147,30 @@ def eq(input_, other, *args, **kwargs): @func_treelize() def equal(input_, other, *args, **kwargs): return torch.equal(input_, other, *args, **kwargs) + + +@doc_from(torch.tensor) +@func_treelize() +def tensor(*args, **kwargs): + """ + In ``treetensor``, you can create a tree tensor with simple data structure. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.tensor(True) # the same as torch.tensor(True) + torch.tensor(True) + + >>> ttorch.tensor([1, 2, 3]) # the same as torch.tensor([1, 2, 3]) + torch.tensor([1, 2, 3]) + + >>> ttorch.tensor({'a': 1, 'b': [1, 2, 3], 'c': [[True, False], [False, True]]}) + ttorch.Tensor({ + 'a': torch.tensor(1), + 'b': torch.tensor([1, 2, 3]), + 'c': torch.tensor([[True, False], [False, True]]), + }) + + """ + return torch.tensor(*args, **kwargs) diff --git a/treetensor/torch/size.py b/treetensor/torch/size.py index 0aa014661951142d375802df5660fdb04396c7eb..b5cb3f1dd7bf8523bbbeaff24ece6a4bfe1bcd75 100644 --- a/treetensor/torch/size.py +++ b/treetensor/torch/size.py @@ -7,12 +7,12 @@ from ..utils import replaceable_partial func_treelize = replaceable_partial(original_func_treelize) __all__ = [ - 'TreeSize' + 'Size' ] # noinspection PyTypeChecker -class TreeSize(TreeObject): +class Size(TreeObject): @func_treelize(return_type=TreeObject) def numel(self: torch.Size) -> TreeObject: return self.numel() diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index 2a760dcfaf246a81ded7786ce870fbf28e10c1de..aadb345103d88219a9de1d4eca5772e6a3164c73 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -3,13 +3,13 @@ import torch from treevalue import method_treelize from treevalue.utils import pre_process -from .size import TreeSize +from .size import Size from ..common import TreeObject, TreeData, ireduce from ..numpy import TreeNumpy from ..utils import inherit_names, current_names, doc_from __all__ = [ - 'TreeTensor' + 'Tensor' ] _reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {})) @@ -19,7 +19,7 @@ tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduc # noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList @current_names() @inherit_names(TreeData) -class TreeTensor(TreeData): +class Tensor(TreeData): @doc_from(torch.Tensor.numpy) @method_treelize(return_type=TreeNumpy) def numpy(self: torch.Tensor) -> np.ndarray: @@ -53,7 +53,7 @@ class TreeTensor(TreeData): @property @doc_from(torch.Tensor.shape) - @method_treelize(return_type=TreeSize) + @method_treelize(return_type=Size) def shape(self: torch.Tensor): return self.shape diff --git a/treetensor/utils/func.py b/treetensor/utils/func.py index 61c15f758b6af426e9bbe374068668da67a03cfd..10cce7a0db83d4815a38ca035fac2bfc040cea00 100644 --- a/treetensor/utils/func.py +++ b/treetensor/utils/func.py @@ -1,10 +1,29 @@ +from functools import wraps +from typing import Callable, Union, Any + __all__ = [ 'replaceable_partial', + 'args_mapping', ] def replaceable_partial(func, **kws): + @wraps(func) def _new_func(*args, **kwargs): return func(*args, **{**kws, **kwargs}) return _new_func + + +def args_mapping(mapper: Callable[[Union[int, str], Any], Any]): + def _decorator(func): + @wraps(func) + def _new_func(*args, **kwargs): + return func( + *(mapper(i, x) for i, x in enumerate(args)), + **{k: mapper(k, v) for k, v in kwargs.items()}, + ) + + return _new_func + + return _decorator