提交 745b5b8e 编写于 作者: HansBug's avatar HansBug 😆

doc, dev, test(hansbug): add test for treetensor.common and treetensor.tensor and size

上级 4086271e
from .common import *
from .config import *
from .numpy import *
from .torch import *
from .test_trees import TestCommonTrees
import io
import pytest
import torch
from treevalue import typetrans, TreeValue, general_tree_value
from treetensor.common import Object, print_tree
def text_compares(expected, actual):
elines = expected.splitlines()
alines = actual.splitlines()
assert len(elines) == len(alines), f"""Lines not match,
Expected: {len(elines)} lines
Actual: {len(alines)} lines
"""
for i, (e, a) in enumerate(zip(elines, alines)):
assert e.rstrip() == a.rstrip(), f"""Line {i} not match,
Expected: {e}
Actual: {a}
"""
@pytest.mark.unittest
class TestCommonTrees:
def test_object(self):
t = Object(1)
assert isinstance(t, int)
assert t == 1
assert Object({'a': 1, 'b': 2}) == typetrans(TreeValue({
'a': 1, 'b': 2
}), Object)
def test_print_tree(self):
class _TempTree(general_tree_value()):
def __repr__(self):
with io.StringIO() as sfile:
print_tree(self, repr_=repr, ascii_=False, show_node_id=False, file=sfile)
return sfile.getvalue()
def __str__(self):
return self.__repr__()
text_compares("""<_TempTree>
├── a --> 1
└── b --> 2
""".rstrip(), str(_TempTree({
'a': 1, 'b': 2
})).rstrip())
class _TmpObject:
def __init__(self, v):
self.__v = v
def __repr__(self):
return self.__v
tx = _TempTree({
'a': 1, 'b': 2, 'c': torch.tensor([[1, 2, ], [3, 4]]),
'd': {'x': torch.tensor([[1], [2], [3, ], [4]])},
'e': _TmpObject('line after this\nhahahaha'),
})
tx.d.y = tx
text_compares("""<_TempTree>
├── a --> 1
├── b --> 2
├── c --> tensor([[1, 2],
│ [3, 4]])
├── d --> <_TempTree>
│ ├── x --> tensor([[1],
│ │ [2],
│ │ [3],
│ │ [4]])
│ └── y --> <_TempTree>
│ (The same address as <root>)
└── e --> line after this
hahahaha
""".rstrip(), str(tx).rstrip())
with io.StringIO() as sf:
print_tree(1, file=sf)
text_compares("1", sf.getvalue())
from .test_funcs import TestTorchFuncs
from .test_size import TestTorchSize
from .test_tensor import TestTorchTensor
......@@ -401,6 +401,111 @@ class TestTorchFuncs:
'b': torch.tensor([4, 5, 5]),
})).all()
def test_ne(self):
assert (ttorch.ne(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[False, True],
[True, False]])).all()
assert (ttorch.ne(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[False, True], [True, False]],
'b': [True, True, False],
})).all()
def test_lt(self):
assert (ttorch.lt(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[False, False],
[True, False]])).all()
assert (ttorch.lt(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[False, False], [True, False]],
'b': [True, False, False],
})).all()
def test_le(self):
assert (ttorch.le(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[True, False],
[True, True]])).all()
assert (ttorch.le(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[True, False], [True, True]],
'b': [True, False, True],
})).all()
def test_gt(self):
assert (ttorch.gt(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[False, True],
[False, False]])).all()
assert (ttorch.gt(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[False, True], [False, False]],
'b': [False, True, False],
})).all()
def test_ge(self):
assert (ttorch.ge(
torch.tensor([[1, 2], [3, 4]]),
torch.tensor([[1, 1], [4, 4]]),
) == torch.tensor([[True, True],
[False, True]])).all()
assert (ttorch.ge(
ttorch.tensor({
'a': [[1, 2], [3, 4]],
'b': [1.0, 1.5, 2.0],
}),
ttorch.tensor({
'a': [[1, 1], [4, 4]],
'b': [1.3, 1.2, 2.0],
}),
) == ttorch.tensor({
'a': [[True, True], [False, True]],
'b': [False, True, True],
})).all()
def test_equal(self):
p1 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3]))
assert isinstance(p1, bool)
......@@ -429,3 +534,36 @@ class TestTorchFuncs:
}))
assert isinstance(p4, bool)
assert not p4
def test_min(self):
t1 = ttorch.min(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
assert t1 == torch.tensor(1.0)
assert ttorch.min(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == ttorch.tensor({
'a': 1.0,
'b': {'x': 0.9},
})
def test_max(self):
t1 = ttorch.max(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
assert t1 == torch.tensor(2.0)
assert ttorch.max(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == ttorch.tensor({
'a': 2.0,
'b': {'x': 2.5, }
})
def test_sum(self):
assert ttorch.sum(torch.tensor([1.0, 2.0, 1.5])) == torch.tensor(4.5)
assert ttorch.sum(ttorch.tensor({
'a': [1.0, 2.0, 1.5],
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == torch.tensor(11.0)
import pytest
import torch
from treevalue import func_treelize, typetrans, TreeValue
import treetensor.torch as ttorch
from treetensor.common import Object
_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
@pytest.mark.unittest
class TestTorchSize:
def test_init(self):
t1 = ttorch.Size([1, 2, 3])
assert isinstance(t1, torch.Size)
assert t1 == torch.Size([1, 2, 3])
t2 = ttorch.Size({
'a': [1, 2, 3],
'b': {'x': [3, 4, ]},
'c': [5],
})
assert isinstance(t2, ttorch.Size)
assert typetrans(t2, TreeValue) == TreeValue({
'a': torch.Size([1, 2, 3]),
'b': {'x': torch.Size([3, 4, ])},
'c': torch.Size([5]),
})
def test_numel(self):
assert ttorch.Size({
'a': [1, 2, 3],
'b': {'x': [3, 4, ]},
'c': [5],
}).numel() == 23
def test_index(self):
assert ttorch.Size({
'a': [1, 2, 3],
'b': {'x': [3, 4, ]},
'c': [5],
}).index(3) == Object({
'a': 2,
'b': {'x': 0},
'c': None
})
with pytest.raises(ValueError):
ttorch.Size({
'a': [1, 2, 3],
'b': {'x': [3, 4, ]},
'c': [5],
}).index(100)
def test_count(self):
assert ttorch.Size({
'a': [1, 2, 3],
'b': {'x': [3, 4, ]},
'c': [5],
}).count(3) == 2
import numpy as np
import pytest
import torch
from treevalue import func_treelize
from treevalue import func_treelize, typetrans, TreeValue
import treetensor.numpy as tnp
import treetensor.torch as ttorch
......@@ -29,6 +29,18 @@ class TestTorchTensor:
}
})
def test_init(self):
assert (ttorch.Tensor([1, 2, 3]) == torch.tensor([1, 2, 3])).all()
assert (ttorch.Tensor([1, 2, 3], dtype=torch.float32) == torch.FloatTensor([1, 2, 3])).all()
assert (self._DEMO_1 == typetrans(TreeValue({
'a': ttorch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': ttorch.tensor([[1, 2], [5, 6]]),
'x': {
'c': ttorch.tensor([3, 5, 6, 7]),
'd': ttorch.tensor([[[1, 2], [8, 9]]]),
}
}), ttorch.Tensor)).all()
def test_numel(self):
assert self._DEMO_1.numel() == 18
......@@ -59,3 +71,6 @@ class TestTorchTensor:
def test_all(self):
assert (self._DEMO_1 == self._DEMO_1).all()
assert not (self._DEMO_1 == self._DEMO_2).all()
def test_tolist(self):
pass
......@@ -18,15 +18,8 @@ __all__ = [
]
def _tree_title(node: TreeValue):
_tree = get_data_property(node)
return "<{cls} {id}>".format(
cls=node.__class__.__name__,
id=hex(id(_tree.actual())),
)
def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, file=None):
def print_tree(tree: TreeValue, repr_: Callable = str,
ascii_: bool = False, show_node_id: bool = True, file=None):
print_to_file = partial(builtins.print, file=file)
node_ids = {}
if ascii_:
......@@ -45,7 +38,10 @@ def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, fil
_need_iter = True
if isinstance(node, TreeValue):
_node_id = id(get_data_property(node).actual())
if show_node_id:
_content = f'<{node.__class__.__name__} {hex(_node_id)}>'
else:
_content = f'<{node.__class__.__name__}>'
if _node_id in node_ids.keys():
_str_old_path = '.'.join(('<root>', *node_ids[_node_id]))
_content = f'{_content}{os.linesep}(The same address as {_str_old_path})'
......@@ -98,25 +94,24 @@ class BaseTreeStruct(general_tree_value()):
return self.__repr__()
def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True):
def clsmeta(func, allow_dict: bool = False):
class _TempTreeValue(TreeValue):
pass
_types = (
TreeValue,
TreeValue, BaseTree,
*((dict,) if allow_dict else ()),
*((BaseTree,) if allow_data else ()),
)
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, _types) else x)))(
replaceable_partial(original_func_treelize, return_type=_TempTreeValue)
)
_torch_size = func_treelize()(cls)
_wrapped_func = func_treelize()(func)
class _MetaClass(type):
def __call__(cls, *args, **kwargs):
_result = _torch_size(*args, **kwargs)
_result = _wrapped_func(*args, **kwargs)
if isinstance(_result, _TempTreeValue):
return type.__call__(cls, _result)
else:
......@@ -125,5 +120,9 @@ def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True):
return _MetaClass
class Object(BaseTreeStruct):
def _object(obj):
return obj
class Object(BaseTreeStruct, metaclass=clsmeta(_object, allow_dict=True)):
pass
from ..common import BaseTreeStruct
__all__ = ['TreeTorch']
__all__ = ['Torch']
class TreeTorch(BaseTreeStruct):
class Torch(BaseTreeStruct):
pass
......@@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree
from treevalue.utils import post_process
from .base import TreeTorch
from .base import Torch
from ..common import Object, clsmeta, ireduce
from ..utils import replaceable_partial, doc_from, current_names, args_mapping
......@@ -44,7 +44,31 @@ def _post_index(func):
# noinspection PyTypeChecker
@current_names()
class Size(TreeTorch, metaclass=clsmeta(torch.Size, allow_dict=True)):
class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)):
def __init__(self, data):
"""
In :class:`treetensor.torch.Size`, it's similar with the original :class:`torch.Size`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size([1, 2, 3])
torch.Size([1, 2, 3])
>>> ttorch.Size({
... 'a': [1, 2, 3],
... 'b': {'x': [3, 4, ]},
... 'c': [5],
... })
<Size 0x7fe00b115970>
├── a --> torch.Size([1, 2, 3])
├── b --> <Size 0x7fe00b115250>
│ └── x --> torch.Size([3, 4])
└── c --> torch.Size([5])
"""
super(Torch, self).__init__(data)
@doc_from(torch.Size.numel)
@ireduce(sum)
@func_treelize(return_type=Object)
......
......@@ -3,7 +3,7 @@ import torch
from treevalue import method_treelize
from treevalue.utils import pre_process
from .base import TreeTorch
from .base import Torch
from .size import Size
from ..common import Object, ireduce, clsmeta
from ..numpy import ndarray
......@@ -29,7 +29,40 @@ def _to_tensor(*args, **kwargs):
# noinspection PyTypeChecker
@current_names()
class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
# noinspection PyUnusedLocal
def __init__(self, data, *args, **kwargs):
"""
In :class:`treetensor.torch.Tensor`, it's similar but a little bit different with the
original :class:`torch.Tensor`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> torch.Tensor([1, 2, 3]) # in torch.Tensor, default type is float32
tensor([1., 2., 3.])
>>> ttorch.Tensor([1, 2, 3]) # a native Tensor object, its type is auto detected with torch.tensor
tensor([1, 2, 3])
>>> ttorch.Tensor([1, 2, 3], dtype=torch.float32) # with float32 type
tensor([1., 2., 3.])
>>> ttorch.Tensor({
... 'a': [1, 2, 3],
... 'b': {'x': [4.0, 5, 6]},
... 'c': [[True, ], [False, ]],
... }) # a tree-based Tensor object
<Tensor 0x7f537bb9a880>
├── a --> tensor([1, 2, 3])
├── b --> <Tensor 0x7f537bb9a0d0>
│ └── x --> tensor([4., 5., 6.])
└── c --> tensor([[ True],
[False]])
"""
super(Torch, self).__init__(data)
@doc_from(torch.Tensor.numpy)
@method_treelize(return_type=ndarray)
def numpy(self: torch.Tensor) -> np.ndarray:
......@@ -140,6 +173,7 @@ class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.shape
# noinspection PyArgumentList
@doc_from(torch.Tensor.all)
@tireduce(torch.all)
@method_treelize(return_type=Object)
......@@ -149,6 +183,7 @@ class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.all(*args, **kwargs)
# noinspection PyArgumentList
@doc_from(torch.Tensor.any)
@tireduce(torch.any)
@method_treelize(return_type=Object)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册