提交 745b5b8e 编写于 作者: HansBug's avatar HansBug 😆

doc, dev, test(hansbug): add test for treetensor.common and treetensor.tensor and size

上级 4086271e
from .common import *
from .config import * from .config import *
from .numpy import * from .numpy import *
from .torch import * from .torch import *
from .test_trees import TestCommonTrees
import io
import pytest
import torch
from treevalue import typetrans, TreeValue, general_tree_value
from treetensor.common import Object, print_tree
def text_compares(expected, actual):
elines = expected.splitlines()
alines = actual.splitlines()
assert len(elines) == len(alines), f"""Lines not match,
Expected: {len(elines)} lines
Actual: {len(alines)} lines
"""
for i, (e, a) in enumerate(zip(elines, alines)):
assert e.rstrip() == a.rstrip(), f"""Line {i} not match,
Expected: {e}
Actual: {a}
"""
@pytest.mark.unittest
class TestCommonTrees:
def test_object(self):
t = Object(1)
assert isinstance(t, int)
assert t == 1
assert Object({'a': 1, 'b': 2}) == typetrans(TreeValue({
'a': 1, 'b': 2
}), Object)
def test_print_tree(self):
class _TempTree(general_tree_value()):
def __repr__(self):
with io.StringIO() as sfile:
print_tree(self, repr_=repr, ascii_=False, show_node_id=False, file=sfile)
return sfile.getvalue()
def __str__(self):
return self.__repr__()
text_compares("""<_TempTree>
├── a --> 1
└── b --> 2
""".rstrip(), str(_TempTree({
'a': 1, 'b': 2
})).rstrip())
class _TmpObject:
def __init__(self, v):
self.__v = v
def __repr__(self):
return self.__v
tx = _TempTree({
'a': 1, 'b': 2, 'c': torch.tensor([[1, 2, ], [3, 4]]),
'd': {'x': torch.tensor([[1], [2], [3, ], [4]])},
'e': _TmpObject('line after this\nhahahaha'),
})
tx.d.y = tx
text_compares("""<_TempTree>
├── a --> 1
├── b --> 2
├── c --> tensor([[1, 2],
│ [3, 4]])
├── d --> <_TempTree>
│ ├── x --> tensor([[1],
│ │ [2],
│ │ [3],
│ │ [4]])
│ └── y --> <_TempTree>
│ (The same address as <root>)
└── e --> line after this
hahahaha
""".rstrip(), str(tx).rstrip())
with io.StringIO() as sf:
print_tree(1, file=sf)
text_compares("1", sf.getvalue())
from .test_funcs import TestTorchFuncs from .test_funcs import TestTorchFuncs
from .test_size import TestTorchSize
from .test_tensor import TestTorchTensor from .test_tensor import TestTorchTensor
...@@ -401,6 +401,111 @@ class TestTorchFuncs: ...@@ -401,6 +401,111 @@ class TestTorchFuncs:
'b': torch.tensor([4, 5, 5]), 'b': torch.tensor([4, 5, 5]),
})).all() })).all()
def test_ne(self):
assert (ttorch.ne(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[False, True],
[True, False]])).all()
assert (ttorch.ne(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[False, True], [True, False]],
'b': [True, True, False],
})).all()
def test_lt(self):
assert (ttorch.lt(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[False, False],
[True, False]])).all()
assert (ttorch.lt(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[False, False], [True, False]],
'b': [True, False, False],
})).all()
def test_le(self):
assert (ttorch.le(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[True, False],
[True, True]])).all()
assert (ttorch.le(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[True, False], [True, True]],
'b': [True, False, True],
})).all()
def test_gt(self):
assert (ttorch.gt(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[False, True],
[False, False]])).all()
assert (ttorch.gt(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[False, True], [False, False]],
'b': [False, True, False],
})).all()
def test_ge(self):
assert (ttorch.ge(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[True, True],
[False, True]])).all()
assert (ttorch.ge(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[True, True], [False, True]],
'b': [False, True, True],
})).all()
def test_equal(self): def test_equal(self):
p1 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])) p1 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3]))
assert isinstance(p1, bool) assert isinstance(p1, bool)
...@@ -429,3 +534,36 @@ class TestTorchFuncs: ...@@ -429,3 +534,36 @@ class TestTorchFuncs:
})) }))
assert isinstance(p4, bool) assert isinstance(p4, bool)
assert not p4 assert not p4
def test_min(self):
t1 = ttorch.min(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
assert t1 == torch.tensor(1.0)
assert ttorch.min(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == ttorch.tensor({
'a': 1.0,
'b': {'x': 0.9},
})
def test_max(self):
t1 = ttorch.max(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
assert t1 == torch.tensor(2.0)
assert ttorch.max(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == ttorch.tensor({
'a': 2.0,
'b': {'x': 2.5, }
})
def test_sum(self):
assert ttorch.sum(torch.tensor([1.0, 2.0, 1.5])) == torch.tensor(4.5)
assert ttorch.sum(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == torch.tensor(11.0)
import pytest
import torch
from treevalue import func_treelize, typetrans, TreeValue
import treetensor.torch as ttorch
from treetensor.common import Object
_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
@pytest.mark.unittest
class TestTorchSize:
def test_init(self):
t1 = ttorch.Size([1, 2, 3])
assert isinstance(t1, torch.Size)
assert t1 == torch.Size([1, 2, 3])
t2 = ttorch.Size({
'a': [1, 2, 3],
'b': {'x': [3, 4, ]},
'c': [5],
})
assert isinstance(t2, ttorch.Size)
assert typetrans(t2, TreeValue) == TreeValue({
'a': torch.Size([1, 2, 3]),
'b': {'x': torch.Size([3, 4, ])},
'c': torch.Size([5]),
})
def test_numel(self):
assert ttorch.Size({
'a': [1, 2, 3],
'b': {'x': [3, 4, ]},
'c': [5],
}).numel() == 23
def test_index(self):
assert ttorch.Size({
'a': [1, 2, 3],
'b': {'x': [3, 4, ]},
'c': [5],
}).index(3) == Object({
'a': 2,
'b': {'x': 0},
'c': None
})
with pytest.raises(ValueError):
ttorch.Size({
'a': [1, 2, 3],
'b': {'x': [3, 4, ]},
'c': [5],
}).index(100)
def test_count(self):
assert ttorch.Size({
'a': [1, 2, 3],
'b': {'x': [3, 4, ]},
'c': [5],
}).count(3) == 2
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from treevalue import func_treelize from treevalue import func_treelize, typetrans, TreeValue
import treetensor.numpy as tnp import treetensor.numpy as tnp
import treetensor.torch as ttorch import treetensor.torch as ttorch
...@@ -29,6 +29,18 @@ class TestTorchTensor: ...@@ -29,6 +29,18 @@ class TestTorchTensor:
} }
}) })
def test_init(self):
assert (ttorch.Tensor([1, 2, 3]) == torch.tensor([1, 2, 3])).all()
assert (ttorch.Tensor([1, 2, 3], dtype=torch.float32) == torch.FloatTensor([1, 2, 3])).all()
assert (self._DEMO_1 == typetrans(TreeValue({
'a': ttorch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': ttorch.tensor([[1, 2], [5, 6]]),
'x': {
'c': ttorch.tensor([3, 5, 6, 7]),
'd': ttorch.tensor([[[1, 2], [8, 9]]]),
}
}), ttorch.Tensor)).all()
def test_numel(self): def test_numel(self):
assert self._DEMO_1.numel() == 18 assert self._DEMO_1.numel() == 18
...@@ -59,3 +71,6 @@ class TestTorchTensor: ...@@ -59,3 +71,6 @@ class TestTorchTensor:
def test_all(self): def test_all(self):
assert (self._DEMO_1 == self._DEMO_1).all() assert (self._DEMO_1 == self._DEMO_1).all()
assert not (self._DEMO_1 == self._DEMO_2).all() assert not (self._DEMO_1 == self._DEMO_2).all()
def test_tolist(self):
pass
...@@ -18,15 +18,8 @@ __all__ = [ ...@@ -18,15 +18,8 @@ __all__ = [
] ]
def _tree_title(node: TreeValue): def print_tree(tree: TreeValue, repr_: Callable = str,
_tree = get_data_property(node) ascii_: bool = False, show_node_id: bool = True, file=None):
return "<{cls} {id}>".format(
cls=node.__class__.__name__,
id=hex(id(_tree.actual())),
)
def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, file=None):
print_to_file = partial(builtins.print, file=file) print_to_file = partial(builtins.print, file=file)
node_ids = {} node_ids = {}
if ascii_: if ascii_:
...@@ -45,7 +38,10 @@ def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, fil ...@@ -45,7 +38,10 @@ def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, fil
_need_iter = True _need_iter = True
if isinstance(node, TreeValue): if isinstance(node, TreeValue):
_node_id = id(get_data_property(node).actual()) _node_id = id(get_data_property(node).actual())
_content = f'<{node.__class__.__name__} {hex(_node_id)}>' if show_node_id:
_content = f'<{node.__class__.__name__} {hex(_node_id)}>'
else:
_content = f'<{node.__class__.__name__}>'
if _node_id in node_ids.keys(): if _node_id in node_ids.keys():
_str_old_path = '.'.join(('<root>', *node_ids[_node_id])) _str_old_path = '.'.join(('<root>', *node_ids[_node_id]))
_content = f'{_content}{os.linesep}(The same address as {_str_old_path})' _content = f'{_content}{os.linesep}(The same address as {_str_old_path})'
...@@ -98,25 +94,24 @@ class BaseTreeStruct(general_tree_value()): ...@@ -98,25 +94,24 @@ class BaseTreeStruct(general_tree_value()):
return self.__repr__() return self.__repr__()
def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True): def clsmeta(func, allow_dict: bool = False):
class _TempTreeValue(TreeValue): class _TempTreeValue(TreeValue):
pass pass
_types = ( _types = (
TreeValue, TreeValue, BaseTree,
*((dict,) if allow_dict else ()), *((dict,) if allow_dict else ()),
*((BaseTree,) if allow_data else ()),
) )
func_treelize = post_process(post_process(args_mapping( func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, _types) else x)))( lambda i, x: TreeValue(x) if isinstance(x, _types) else x)))(
replaceable_partial(original_func_treelize, return_type=_TempTreeValue) replaceable_partial(original_func_treelize, return_type=_TempTreeValue)
) )
_torch_size = func_treelize()(cls) _wrapped_func = func_treelize()(func)
class _MetaClass(type): class _MetaClass(type):
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
_result = _torch_size(*args, **kwargs) _result = _wrapped_func(*args, **kwargs)
if isinstance(_result, _TempTreeValue): if isinstance(_result, _TempTreeValue):
return type.__call__(cls, _result) return type.__call__(cls, _result)
else: else:
...@@ -125,5 +120,9 @@ def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True): ...@@ -125,5 +120,9 @@ def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True):
return _MetaClass return _MetaClass
class Object(BaseTreeStruct): def _object(obj):
return obj
class Object(BaseTreeStruct, metaclass=clsmeta(_object, allow_dict=True)):
pass pass
from ..common import BaseTreeStruct from ..common import BaseTreeStruct
__all__ = ['TreeTorch'] __all__ = ['Torch']
class TreeTorch(BaseTreeStruct): class Torch(BaseTreeStruct):
pass pass
...@@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize ...@@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree from treevalue.tree.common import BaseTree
from treevalue.utils import post_process from treevalue.utils import post_process
from .base import TreeTorch from .base import Torch
from ..common import Object, clsmeta, ireduce from ..common import Object, clsmeta, ireduce
from ..utils import replaceable_partial, doc_from, current_names, args_mapping from ..utils import replaceable_partial, doc_from, current_names, args_mapping
...@@ -44,7 +44,31 @@ def _post_index(func): ...@@ -44,7 +44,31 @@ def _post_index(func):
# noinspection PyTypeChecker # noinspection PyTypeChecker
@current_names() @current_names()
class Size(TreeTorch, metaclass=clsmeta(torch.Size, allow_dict=True)): class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)):
def __init__(self, data):
"""
In :class:`treetensor.torch.Size`, it's similar with the original :class:`torch.Size`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size([1, 2, 3])
torch.Size([1, 2, 3])
>>> ttorch.Size({
... 'a': [1, 2, 3],
... 'b': {'x': [3, 4, ]},
... 'c': [5],
... })
<Size 0x7fe00b115970>
├── a --> torch.Size([1, 2, 3])
├── b --> <Size 0x7fe00b115250>
│ └── x --> torch.Size([3, 4])
└── c --> torch.Size([5])
"""
super(Torch, self).__init__(data)
@doc_from(torch.Size.numel) @doc_from(torch.Size.numel)
@ireduce(sum) @ireduce(sum)
@func_treelize(return_type=Object) @func_treelize(return_type=Object)
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from treevalue import method_treelize from treevalue import method_treelize
from treevalue.utils import pre_process from treevalue.utils import pre_process
from .base import TreeTorch from .base import Torch
from .size import Size from .size import Size
from ..common import Object, ireduce, clsmeta from ..common import Object, ireduce, clsmeta
from ..numpy import ndarray from ..numpy import ndarray
...@@ -29,7 +29,40 @@ def _to_tensor(*args, **kwargs): ...@@ -29,7 +29,40 @@ def _to_tensor(*args, **kwargs):
# noinspection PyTypeChecker # noinspection PyTypeChecker
@current_names() @current_names()
class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)): class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
# noinspection PyUnusedLocal
def __init__(self, data, *args, **kwargs):
"""
In :class:`treetensor.torch.Tensor`, it's similar but a little bit different with the
original :class:`torch.Tensor`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> torch.Tensor([1, 2, 3]) # in torch.Tensor, default type is float32
tensor([1., 2., 3.])
>>> ttorch.Tensor([1, 2, 3]) # a native Tensor object, its type is auto detected with torch.tensor
tensor([1, 2, 3])
>>> ttorch.Tensor([1, 2, 3], dtype=torch.float32) # with float32 type
tensor([1., 2., 3.])
>>> ttorch.Tensor({
... 'a': [1, 2, 3],
... 'b': {'x': [4.0, 5, 6]},
... 'c': [[True, ], [False, ]],
... }) # a tree-based Tensor object
<Tensor 0x7f537bb9a880>
├── a --> tensor([1, 2, 3])
├── b --> <Tensor 0x7f537bb9a0d0>
│ └── x --> tensor([4., 5., 6.])
└── c --> tensor([[ True],
[False]])
"""
super(Torch, self).__init__(data)
@doc_from(torch.Tensor.numpy) @doc_from(torch.Tensor.numpy)
@method_treelize(return_type=ndarray) @method_treelize(return_type=ndarray)
def numpy(self: torch.Tensor) -> np.ndarray: def numpy(self: torch.Tensor) -> np.ndarray:
...@@ -140,6 +173,7 @@ class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)): ...@@ -140,6 +173,7 @@ class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
""" """
return self.shape return self.shape
# noinspection PyArgumentList
@doc_from(torch.Tensor.all) @doc_from(torch.Tensor.all)
@tireduce(torch.all) @tireduce(torch.all)
@method_treelize(return_type=Object) @method_treelize(return_type=Object)
...@@ -149,6 +183,7 @@ class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)): ...@@ -149,6 +183,7 @@ class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
""" """
return self.all(*args, **kwargs) return self.all(*args, **kwargs)
# noinspection PyArgumentList
@doc_from(torch.Tensor.any) @doc_from(torch.Tensor.any)
@tireduce(torch.any) @tireduce(torch.any)
@method_treelize(return_type=Object) @method_treelize(return_type=Object)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册