diff --git a/test/__init__.py b/test/__init__.py index 1105316c9cc682599cdd5fdc453b274f1f5aac8e..8e761ea4e21f918370bce08c39c85c11407833af 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,3 +1,4 @@ +from .common import * from .config import * from .numpy import * from .torch import * diff --git a/test/common/__init__.py b/test/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..858f44159e7abcc9da1d598801d7e6239595e370 --- /dev/null +++ b/test/common/__init__.py @@ -0,0 +1 @@ +from .test_trees import TestCommonTrees diff --git a/test/common/test_trees.py b/test/common/test_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..c420a963c177a56dfa009561b3d0eecaea04c0c5 --- /dev/null +++ b/test/common/test_trees.py @@ -0,0 +1,84 @@ +import io + +import pytest +import torch +from treevalue import typetrans, TreeValue, general_tree_value + +from treetensor.common import Object, print_tree + + +def text_compares(expected, actual): + elines = expected.splitlines() + alines = actual.splitlines() + assert len(elines) == len(alines), f"""Lines not match, +Expected: {len(elines)} lines + Actual: {len(alines)} lines +""" + + for i, (e, a) in enumerate(zip(elines, alines)): + assert e.rstrip() == a.rstrip(), f"""Line {i} not match, +Expected: {e} + Actual: {a} +""" + + +@pytest.mark.unittest +class TestCommonTrees: + def test_object(self): + t = Object(1) + assert isinstance(t, int) + assert t == 1 + + assert Object({'a': 1, 'b': 2}) == typetrans(TreeValue({ + 'a': 1, 'b': 2 + }), Object) + + def test_print_tree(self): + class _TempTree(general_tree_value()): + def __repr__(self): + with io.StringIO() as sfile: + print_tree(self, repr_=repr, ascii_=False, show_node_id=False, file=sfile) + return sfile.getvalue() + + def __str__(self): + return self.__repr__() + + text_compares("""<_TempTree> +├── a --> 1 +└── b --> 2 + """.rstrip(), str(_TempTree({ + 'a': 1, 'b': 2 + })).rstrip()) + + class _TmpObject: + def __init__(self, v): + self.__v = v + + def __repr__(self): + return self.__v + + tx = _TempTree({ + 'a': 1, 'b': 2, 'c': torch.tensor([[1, 2, ], [3, 4]]), + 'd': {'x': torch.tensor([[1], [2], [3, ], [4]])}, + 'e': _TmpObject('line after this\nhahahaha'), + }) + tx.d.y = tx + text_compares("""<_TempTree> +├── a --> 1 +├── b --> 2 +├── c --> tensor([[1, 2], +│ [3, 4]]) +├── d --> <_TempTree> +│ ├── x --> tensor([[1], +│ │ [2], +│ │ [3], +│ │ [4]]) +│ └── y --> <_TempTree> +│ (The same address as ) +└── e --> line after this + hahahaha + """.rstrip(), str(tx).rstrip()) + + with io.StringIO() as sf: + print_tree(1, file=sf) + text_compares("1", sf.getvalue()) diff --git a/test/torch/__init__.py b/test/torch/__init__.py index 1a713f0eb1fd5423dd3deafd14f69df479f6e27e..803180b8d951e2b78bac74caced4526acf978046 100644 --- a/test/torch/__init__.py +++ b/test/torch/__init__.py @@ -1,2 +1,3 @@ from .test_funcs import TestTorchFuncs +from .test_size import TestTorchSize from .test_tensor import TestTorchTensor diff --git a/test/torch/test_funcs.py b/test/torch/test_funcs.py index 25b71b16856537c483e5148d1f0c47f11cdcd4f0..315b47039f98cfca49c91a94c4299121324aab47 100644 --- a/test/torch/test_funcs.py +++ b/test/torch/test_funcs.py @@ -401,6 +401,111 @@ class TestTorchFuncs: 'b': torch.tensor([4, 5, 5]), })).all() + def test_ne(self): + assert (ttorch.ne( + torch.tensor([[1, 2], [3, 4]]), + torch.tensor([[1, 1], [4, 4]]), + ) == torch.tensor([[False, True], + [True, False]])).all() + + assert (ttorch.ne( + ttorch.tensor({ + 'a': [[1, 2], [3, 4]], + 'b': [1.0, 1.5, 2.0], + }), + ttorch.tensor({ + 'a': [[1, 1], [4, 4]], + 'b': [1.3, 1.2, 2.0], + }), + ) == ttorch.tensor({ + 'a': [[False, True], [True, False]], + 'b': [True, True, False], + })).all() + + def test_lt(self): + assert (ttorch.lt( + torch.tensor([[1, 2], [3, 4]]), + torch.tensor([[1, 1], [4, 4]]), + ) == torch.tensor([[False, False], + [True, False]])).all() + + assert (ttorch.lt( + ttorch.tensor({ + 'a': [[1, 2], [3, 4]], + 'b': [1.0, 1.5, 2.0], + }), + ttorch.tensor({ + 'a': [[1, 1], [4, 4]], + 'b': [1.3, 1.2, 2.0], + }), + ) == ttorch.tensor({ + 'a': [[False, False], [True, False]], + 'b': [True, False, False], + })).all() + + def test_le(self): + assert (ttorch.le( + torch.tensor([[1, 2], [3, 4]]), + torch.tensor([[1, 1], [4, 4]]), + ) == torch.tensor([[True, False], + [True, True]])).all() + + assert (ttorch.le( + ttorch.tensor({ + 'a': [[1, 2], [3, 4]], + 'b': [1.0, 1.5, 2.0], + }), + ttorch.tensor({ + 'a': [[1, 1], [4, 4]], + 'b': [1.3, 1.2, 2.0], + }), + ) == ttorch.tensor({ + 'a': [[True, False], [True, True]], + 'b': [True, False, True], + })).all() + + def test_gt(self): + assert (ttorch.gt( + torch.tensor([[1, 2], [3, 4]]), + torch.tensor([[1, 1], [4, 4]]), + ) == torch.tensor([[False, True], + [False, False]])).all() + + assert (ttorch.gt( + ttorch.tensor({ + 'a': [[1, 2], [3, 4]], + 'b': [1.0, 1.5, 2.0], + }), + ttorch.tensor({ + 'a': [[1, 1], [4, 4]], + 'b': [1.3, 1.2, 2.0], + }), + ) == ttorch.tensor({ + 'a': [[False, True], [False, False]], + 'b': [False, True, False], + })).all() + + def test_ge(self): + assert (ttorch.ge( + torch.tensor([[1, 2], [3, 4]]), + torch.tensor([[1, 1], [4, 4]]), + ) == torch.tensor([[True, True], + [False, True]])).all() + + assert (ttorch.ge( + ttorch.tensor({ + 'a': [[1, 2], [3, 4]], + 'b': [1.0, 1.5, 2.0], + }), + ttorch.tensor({ + 'a': [[1, 1], [4, 4]], + 'b': [1.3, 1.2, 2.0], + }), + ) == ttorch.tensor({ + 'a': [[True, True], [False, True]], + 'b': [False, True, True], + })).all() + def test_equal(self): p1 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])) assert isinstance(p1, bool) @@ -429,3 +534,36 @@ class TestTorchFuncs: })) assert isinstance(p4, bool) assert not p4 + + def test_min(self): + t1 = ttorch.min(torch.tensor([1.0, 2.0, 1.5])) + assert isinstance(t1, torch.Tensor) + assert t1 == torch.tensor(1.0) + + assert ttorch.min(ttorch.tensor({ + 'a': [1.0, 2.0, 1.5], + 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + })) == ttorch.tensor({ + 'a': 1.0, + 'b': {'x': 0.9}, + }) + + def test_max(self): + t1 = ttorch.max(torch.tensor([1.0, 2.0, 1.5])) + assert isinstance(t1, torch.Tensor) + assert t1 == torch.tensor(2.0) + + assert ttorch.max(ttorch.tensor({ + 'a': [1.0, 2.0, 1.5], + 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + })) == ttorch.tensor({ + 'a': 2.0, + 'b': {'x': 2.5, } + }) + + def test_sum(self): + assert ttorch.sum(torch.tensor([1.0, 2.0, 1.5])) == torch.tensor(4.5) + assert ttorch.sum(ttorch.tensor({ + 'a': [1.0, 2.0, 1.5], + 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, + })) == torch.tensor(11.0) diff --git a/test/torch/test_size.py b/test/torch/test_size.py new file mode 100644 index 0000000000000000000000000000000000000000..bfaa2a14bf8dbd86170a7516e84c29349ebef287 --- /dev/null +++ b/test/torch/test_size.py @@ -0,0 +1,60 @@ +import pytest +import torch +from treevalue import func_treelize, typetrans, TreeValue + +import treetensor.torch as ttorch +from treetensor.common import Object + +_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y) + + +@pytest.mark.unittest +class TestTorchSize: + def test_init(self): + t1 = ttorch.Size([1, 2, 3]) + assert isinstance(t1, torch.Size) + assert t1 == torch.Size([1, 2, 3]) + + t2 = ttorch.Size({ + 'a': [1, 2, 3], + 'b': {'x': [3, 4, ]}, + 'c': [5], + }) + assert isinstance(t2, ttorch.Size) + assert typetrans(t2, TreeValue) == TreeValue({ + 'a': torch.Size([1, 2, 3]), + 'b': {'x': torch.Size([3, 4, ])}, + 'c': torch.Size([5]), + }) + + def test_numel(self): + assert ttorch.Size({ + 'a': [1, 2, 3], + 'b': {'x': [3, 4, ]}, + 'c': [5], + }).numel() == 23 + + def test_index(self): + assert ttorch.Size({ + 'a': [1, 2, 3], + 'b': {'x': [3, 4, ]}, + 'c': [5], + }).index(3) == Object({ + 'a': 2, + 'b': {'x': 0}, + 'c': None + }) + + with pytest.raises(ValueError): + ttorch.Size({ + 'a': [1, 2, 3], + 'b': {'x': [3, 4, ]}, + 'c': [5], + }).index(100) + + def test_count(self): + assert ttorch.Size({ + 'a': [1, 2, 3], + 'b': {'x': [3, 4, ]}, + 'c': [5], + }).count(3) == 2 diff --git a/test/torch/test_tensor.py b/test/torch/test_tensor.py index 79ff0d6673c35899527acb0d5282a1842c75a216..b2d6ffd72af9990cab61dc5303d340d71802aecc 100644 --- a/test/torch/test_tensor.py +++ b/test/torch/test_tensor.py @@ -1,7 +1,7 @@ import numpy as np import pytest import torch -from treevalue import func_treelize +from treevalue import func_treelize, typetrans, TreeValue import treetensor.numpy as tnp import treetensor.torch as ttorch @@ -29,6 +29,18 @@ class TestTorchTensor: } }) + def test_init(self): + assert (ttorch.Tensor([1, 2, 3]) == torch.tensor([1, 2, 3])).all() + assert (ttorch.Tensor([1, 2, 3], dtype=torch.float32) == torch.FloatTensor([1, 2, 3])).all() + assert (self._DEMO_1 == typetrans(TreeValue({ + 'a': ttorch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': ttorch.tensor([[1, 2], [5, 6]]), + 'x': { + 'c': ttorch.tensor([3, 5, 6, 7]), + 'd': ttorch.tensor([[[1, 2], [8, 9]]]), + } + }), ttorch.Tensor)).all() + def test_numel(self): assert self._DEMO_1.numel() == 18 @@ -59,3 +71,6 @@ class TestTorchTensor: def test_all(self): assert (self._DEMO_1 == self._DEMO_1).all() assert not (self._DEMO_1 == self._DEMO_2).all() + + def test_tolist(self): + pass diff --git a/treetensor/common/trees.py b/treetensor/common/trees.py index ee0757be2faa5df2f9d8b70282f48304f6f25cda..885a824027d48bf282cabcde39e0fbb2ec93a9e3 100644 --- a/treetensor/common/trees.py +++ b/treetensor/common/trees.py @@ -18,15 +18,8 @@ __all__ = [ ] -def _tree_title(node: TreeValue): - _tree = get_data_property(node) - return "<{cls} {id}>".format( - cls=node.__class__.__name__, - id=hex(id(_tree.actual())), - ) - - -def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, file=None): +def print_tree(tree: TreeValue, repr_: Callable = str, + ascii_: bool = False, show_node_id: bool = True, file=None): print_to_file = partial(builtins.print, file=file) node_ids = {} if ascii_: @@ -45,7 +38,10 @@ def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, fil _need_iter = True if isinstance(node, TreeValue): _node_id = id(get_data_property(node).actual()) - _content = f'<{node.__class__.__name__} {hex(_node_id)}>' + if show_node_id: + _content = f'<{node.__class__.__name__} {hex(_node_id)}>' + else: + _content = f'<{node.__class__.__name__}>' if _node_id in node_ids.keys(): _str_old_path = '.'.join(('', *node_ids[_node_id])) _content = f'{_content}{os.linesep}(The same address as {_str_old_path})' @@ -98,25 +94,24 @@ class BaseTreeStruct(general_tree_value()): return self.__repr__() -def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True): +def clsmeta(func, allow_dict: bool = False): class _TempTreeValue(TreeValue): pass _types = ( - TreeValue, + TreeValue, BaseTree, *((dict,) if allow_dict else ()), - *((BaseTree,) if allow_data else ()), ) func_treelize = post_process(post_process(args_mapping( lambda i, x: TreeValue(x) if isinstance(x, _types) else x)))( replaceable_partial(original_func_treelize, return_type=_TempTreeValue) ) - _torch_size = func_treelize()(cls) + _wrapped_func = func_treelize()(func) class _MetaClass(type): def __call__(cls, *args, **kwargs): - _result = _torch_size(*args, **kwargs) + _result = _wrapped_func(*args, **kwargs) if isinstance(_result, _TempTreeValue): return type.__call__(cls, _result) else: @@ -125,5 +120,9 @@ def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True): return _MetaClass -class Object(BaseTreeStruct): +def _object(obj): + return obj + + +class Object(BaseTreeStruct, metaclass=clsmeta(_object, allow_dict=True)): pass diff --git a/treetensor/torch/base.py b/treetensor/torch/base.py index 6b704206e5c4356873991b8743c54aa6da0baf96..c166ede370ec478f7b0390188b42ebb6100bf139 100644 --- a/treetensor/torch/base.py +++ b/treetensor/torch/base.py @@ -1,7 +1,7 @@ from ..common import BaseTreeStruct -__all__ = ['TreeTorch'] +__all__ = ['Torch'] -class TreeTorch(BaseTreeStruct): +class Torch(BaseTreeStruct): pass diff --git a/treetensor/torch/size.py b/treetensor/torch/size.py index 2824b8585e2cb8b0ee5cd5468f39438ca6e6be03..5203ef4b08b0182809b427c5dc48a70fd8ceb9ac 100644 --- a/treetensor/torch/size.py +++ b/treetensor/torch/size.py @@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize from treevalue.tree.common import BaseTree from treevalue.utils import post_process -from .base import TreeTorch +from .base import Torch from ..common import Object, clsmeta, ireduce from ..utils import replaceable_partial, doc_from, current_names, args_mapping @@ -44,7 +44,31 @@ def _post_index(func): # noinspection PyTypeChecker @current_names() -class Size(TreeTorch, metaclass=clsmeta(torch.Size, allow_dict=True)): +class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)): + def __init__(self, data): + """ + In :class:`treetensor.torch.Size`, it's similar with the original :class:`torch.Size`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.Size([1, 2, 3]) + torch.Size([1, 2, 3]) + + >>> ttorch.Size({ + ... 'a': [1, 2, 3], + ... 'b': {'x': [3, 4, ]}, + ... 'c': [5], + ... }) + + ├── a --> torch.Size([1, 2, 3]) + ├── b --> + │ └── x --> torch.Size([3, 4]) + └── c --> torch.Size([5]) + """ + super(Torch, self).__init__(data) + @doc_from(torch.Size.numel) @ireduce(sum) @func_treelize(return_type=Object) diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index d877a2506554a4d283624a77f53d8c17edc79d71..a3cf332f958351bc2b2ca0e4029cea5ffc5b110c 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -3,7 +3,7 @@ import torch from treevalue import method_treelize from treevalue.utils import pre_process -from .base import TreeTorch +from .base import Torch from .size import Size from ..common import Object, ireduce, clsmeta from ..numpy import ndarray @@ -29,7 +29,40 @@ def _to_tensor(*args, **kwargs): # noinspection PyTypeChecker @current_names() -class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)): +class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): + # noinspection PyUnusedLocal + def __init__(self, data, *args, **kwargs): + """ + In :class:`treetensor.torch.Tensor`, it's similar but a little bit different with the + original :class:`torch.Tensor`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> torch.Tensor([1, 2, 3]) # in torch.Tensor, default type is float32 + tensor([1., 2., 3.]) + + >>> ttorch.Tensor([1, 2, 3]) # a native Tensor object, its type is auto detected with torch.tensor + tensor([1, 2, 3]) + + >>> ttorch.Tensor([1, 2, 3], dtype=torch.float32) # with float32 type + tensor([1., 2., 3.]) + + >>> ttorch.Tensor({ + ... 'a': [1, 2, 3], + ... 'b': {'x': [4.0, 5, 6]}, + ... 'c': [[True, ], [False, ]], + ... }) # a tree-based Tensor object + + ├── a --> tensor([1, 2, 3]) + ├── b --> + │ └── x --> tensor([4., 5., 6.]) + └── c --> tensor([[ True], + [False]]) + """ + super(Torch, self).__init__(data) + @doc_from(torch.Tensor.numpy) @method_treelize(return_type=ndarray) def numpy(self: torch.Tensor) -> np.ndarray: @@ -140,6 +173,7 @@ class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.shape + # noinspection PyArgumentList @doc_from(torch.Tensor.all) @tireduce(torch.all) @method_treelize(return_type=Object) @@ -149,6 +183,7 @@ class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.all(*args, **kwargs) + # noinspection PyArgumentList @doc_from(torch.Tensor.any) @tireduce(torch.any) @method_treelize(return_type=Object)