From 4086271e1fe53b8f274a36a744f6e1a66d8c19b7 Mon Sep 17 00:00:00 2001 From: HansBug Date: Sun, 19 Sep 2021 22:54:47 +0800 Subject: [PATCH] dev, doc(hansbug): add new function for Size class && add plenty of new documentations --- test/numpy/test_array.py | 8 +-- test/torch/test_tensor.py | 24 ++++----- treetensor/__init__.py | 2 +- treetensor/common/trees.py | 40 ++++++++++++-- treetensor/numpy/array.py | 14 ++--- treetensor/numpy/funcs.py | 4 +- treetensor/torch/funcs.py | 20 +++---- treetensor/torch/size.py | 107 +++++++++++++++++++++++++++++++++---- treetensor/torch/tensor.py | 35 ++++++------ 9 files changed, 186 insertions(+), 68 deletions(-) diff --git a/test/numpy/test_array.py b/test/numpy/test_array.py index df9004e34..d4aa0ef03 100644 --- a/test/numpy/test_array.py +++ b/test/numpy/test_array.py @@ -2,7 +2,7 @@ import numpy as np import pytest import treetensor.numpy as tnp -from treetensor.common import TreeObject +from treetensor.common import Object # noinspection DuplicatedCode @@ -209,7 +209,7 @@ class TestNumpyArray: })).all() def test_tolist(self): - assert self._DEMO_1.tolist() == TreeObject({ + assert self._DEMO_1.tolist() == Object({ 'a': [[1, 2, 3], [4, 5, 6]], 'b': [1, 3, 5, 7], 'x': { @@ -217,7 +217,7 @@ class TestNumpyArray: 'd': [3, 9, 11.0], } }) - assert self._DEMO_2.tolist() == TreeObject({ + assert self._DEMO_2.tolist() == Object({ 'a': [[1, 22, 3], [4, 5, 6]], 'b': [1, 3, 5, 7], 'x': { @@ -225,7 +225,7 @@ class TestNumpyArray: 'd': [3, 9, 11.0], } }) - assert self._DEMO_3.tolist() == TreeObject({ + assert self._DEMO_3.tolist() == Object({ 'a': [[0, 0, 0], [0, 0, 0]], 'b': [0, 0, 0, 0], 'x': { diff --git a/test/torch/test_tensor.py b/test/torch/test_tensor.py index 93debafc9..79ff0d667 100644 --- a/test/torch/test_tensor.py +++ b/test/torch/test_tensor.py @@ -12,20 +12,20 @@ _all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y) @pytest.mark.unittest class TestTorchTensor: _DEMO_1 = ttorch.Tensor({ - 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), - 'b': torch.tensor([[1, 2], [5, 6]]), + 'a': [[1, 2, 3], [4, 5, 6]], + 'b': [[1, 2], [5, 6]], 'x': { - 'c': torch.tensor([3, 5, 6, 7]), - 'd': torch.tensor([[[1, 2], [8, 9]]]), + 'c': [3, 5, 6, 7], + 'd': [[[1, 2], [8, 9]]], } }) _DEMO_2 = ttorch.Tensor({ - 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), - 'b': torch.tensor([[1, 2], [5, 60]]), + 'a': [[1, 2, 3], [4, 5, 6]], + 'b': [[1, 2], [5, 60]], 'x': { - 'c': torch.tensor([3, 5, 6, 7]), - 'd': torch.tensor([[[1, 2], [8, 9]]]), + 'c': [3, 5, 6, 7], + 'd': [[[1, 2], [8, 9]]], } }) @@ -48,11 +48,11 @@ class TestTorchTensor: def test_to(self): assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({ - 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), - 'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32), + 'a': torch.FloatTensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.FloatTensor([[1, 2], [5, 6]]), 'x': { - 'c': torch.tensor([3, 5, 6, 7], dtype=torch.float32), - 'd': torch.tensor([[[1, 2], [8, 9]]], dtype=torch.float32), + 'c': torch.FloatTensor([3, 5, 6, 7]), + 'd': torch.FloatTensor([[[1, 2], [8, 9]]]), } })) diff --git a/treetensor/__init__.py b/treetensor/__init__.py index db6684d66..cf34c4b5f 100644 --- a/treetensor/__init__.py +++ b/treetensor/__init__.py @@ -1,3 +1,3 @@ -from .common import TreeObject +from .common import Object from .numpy import ndarray from .torch import Tensor diff --git a/treetensor/common/trees.py b/treetensor/common/trees.py index 9b528d0c8..ee0757be2 100644 --- a/treetensor/common/trees.py +++ b/treetensor/common/trees.py @@ -1,15 +1,20 @@ import builtins import io import os -from abc import ABCMeta from functools import partial from typing import Optional, Tuple, Callable +from treevalue import func_treelize as original_func_treelize from treevalue import general_tree_value, TreeValue +from treevalue.tree.common import BaseTree from treevalue.tree.tree.tree import get_data_property +from treevalue.utils import post_process + +from ..utils import replaceable_partial, args_mapping __all__ = [ - 'BaseTreeStruct', "TreeObject", 'print_tree', + 'BaseTreeStruct', "Object", + 'print_tree', 'clsmeta', ] @@ -78,7 +83,7 @@ def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, fil print(repr_(tree), file=file) -class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta): +class BaseTreeStruct(general_tree_value()): """ Overview: Base structure of all the trees in ``treetensor``. @@ -93,5 +98,32 @@ class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta): return self.__repr__() -class TreeObject(BaseTreeStruct): +def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True): + class _TempTreeValue(TreeValue): + pass + + _types = ( + TreeValue, + *((dict,) if allow_dict else ()), + *((BaseTree,) if allow_data else ()), + ) + func_treelize = post_process(post_process(args_mapping( + lambda i, x: TreeValue(x) if isinstance(x, _types) else x)))( + replaceable_partial(original_func_treelize, return_type=_TempTreeValue) + ) + + _torch_size = func_treelize()(cls) + + class _MetaClass(type): + def __call__(cls, *args, **kwargs): + _result = _torch_size(*args, **kwargs) + if isinstance(_result, _TempTreeValue): + return type.__call__(cls, _result) + else: + return _result + + return _MetaClass + + +class Object(BaseTreeStruct): pass diff --git a/treetensor/numpy/array.py b/treetensor/numpy/array.py index a30ee1131..7cbcc9a8a 100644 --- a/treetensor/numpy/array.py +++ b/treetensor/numpy/array.py @@ -2,7 +2,7 @@ import numpy as np from treevalue import method_treelize from .base import TreeNumpy -from ..common import TreeObject, ireduce +from ..common import Object, ireduce from ..utils import current_names __all__ = [ @@ -18,34 +18,34 @@ class ndarray(TreeNumpy): Real numpy tree. """ - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def tolist(self: np.ndarray): return self.tolist() @property @ireduce(sum) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def size(self: np.ndarray) -> int: return self.size @property @ireduce(sum) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def nbytes(self: np.ndarray) -> int: return self.nbytes @ireduce(sum) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def sum(self: np.ndarray, *args, **kwargs): return self.sum(*args, **kwargs) @ireduce(all) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def all(self: np.ndarray, *args, **kwargs): return self.all(*args, **kwargs) @ireduce(any) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def any(self: np.ndarray, *args, **kwargs): return self.any(*args, **kwargs) diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index 49a4742b3..c4f5a522b 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize from treevalue.utils import post_process from .array import ndarray -from ..common import ireduce, TreeObject +from ..common import ireduce, Object from ..utils import replaceable_partial, doc_from, args_mapping __all__ = [ @@ -22,7 +22,7 @@ func_treelize = post_process(post_process(args_mapping( @doc_from(np.all) @ireduce(builtins.all) -@func_treelize(return_type=TreeObject) +@func_treelize(return_type=Object) def all(a, *args, **kwargs): return np.all(a, *args, **kwargs) diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index 0acba8718..3908edb69 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -1,17 +1,13 @@ -""" -Overview: - Common functions, based on ``torch`` module. -""" - import builtins import torch from treevalue import TreeValue from treevalue import func_treelize as original_func_treelize +from treevalue.tree.common import BaseTree from treevalue.utils import post_process from .tensor import Tensor, tireduce -from ..common import TreeObject, ireduce +from ..common import Object, ireduce from ..utils import replaceable_partial, doc_from, args_mapping __all__ = [ @@ -28,7 +24,7 @@ __all__ = [ ] func_treelize = post_process(post_process(args_mapping( - lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeValue)) else x)))( + lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))( replaceable_partial(original_func_treelize, return_type=Tensor) ) @@ -355,7 +351,7 @@ def empty_like(input, *args, **kwargs): # noinspection PyShadowingBuiltins @doc_from(torch.all) @tireduce(torch.all) -@func_treelize(return_type=TreeObject) +@func_treelize(return_type=Object) def all(input, *args, **kwargs): """ In ``treetensor``, you can get the ``all`` result of a whole tree with this function. @@ -394,7 +390,7 @@ def all(input, *args, **kwargs): # noinspection PyShadowingBuiltins @doc_from(torch.any) @tireduce(torch.any) -@func_treelize(return_type=TreeObject) +@func_treelize(return_type=Object) def any(input, *args, **kwargs): """ In ``treetensor``, you can get the ``any`` result of a whole tree with this function. @@ -433,7 +429,7 @@ def any(input, *args, **kwargs): # noinspection PyShadowingBuiltins @doc_from(torch.min) @tireduce(torch.min) -@func_treelize(return_type=TreeObject) +@func_treelize(return_type=Object) def min(input, *args, **kwargs): """ In ``treetensor``, you can get the ``min`` result of a whole tree with this function. @@ -472,7 +468,7 @@ def min(input, *args, **kwargs): # noinspection PyShadowingBuiltins @doc_from(torch.max) @tireduce(torch.max) -@func_treelize(return_type=TreeObject) +@func_treelize(return_type=Object) def max(input, *args, **kwargs): """ In ``treetensor``, you can get the ``max`` result of a whole tree with this function. @@ -511,7 +507,7 @@ def max(input, *args, **kwargs): # noinspection PyShadowingBuiltins @doc_from(torch.sum) @tireduce(torch.sum) -@func_treelize(return_type=TreeObject) +@func_treelize(return_type=Object) def sum(input, *args, **kwargs): """ In ``treetensor``, you can get the ``sum`` result of a whole tree with this function. diff --git a/treetensor/torch/size.py b/treetensor/torch/size.py index 7b47cd2b5..2824b8585 100644 --- a/treetensor/torch/size.py +++ b/treetensor/torch/size.py @@ -1,31 +1,116 @@ +from functools import wraps + import torch +from treevalue import TreeValue from treevalue import func_treelize as original_func_treelize +from treevalue.tree.common import BaseTree +from treevalue.utils import post_process from .base import TreeTorch -from ..common import TreeObject -from ..utils import replaceable_partial, doc_from, current_names +from ..common import Object, clsmeta, ireduce +from ..utils import replaceable_partial, doc_from, current_names, args_mapping -func_treelize = replaceable_partial(original_func_treelize) +func_treelize = post_process(post_process(args_mapping( + lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))( + replaceable_partial(original_func_treelize) +) __all__ = [ 'Size' ] +def _post_index(func): + def _has_non_none(tree): + if isinstance(tree, TreeValue): + for _, value in tree: + if _has_non_none(value): + return True + + return False + else: + return tree is not None + + @wraps(func) + def _new_func(self, value, *args, **kwargs): + _tree = func(self, value, *args, **kwargs) + if not _has_non_none(_tree): + raise ValueError(f'Can not find {repr(value)} in all the sizes.') + else: + return _tree + + return _new_func + + # noinspection PyTypeChecker @current_names() -class Size(TreeTorch): +class Size(TreeTorch, metaclass=clsmeta(torch.Size, allow_dict=True)): @doc_from(torch.Size.numel) - @func_treelize(return_type=TreeObject) - def numel(self: torch.Size) -> TreeObject: + @ireduce(sum) + @func_treelize(return_type=Object) + def numel(self: torch.Size) -> Object: + """ + Get the numel sum of the sizes in this tree. + + Example:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.Size({ + ... 'a': [1, 2], + ... 'b': {'x': [3, 2, 4]}, + ... }).numel() + 26 + """ return self.numel() @doc_from(torch.Size.index) - @func_treelize(return_type=TreeObject) - def index(self: torch.Size, *args, **kwargs) -> TreeObject: - return self.index(*args, **kwargs) + @_post_index + @func_treelize(return_type=Object) + def index(self: torch.Size, value, *args, **kwargs) -> Object: + """ + + Example:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.Size({ + ... 'a': [1, 2], + ... 'b': {'x': [3, 2, 4]}, + ... 'c': [3, 5], + ... }).index(2) + + ├── a --> 1 + ├── b --> + │ └── x --> 1 + └── c --> None + + .. note:: + + This method's behaviour is different from the :func:`torch.Size.index`. + No :class:`ValueError` will be raised unless the value can not be found + in any of the sizes, instead there will be nones returned in the tree. + """ + try: + return self.index(value, *args, **kwargs) + except ValueError: + return None @doc_from(torch.Size.count) - @func_treelize(return_type=TreeObject) - def count(self: torch.Size, *args, **kwargs) -> TreeObject: + @ireduce(sum) + @func_treelize(return_type=Object) + def count(self: torch.Size, *args, **kwargs) -> Object: + """ + Get the occurrence count of the sizes in this tree. + + Example:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.Size({ + ... 'a': [1, 2], + ... 'b': {'x': [3, 2, 4]}, + ... }).count(2) + 2 + """ return self.count(*args, **kwargs) diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index ae50a2eec..d877a2506 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -1,8 +1,3 @@ -""" -Overview: - ``Tensor`` class, based on ``torch`` module. -""" - import numpy as np import torch from treevalue import method_treelize @@ -10,7 +5,7 @@ from treevalue.utils import pre_process from .base import TreeTorch from .size import Size -from ..common import TreeObject, ireduce +from ..common import Object, ireduce, clsmeta from ..numpy import ndarray from ..utils import current_names, doc_from @@ -22,9 +17,19 @@ _reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {})) tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce) -# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList +def _to_tensor(*args, **kwargs): + if (len(args) == 1 and not kwargs) or \ + (not args and set(kwargs.keys()) == {'data'}): + data = args[0] if len(args) == 1 else kwargs['data'] + if isinstance(data, torch.Tensor): + return data + + return torch.tensor(*args, **kwargs) + + +# noinspection PyTypeChecker @current_names() -class Tensor(TreeTorch): +class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)): @doc_from(torch.Tensor.numpy) @method_treelize(return_type=ndarray) def numpy(self: torch.Tensor) -> np.ndarray: @@ -36,7 +41,7 @@ class Tensor(TreeTorch): return self.numpy() @doc_from(torch.Tensor.tolist) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def tolist(self: torch.Tensor): """ Get the dump result of tree tensor. @@ -106,7 +111,7 @@ class Tensor(TreeTorch): @doc_from(torch.Tensor.numel) @ireduce(sum) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def numel(self: torch.Tensor): """ See :func:`treetensor.torch.numel` @@ -137,7 +142,7 @@ class Tensor(TreeTorch): @doc_from(torch.Tensor.all) @tireduce(torch.all) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def all(self: torch.Tensor, *args, **kwargs) -> bool: """ See :func:`treetensor.torch.all` @@ -146,7 +151,7 @@ class Tensor(TreeTorch): @doc_from(torch.Tensor.any) @tireduce(torch.any) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def any(self: torch.Tensor, *args, **kwargs) -> bool: """ See :func:`treetensor.torch.any` @@ -155,7 +160,7 @@ class Tensor(TreeTorch): @doc_from(torch.Tensor.max) @tireduce(torch.max) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def max(self: torch.Tensor, *args, **kwargs): """ See :func:`treetensor.torch.max` @@ -164,7 +169,7 @@ class Tensor(TreeTorch): @doc_from(torch.Tensor.min) @tireduce(torch.min) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def min(self: torch.Tensor, *args, **kwargs): """ See :func:`treetensor.torch.min` @@ -173,7 +178,7 @@ class Tensor(TreeTorch): @doc_from(torch.Tensor.sum) @tireduce(torch.sum) - @method_treelize(return_type=TreeObject) + @method_treelize(return_type=Object) def sum(self: torch.Tensor, *args, **kwargs): """ See :func:`treetensor.torch.sum` -- GitLab