提交 4086271e 编写于 作者: HansBug's avatar HansBug 😆

dev, doc(hansbug): add new function for Size class && add plenty of new documentations

上级 7ff4109d
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import pytest import pytest
import treetensor.numpy as tnp import treetensor.numpy as tnp
from treetensor.common import TreeObject from treetensor.common import Object
# noinspection DuplicatedCode # noinspection DuplicatedCode
...@@ -209,7 +209,7 @@ class TestNumpyArray: ...@@ -209,7 +209,7 @@ class TestNumpyArray:
})).all() })).all()
def test_tolist(self): def test_tolist(self):
assert self._DEMO_1.tolist() == TreeObject({ assert self._DEMO_1.tolist() == Object({
'a': [[1, 2, 3], [4, 5, 6]], 'a': [[1, 2, 3], [4, 5, 6]],
'b': [1, 3, 5, 7], 'b': [1, 3, 5, 7],
'x': { 'x': {
...@@ -217,7 +217,7 @@ class TestNumpyArray: ...@@ -217,7 +217,7 @@ class TestNumpyArray:
'd': [3, 9, 11.0], 'd': [3, 9, 11.0],
} }
}) })
assert self._DEMO_2.tolist() == TreeObject({ assert self._DEMO_2.tolist() == Object({
'a': [[1, 22, 3], [4, 5, 6]], 'a': [[1, 22, 3], [4, 5, 6]],
'b': [1, 3, 5, 7], 'b': [1, 3, 5, 7],
'x': { 'x': {
...@@ -225,7 +225,7 @@ class TestNumpyArray: ...@@ -225,7 +225,7 @@ class TestNumpyArray:
'd': [3, 9, 11.0], 'd': [3, 9, 11.0],
} }
}) })
assert self._DEMO_3.tolist() == TreeObject({ assert self._DEMO_3.tolist() == Object({
'a': [[0, 0, 0], [0, 0, 0]], 'a': [[0, 0, 0], [0, 0, 0]],
'b': [0, 0, 0, 0], 'b': [0, 0, 0, 0],
'x': { 'x': {
......
...@@ -12,20 +12,20 @@ _all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y) ...@@ -12,20 +12,20 @@ _all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
@pytest.mark.unittest @pytest.mark.unittest
class TestTorchTensor: class TestTorchTensor:
_DEMO_1 = ttorch.Tensor({ _DEMO_1 = ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': [[1, 2, 3], [4, 5, 6]],
'b': torch.tensor([[1, 2], [5, 6]]), 'b': [[1, 2], [5, 6]],
'x': { 'x': {
'c': torch.tensor([3, 5, 6, 7]), 'c': [3, 5, 6, 7],
'd': torch.tensor([[[1, 2], [8, 9]]]), 'd': [[[1, 2], [8, 9]]],
} }
}) })
_DEMO_2 = ttorch.Tensor({ _DEMO_2 = ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': [[1, 2, 3], [4, 5, 6]],
'b': torch.tensor([[1, 2], [5, 60]]), 'b': [[1, 2], [5, 60]],
'x': { 'x': {
'c': torch.tensor([3, 5, 6, 7]), 'c': [3, 5, 6, 7],
'd': torch.tensor([[[1, 2], [8, 9]]]), 'd': [[[1, 2], [8, 9]]],
} }
}) })
...@@ -48,11 +48,11 @@ class TestTorchTensor: ...@@ -48,11 +48,11 @@ class TestTorchTensor:
def test_to(self): def test_to(self):
assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({ assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), 'a': torch.FloatTensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32), 'b': torch.FloatTensor([[1, 2], [5, 6]]),
'x': { 'x': {
'c': torch.tensor([3, 5, 6, 7], dtype=torch.float32), 'c': torch.FloatTensor([3, 5, 6, 7]),
'd': torch.tensor([[[1, 2], [8, 9]]], dtype=torch.float32), 'd': torch.FloatTensor([[[1, 2], [8, 9]]]),
} }
})) }))
......
from .common import TreeObject from .common import Object
from .numpy import ndarray from .numpy import ndarray
from .torch import Tensor from .torch import Tensor
import builtins import builtins
import io import io
import os import os
from abc import ABCMeta
from functools import partial from functools import partial
from typing import Optional, Tuple, Callable from typing import Optional, Tuple, Callable
from treevalue import func_treelize as original_func_treelize
from treevalue import general_tree_value, TreeValue from treevalue import general_tree_value, TreeValue
from treevalue.tree.common import BaseTree
from treevalue.tree.tree.tree import get_data_property from treevalue.tree.tree.tree import get_data_property
from treevalue.utils import post_process
from ..utils import replaceable_partial, args_mapping
__all__ = [ __all__ = [
'BaseTreeStruct', "TreeObject", 'print_tree', 'BaseTreeStruct', "Object",
'print_tree', 'clsmeta',
] ]
...@@ -78,7 +83,7 @@ def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, fil ...@@ -78,7 +83,7 @@ def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, fil
print(repr_(tree), file=file) print(repr_(tree), file=file)
class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta): class BaseTreeStruct(general_tree_value()):
""" """
Overview: Overview:
Base structure of all the trees in ``treetensor``. Base structure of all the trees in ``treetensor``.
...@@ -93,5 +98,32 @@ class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta): ...@@ -93,5 +98,32 @@ class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
return self.__repr__() return self.__repr__()
class TreeObject(BaseTreeStruct): def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True):
class _TempTreeValue(TreeValue):
pass
_types = (
TreeValue,
*((dict,) if allow_dict else ()),
*((BaseTree,) if allow_data else ()),
)
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, _types) else x)))(
replaceable_partial(original_func_treelize, return_type=_TempTreeValue)
)
_torch_size = func_treelize()(cls)
class _MetaClass(type):
def __call__(cls, *args, **kwargs):
_result = _torch_size(*args, **kwargs)
if isinstance(_result, _TempTreeValue):
return type.__call__(cls, _result)
else:
return _result
return _MetaClass
class Object(BaseTreeStruct):
pass pass
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
from treevalue import method_treelize from treevalue import method_treelize
from .base import TreeNumpy from .base import TreeNumpy
from ..common import TreeObject, ireduce from ..common import Object, ireduce
from ..utils import current_names from ..utils import current_names
__all__ = [ __all__ = [
...@@ -18,34 +18,34 @@ class ndarray(TreeNumpy): ...@@ -18,34 +18,34 @@ class ndarray(TreeNumpy):
Real numpy tree. Real numpy tree.
""" """
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def tolist(self: np.ndarray): def tolist(self: np.ndarray):
return self.tolist() return self.tolist()
@property @property
@ireduce(sum) @ireduce(sum)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def size(self: np.ndarray) -> int: def size(self: np.ndarray) -> int:
return self.size return self.size
@property @property
@ireduce(sum) @ireduce(sum)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def nbytes(self: np.ndarray) -> int: def nbytes(self: np.ndarray) -> int:
return self.nbytes return self.nbytes
@ireduce(sum) @ireduce(sum)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def sum(self: np.ndarray, *args, **kwargs): def sum(self: np.ndarray, *args, **kwargs):
return self.sum(*args, **kwargs) return self.sum(*args, **kwargs)
@ireduce(all) @ireduce(all)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def all(self: np.ndarray, *args, **kwargs): def all(self: np.ndarray, *args, **kwargs):
return self.all(*args, **kwargs) return self.all(*args, **kwargs)
@ireduce(any) @ireduce(any)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def any(self: np.ndarray, *args, **kwargs): def any(self: np.ndarray, *args, **kwargs):
return self.any(*args, **kwargs) return self.any(*args, **kwargs)
......
...@@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize ...@@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize
from treevalue.utils import post_process from treevalue.utils import post_process
from .array import ndarray from .array import ndarray
from ..common import ireduce, TreeObject from ..common import ireduce, Object
from ..utils import replaceable_partial, doc_from, args_mapping from ..utils import replaceable_partial, doc_from, args_mapping
__all__ = [ __all__ = [
...@@ -22,7 +22,7 @@ func_treelize = post_process(post_process(args_mapping( ...@@ -22,7 +22,7 @@ func_treelize = post_process(post_process(args_mapping(
@doc_from(np.all) @doc_from(np.all)
@ireduce(builtins.all) @ireduce(builtins.all)
@func_treelize(return_type=TreeObject) @func_treelize(return_type=Object)
def all(a, *args, **kwargs): def all(a, *args, **kwargs):
return np.all(a, *args, **kwargs) return np.all(a, *args, **kwargs)
......
"""
Overview:
Common functions, based on ``torch`` module.
"""
import builtins import builtins
import torch import torch
from treevalue import TreeValue from treevalue import TreeValue
from treevalue import func_treelize as original_func_treelize from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree
from treevalue.utils import post_process from treevalue.utils import post_process
from .tensor import Tensor, tireduce from .tensor import Tensor, tireduce
from ..common import TreeObject, ireduce from ..common import Object, ireduce
from ..utils import replaceable_partial, doc_from, args_mapping from ..utils import replaceable_partial, doc_from, args_mapping
__all__ = [ __all__ = [
...@@ -28,7 +24,7 @@ __all__ = [ ...@@ -28,7 +24,7 @@ __all__ = [
] ]
func_treelize = post_process(post_process(args_mapping( func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeValue)) else x)))( lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
replaceable_partial(original_func_treelize, return_type=Tensor) replaceable_partial(original_func_treelize, return_type=Tensor)
) )
...@@ -355,7 +351,7 @@ def empty_like(input, *args, **kwargs): ...@@ -355,7 +351,7 @@ def empty_like(input, *args, **kwargs):
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
@doc_from(torch.all) @doc_from(torch.all)
@tireduce(torch.all) @tireduce(torch.all)
@func_treelize(return_type=TreeObject) @func_treelize(return_type=Object)
def all(input, *args, **kwargs): def all(input, *args, **kwargs):
""" """
In ``treetensor``, you can get the ``all`` result of a whole tree with this function. In ``treetensor``, you can get the ``all`` result of a whole tree with this function.
...@@ -394,7 +390,7 @@ def all(input, *args, **kwargs): ...@@ -394,7 +390,7 @@ def all(input, *args, **kwargs):
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
@doc_from(torch.any) @doc_from(torch.any)
@tireduce(torch.any) @tireduce(torch.any)
@func_treelize(return_type=TreeObject) @func_treelize(return_type=Object)
def any(input, *args, **kwargs): def any(input, *args, **kwargs):
""" """
In ``treetensor``, you can get the ``any`` result of a whole tree with this function. In ``treetensor``, you can get the ``any`` result of a whole tree with this function.
...@@ -433,7 +429,7 @@ def any(input, *args, **kwargs): ...@@ -433,7 +429,7 @@ def any(input, *args, **kwargs):
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
@doc_from(torch.min) @doc_from(torch.min)
@tireduce(torch.min) @tireduce(torch.min)
@func_treelize(return_type=TreeObject) @func_treelize(return_type=Object)
def min(input, *args, **kwargs): def min(input, *args, **kwargs):
""" """
In ``treetensor``, you can get the ``min`` result of a whole tree with this function. In ``treetensor``, you can get the ``min`` result of a whole tree with this function.
...@@ -472,7 +468,7 @@ def min(input, *args, **kwargs): ...@@ -472,7 +468,7 @@ def min(input, *args, **kwargs):
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
@doc_from(torch.max) @doc_from(torch.max)
@tireduce(torch.max) @tireduce(torch.max)
@func_treelize(return_type=TreeObject) @func_treelize(return_type=Object)
def max(input, *args, **kwargs): def max(input, *args, **kwargs):
""" """
In ``treetensor``, you can get the ``max`` result of a whole tree with this function. In ``treetensor``, you can get the ``max`` result of a whole tree with this function.
...@@ -511,7 +507,7 @@ def max(input, *args, **kwargs): ...@@ -511,7 +507,7 @@ def max(input, *args, **kwargs):
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
@doc_from(torch.sum) @doc_from(torch.sum)
@tireduce(torch.sum) @tireduce(torch.sum)
@func_treelize(return_type=TreeObject) @func_treelize(return_type=Object)
def sum(input, *args, **kwargs): def sum(input, *args, **kwargs):
""" """
In ``treetensor``, you can get the ``sum`` result of a whole tree with this function. In ``treetensor``, you can get the ``sum`` result of a whole tree with this function.
......
from functools import wraps
import torch import torch
from treevalue import TreeValue
from treevalue import func_treelize as original_func_treelize from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree
from treevalue.utils import post_process
from .base import TreeTorch from .base import TreeTorch
from ..common import TreeObject from ..common import Object, clsmeta, ireduce
from ..utils import replaceable_partial, doc_from, current_names from ..utils import replaceable_partial, doc_from, current_names, args_mapping
func_treelize = replaceable_partial(original_func_treelize) func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
replaceable_partial(original_func_treelize)
)
__all__ = [ __all__ = [
'Size' 'Size'
] ]
def _post_index(func):
def _has_non_none(tree):
if isinstance(tree, TreeValue):
for _, value in tree:
if _has_non_none(value):
return True
return False
else:
return tree is not None
@wraps(func)
def _new_func(self, value, *args, **kwargs):
_tree = func(self, value, *args, **kwargs)
if not _has_non_none(_tree):
raise ValueError(f'Can not find {repr(value)} in all the sizes.')
else:
return _tree
return _new_func
# noinspection PyTypeChecker # noinspection PyTypeChecker
@current_names() @current_names()
class Size(TreeTorch): class Size(TreeTorch, metaclass=clsmeta(torch.Size, allow_dict=True)):
@doc_from(torch.Size.numel) @doc_from(torch.Size.numel)
@func_treelize(return_type=TreeObject) @ireduce(sum)
def numel(self: torch.Size) -> TreeObject: @func_treelize(return_type=Object)
def numel(self: torch.Size) -> Object:
"""
Get the numel sum of the sizes in this tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
... }).numel()
26
"""
return self.numel() return self.numel()
@doc_from(torch.Size.index) @doc_from(torch.Size.index)
@func_treelize(return_type=TreeObject) @_post_index
def index(self: torch.Size, *args, **kwargs) -> TreeObject: @func_treelize(return_type=Object)
return self.index(*args, **kwargs) def index(self: torch.Size, value, *args, **kwargs) -> Object:
"""
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
... 'c': [3, 5],
... }).index(2)
<Object 0x7fb412780e80>
├── a --> 1
├── b --> <Object 0x7fb412780eb8>
│ └── x --> 1
└── c --> None
.. note::
This method's behaviour is different from the :func:`torch.Size.index`.
No :class:`ValueError` will be raised unless the value can not be found
in any of the sizes, instead there will be nones returned in the tree.
"""
try:
return self.index(value, *args, **kwargs)
except ValueError:
return None
@doc_from(torch.Size.count) @doc_from(torch.Size.count)
@func_treelize(return_type=TreeObject) @ireduce(sum)
def count(self: torch.Size, *args, **kwargs) -> TreeObject: @func_treelize(return_type=Object)
def count(self: torch.Size, *args, **kwargs) -> Object:
"""
Get the occurrence count of the sizes in this tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
... }).count(2)
2
"""
return self.count(*args, **kwargs) return self.count(*args, **kwargs)
"""
Overview:
``Tensor`` class, based on ``torch`` module.
"""
import numpy as np import numpy as np
import torch import torch
from treevalue import method_treelize from treevalue import method_treelize
...@@ -10,7 +5,7 @@ from treevalue.utils import pre_process ...@@ -10,7 +5,7 @@ from treevalue.utils import pre_process
from .base import TreeTorch from .base import TreeTorch
from .size import Size from .size import Size
from ..common import TreeObject, ireduce from ..common import Object, ireduce, clsmeta
from ..numpy import ndarray from ..numpy import ndarray
from ..utils import current_names, doc_from from ..utils import current_names, doc_from
...@@ -22,9 +17,19 @@ _reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {})) ...@@ -22,9 +17,19 @@ _reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce) tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce)
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList def _to_tensor(*args, **kwargs):
if (len(args) == 1 and not kwargs) or \
(not args and set(kwargs.keys()) == {'data'}):
data = args[0] if len(args) == 1 else kwargs['data']
if isinstance(data, torch.Tensor):
return data
return torch.tensor(*args, **kwargs)
# noinspection PyTypeChecker
@current_names() @current_names()
class Tensor(TreeTorch): class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
@doc_from(torch.Tensor.numpy) @doc_from(torch.Tensor.numpy)
@method_treelize(return_type=ndarray) @method_treelize(return_type=ndarray)
def numpy(self: torch.Tensor) -> np.ndarray: def numpy(self: torch.Tensor) -> np.ndarray:
...@@ -36,7 +41,7 @@ class Tensor(TreeTorch): ...@@ -36,7 +41,7 @@ class Tensor(TreeTorch):
return self.numpy() return self.numpy()
@doc_from(torch.Tensor.tolist) @doc_from(torch.Tensor.tolist)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def tolist(self: torch.Tensor): def tolist(self: torch.Tensor):
""" """
Get the dump result of tree tensor. Get the dump result of tree tensor.
...@@ -106,7 +111,7 @@ class Tensor(TreeTorch): ...@@ -106,7 +111,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.numel) @doc_from(torch.Tensor.numel)
@ireduce(sum) @ireduce(sum)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def numel(self: torch.Tensor): def numel(self: torch.Tensor):
""" """
See :func:`treetensor.torch.numel` See :func:`treetensor.torch.numel`
...@@ -137,7 +142,7 @@ class Tensor(TreeTorch): ...@@ -137,7 +142,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.all) @doc_from(torch.Tensor.all)
@tireduce(torch.all) @tireduce(torch.all)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def all(self: torch.Tensor, *args, **kwargs) -> bool: def all(self: torch.Tensor, *args, **kwargs) -> bool:
""" """
See :func:`treetensor.torch.all` See :func:`treetensor.torch.all`
...@@ -146,7 +151,7 @@ class Tensor(TreeTorch): ...@@ -146,7 +151,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.any) @doc_from(torch.Tensor.any)
@tireduce(torch.any) @tireduce(torch.any)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def any(self: torch.Tensor, *args, **kwargs) -> bool: def any(self: torch.Tensor, *args, **kwargs) -> bool:
""" """
See :func:`treetensor.torch.any` See :func:`treetensor.torch.any`
...@@ -155,7 +160,7 @@ class Tensor(TreeTorch): ...@@ -155,7 +160,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.max) @doc_from(torch.Tensor.max)
@tireduce(torch.max) @tireduce(torch.max)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def max(self: torch.Tensor, *args, **kwargs): def max(self: torch.Tensor, *args, **kwargs):
""" """
See :func:`treetensor.torch.max` See :func:`treetensor.torch.max`
...@@ -164,7 +169,7 @@ class Tensor(TreeTorch): ...@@ -164,7 +169,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.min) @doc_from(torch.Tensor.min)
@tireduce(torch.min) @tireduce(torch.min)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def min(self: torch.Tensor, *args, **kwargs): def min(self: torch.Tensor, *args, **kwargs):
""" """
See :func:`treetensor.torch.min` See :func:`treetensor.torch.min`
...@@ -173,7 +178,7 @@ class Tensor(TreeTorch): ...@@ -173,7 +178,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.sum) @doc_from(torch.Tensor.sum)
@tireduce(torch.sum) @tireduce(torch.sum)
@method_treelize(return_type=TreeObject) @method_treelize(return_type=Object)
def sum(self: torch.Tensor, *args, **kwargs): def sum(self: torch.Tensor, *args, **kwargs):
""" """
See :func:`treetensor.torch.sum` See :func:`treetensor.torch.sum`
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册