提交 4086271e 编写于 作者: HansBug's avatar HansBug 😆

dev, doc(hansbug): add new function for Size class && add plenty of new documentations

上级 7ff4109d
......@@ -2,7 +2,7 @@ import numpy as np
import pytest
import treetensor.numpy as tnp
from treetensor.common import TreeObject
from treetensor.common import Object
# noinspection DuplicatedCode
......@@ -209,7 +209,7 @@ class TestNumpyArray:
})).all()
def test_tolist(self):
assert self._DEMO_1.tolist() == TreeObject({
assert self._DEMO_1.tolist() == Object({
'a': [[1, 2, 3], [4, 5, 6]],
'b': [1, 3, 5, 7],
'x': {
......@@ -217,7 +217,7 @@ class TestNumpyArray:
'd': [3, 9, 11.0],
}
})
assert self._DEMO_2.tolist() == TreeObject({
assert self._DEMO_2.tolist() == Object({
'a': [[1, 22, 3], [4, 5, 6]],
'b': [1, 3, 5, 7],
'x': {
......@@ -225,7 +225,7 @@ class TestNumpyArray:
'd': [3, 9, 11.0],
}
})
assert self._DEMO_3.tolist() == TreeObject({
assert self._DEMO_3.tolist() == Object({
'a': [[0, 0, 0], [0, 0, 0]],
'b': [0, 0, 0, 0],
'x': {
......
......@@ -12,20 +12,20 @@ _all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
@pytest.mark.unittest
class TestTorchTensor:
_DEMO_1 = ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 6]]),
'a': [[1, 2, 3], [4, 5, 6]],
'b': [[1, 2], [5, 6]],
'x': {
'c': torch.tensor([3, 5, 6, 7]),
'd': torch.tensor([[[1, 2], [8, 9]]]),
'c': [3, 5, 6, 7],
'd': [[[1, 2], [8, 9]]],
}
})
_DEMO_2 = ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 60]]),
'a': [[1, 2, 3], [4, 5, 6]],
'b': [[1, 2], [5, 60]],
'x': {
'c': torch.tensor([3, 5, 6, 7]),
'd': torch.tensor([[[1, 2], [8, 9]]]),
'c': [3, 5, 6, 7],
'd': [[[1, 2], [8, 9]]],
}
})
......@@ -48,11 +48,11 @@ class TestTorchTensor:
def test_to(self):
assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32),
'a': torch.FloatTensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.FloatTensor([[1, 2], [5, 6]]),
'x': {
'c': torch.tensor([3, 5, 6, 7], dtype=torch.float32),
'd': torch.tensor([[[1, 2], [8, 9]]], dtype=torch.float32),
'c': torch.FloatTensor([3, 5, 6, 7]),
'd': torch.FloatTensor([[[1, 2], [8, 9]]]),
}
}))
......
from .common import TreeObject
from .common import Object
from .numpy import ndarray
from .torch import Tensor
import builtins
import io
import os
from abc import ABCMeta
from functools import partial
from typing import Optional, Tuple, Callable
from treevalue import func_treelize as original_func_treelize
from treevalue import general_tree_value, TreeValue
from treevalue.tree.common import BaseTree
from treevalue.tree.tree.tree import get_data_property
from treevalue.utils import post_process
from ..utils import replaceable_partial, args_mapping
__all__ = [
'BaseTreeStruct', "TreeObject", 'print_tree',
'BaseTreeStruct', "Object",
'print_tree', 'clsmeta',
]
......@@ -78,7 +83,7 @@ def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, fil
print(repr_(tree), file=file)
class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
class BaseTreeStruct(general_tree_value()):
"""
Overview:
Base structure of all the trees in ``treetensor``.
......@@ -93,5 +98,32 @@ class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
return self.__repr__()
class TreeObject(BaseTreeStruct):
def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True):
class _TempTreeValue(TreeValue):
pass
_types = (
TreeValue,
*((dict,) if allow_dict else ()),
*((BaseTree,) if allow_data else ()),
)
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, _types) else x)))(
replaceable_partial(original_func_treelize, return_type=_TempTreeValue)
)
_torch_size = func_treelize()(cls)
class _MetaClass(type):
def __call__(cls, *args, **kwargs):
_result = _torch_size(*args, **kwargs)
if isinstance(_result, _TempTreeValue):
return type.__call__(cls, _result)
else:
return _result
return _MetaClass
class Object(BaseTreeStruct):
pass
......@@ -2,7 +2,7 @@ import numpy as np
from treevalue import method_treelize
from .base import TreeNumpy
from ..common import TreeObject, ireduce
from ..common import Object, ireduce
from ..utils import current_names
__all__ = [
......@@ -18,34 +18,34 @@ class ndarray(TreeNumpy):
Real numpy tree.
"""
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def tolist(self: np.ndarray):
return self.tolist()
@property
@ireduce(sum)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def size(self: np.ndarray) -> int:
return self.size
@property
@ireduce(sum)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def nbytes(self: np.ndarray) -> int:
return self.nbytes
@ireduce(sum)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def sum(self: np.ndarray, *args, **kwargs):
return self.sum(*args, **kwargs)
@ireduce(all)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def all(self: np.ndarray, *args, **kwargs):
return self.all(*args, **kwargs)
@ireduce(any)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def any(self: np.ndarray, *args, **kwargs):
return self.any(*args, **kwargs)
......
......@@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize
from treevalue.utils import post_process
from .array import ndarray
from ..common import ireduce, TreeObject
from ..common import ireduce, Object
from ..utils import replaceable_partial, doc_from, args_mapping
__all__ = [
......@@ -22,7 +22,7 @@ func_treelize = post_process(post_process(args_mapping(
@doc_from(np.all)
@ireduce(builtins.all)
@func_treelize(return_type=TreeObject)
@func_treelize(return_type=Object)
def all(a, *args, **kwargs):
return np.all(a, *args, **kwargs)
......
"""
Overview:
Common functions, based on ``torch`` module.
"""
import builtins
import torch
from treevalue import TreeValue
from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree
from treevalue.utils import post_process
from .tensor import Tensor, tireduce
from ..common import TreeObject, ireduce
from ..common import Object, ireduce
from ..utils import replaceable_partial, doc_from, args_mapping
__all__ = [
......@@ -28,7 +24,7 @@ __all__ = [
]
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeValue)) else x)))(
lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
replaceable_partial(original_func_treelize, return_type=Tensor)
)
......@@ -355,7 +351,7 @@ def empty_like(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.all)
@tireduce(torch.all)
@func_treelize(return_type=TreeObject)
@func_treelize(return_type=Object)
def all(input, *args, **kwargs):
"""
In ``treetensor``, you can get the ``all`` result of a whole tree with this function.
......@@ -394,7 +390,7 @@ def all(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.any)
@tireduce(torch.any)
@func_treelize(return_type=TreeObject)
@func_treelize(return_type=Object)
def any(input, *args, **kwargs):
"""
In ``treetensor``, you can get the ``any`` result of a whole tree with this function.
......@@ -433,7 +429,7 @@ def any(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.min)
@tireduce(torch.min)
@func_treelize(return_type=TreeObject)
@func_treelize(return_type=Object)
def min(input, *args, **kwargs):
"""
In ``treetensor``, you can get the ``min`` result of a whole tree with this function.
......@@ -472,7 +468,7 @@ def min(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.max)
@tireduce(torch.max)
@func_treelize(return_type=TreeObject)
@func_treelize(return_type=Object)
def max(input, *args, **kwargs):
"""
In ``treetensor``, you can get the ``max`` result of a whole tree with this function.
......@@ -511,7 +507,7 @@ def max(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.sum)
@tireduce(torch.sum)
@func_treelize(return_type=TreeObject)
@func_treelize(return_type=Object)
def sum(input, *args, **kwargs):
"""
In ``treetensor``, you can get the ``sum`` result of a whole tree with this function.
......
from functools import wraps
import torch
from treevalue import TreeValue
from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree
from treevalue.utils import post_process
from .base import TreeTorch
from ..common import TreeObject
from ..utils import replaceable_partial, doc_from, current_names
from ..common import Object, clsmeta, ireduce
from ..utils import replaceable_partial, doc_from, current_names, args_mapping
func_treelize = replaceable_partial(original_func_treelize)
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
replaceable_partial(original_func_treelize)
)
__all__ = [
'Size'
]
def _post_index(func):
def _has_non_none(tree):
if isinstance(tree, TreeValue):
for _, value in tree:
if _has_non_none(value):
return True
return False
else:
return tree is not None
@wraps(func)
def _new_func(self, value, *args, **kwargs):
_tree = func(self, value, *args, **kwargs)
if not _has_non_none(_tree):
raise ValueError(f'Can not find {repr(value)} in all the sizes.')
else:
return _tree
return _new_func
# noinspection PyTypeChecker
@current_names()
class Size(TreeTorch):
class Size(TreeTorch, metaclass=clsmeta(torch.Size, allow_dict=True)):
@doc_from(torch.Size.numel)
@func_treelize(return_type=TreeObject)
def numel(self: torch.Size) -> TreeObject:
@ireduce(sum)
@func_treelize(return_type=Object)
def numel(self: torch.Size) -> Object:
"""
Get the numel sum of the sizes in this tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
... }).numel()
26
"""
return self.numel()
@doc_from(torch.Size.index)
@func_treelize(return_type=TreeObject)
def index(self: torch.Size, *args, **kwargs) -> TreeObject:
return self.index(*args, **kwargs)
@_post_index
@func_treelize(return_type=Object)
def index(self: torch.Size, value, *args, **kwargs) -> Object:
"""
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
... 'c': [3, 5],
... }).index(2)
<Object 0x7fb412780e80>
├── a --> 1
├── b --> <Object 0x7fb412780eb8>
│ └── x --> 1
└── c --> None
.. note::
This method's behaviour is different from the :func:`torch.Size.index`.
No :class:`ValueError` will be raised unless the value can not be found
in any of the sizes, instead there will be nones returned in the tree.
"""
try:
return self.index(value, *args, **kwargs)
except ValueError:
return None
@doc_from(torch.Size.count)
@func_treelize(return_type=TreeObject)
def count(self: torch.Size, *args, **kwargs) -> TreeObject:
@ireduce(sum)
@func_treelize(return_type=Object)
def count(self: torch.Size, *args, **kwargs) -> Object:
"""
Get the occurrence count of the sizes in this tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
... }).count(2)
2
"""
return self.count(*args, **kwargs)
"""
Overview:
``Tensor`` class, based on ``torch`` module.
"""
import numpy as np
import torch
from treevalue import method_treelize
......@@ -10,7 +5,7 @@ from treevalue.utils import pre_process
from .base import TreeTorch
from .size import Size
from ..common import TreeObject, ireduce
from ..common import Object, ireduce, clsmeta
from ..numpy import ndarray
from ..utils import current_names, doc_from
......@@ -22,9 +17,19 @@ _reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce)
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
def _to_tensor(*args, **kwargs):
if (len(args) == 1 and not kwargs) or \
(not args and set(kwargs.keys()) == {'data'}):
data = args[0] if len(args) == 1 else kwargs['data']
if isinstance(data, torch.Tensor):
return data
return torch.tensor(*args, **kwargs)
# noinspection PyTypeChecker
@current_names()
class Tensor(TreeTorch):
class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
@doc_from(torch.Tensor.numpy)
@method_treelize(return_type=ndarray)
def numpy(self: torch.Tensor) -> np.ndarray:
......@@ -36,7 +41,7 @@ class Tensor(TreeTorch):
return self.numpy()
@doc_from(torch.Tensor.tolist)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def tolist(self: torch.Tensor):
"""
Get the dump result of tree tensor.
......@@ -106,7 +111,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.numel)
@ireduce(sum)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def numel(self: torch.Tensor):
"""
See :func:`treetensor.torch.numel`
......@@ -137,7 +142,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.all)
@tireduce(torch.all)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def all(self: torch.Tensor, *args, **kwargs) -> bool:
"""
See :func:`treetensor.torch.all`
......@@ -146,7 +151,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.any)
@tireduce(torch.any)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def any(self: torch.Tensor, *args, **kwargs) -> bool:
"""
See :func:`treetensor.torch.any`
......@@ -155,7 +160,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.max)
@tireduce(torch.max)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def max(self: torch.Tensor, *args, **kwargs):
"""
See :func:`treetensor.torch.max`
......@@ -164,7 +169,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.min)
@tireduce(torch.min)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def min(self: torch.Tensor, *args, **kwargs):
"""
See :func:`treetensor.torch.min`
......@@ -173,7 +178,7 @@ class Tensor(TreeTorch):
@doc_from(torch.Tensor.sum)
@tireduce(torch.sum)
@method_treelize(return_type=TreeObject)
@method_treelize(return_type=Object)
def sum(self: torch.Tensor, *args, **kwargs):
"""
See :func:`treetensor.torch.sum`
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册