提交 f6e128a6 编写于 作者: HansBug's avatar HansBug 😆

test(hansbug): add all documentation for treetensor.torch part

上级 04bebab0
......@@ -5,6 +5,7 @@ from treevalue import func_treelize, typetrans, TreeValue
import treetensor.numpy as tnp
import treetensor.torch as ttorch
from treetensor.common import Object
_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
......@@ -69,8 +70,141 @@ class TestTorchTensor:
}))
def test_all(self):
assert (self._DEMO_1 == self._DEMO_1).all()
assert not (self._DEMO_1 == self._DEMO_2).all()
t1 = ttorch.Tensor({
'a': [True, True],
'b': {'x': [[True, True, ], [True, True, ]]}
}).all()
assert isinstance(t1, torch.Tensor)
assert t1.dtype == torch.bool
assert t1
t2 = ttorch.Tensor({
'a': [True, False],
'b': {'x': [[True, True, ], [True, True, ]]}
}).all()
assert isinstance(t2, torch.Tensor)
assert t2.dtype == torch.bool
assert not t2
def test_tolist(self):
pass
assert self._DEMO_1.tolist() == Object({
'a': [[1, 2, 3], [4, 5, 6]],
'b': [[1, 2], [5, 6]],
'x': {
'c': [3, 5, 6, 7],
'd': [[[1, 2], [8, 9]]],
}
})
def test_any(self):
t1 = ttorch.Tensor({
'a': [True, False],
'b': {'x': [[False, False, ], [False, False, ]]}
}).any()
assert isinstance(t1, torch.Tensor)
assert t1.dtype == torch.bool
assert t1
t2 = ttorch.Tensor({
'a': [False, False],
'b': {'x': [[False, False, ], [False, False, ]]}
}).any()
assert isinstance(t2, torch.Tensor)
assert t2.dtype == torch.bool
assert not t2
def test_max(self):
t1 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).max()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 3
def test_min(self):
t1 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).min()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == -1
def test_sum(self):
t1 = ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}).sum()
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 7
def test_eq(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) == ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [True, False],
'b': {'x': [[False, True], [False, False]]}
})).all()
def test_ne(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) != ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [False, True],
'b': {'x': [[True, False], [True, True]]}
})).all()
def test_lt(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) < ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [False, True],
'b': {'x': [[False, False], [True, False]]}
})).all()
def test_le(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) <= ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [True, True],
'b': {'x': [[False, True], [True, False]]}
})).all()
def test_gt(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) > ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [False, False],
'b': {'x': [[True, False], [False, True]]}
})).all()
def test_ge(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
}) >= ttorch.Tensor({
'a': [1, 21],
'b': {'x': [[-1, 3], [12, -10]]}
})) == ttorch.Tensor({
'a': [True, False],
'b': {'x': [[True, True], [False, True]]}
})).all()
......@@ -27,8 +27,8 @@ def print_tree(tree: TreeValue, repr_: Callable = str,
Arguments:
- tree (:obj:`TreeValue`): Given tree object.
- repr\_ (:obj:`Callable`): Representation function, default is ``str``.
- ascii\_ (:obj:`bool`): Use ascii to print the tree, default is ``False``.
- repr\\_ (:obj:`Callable`): Representation function, default is ``str``.
- ascii\\_ (:obj:`bool`): Use ascii to print the tree, default is ``False``.
- show_node_id (:obj:`bool`): Show node id of the tree, default is ``True``.
- file: Output file of this print procedure, default is ``None`` which means to stdout.
"""
......
......@@ -88,7 +88,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
>>> 'b': [1, 2, 3],
>>> 'c': True,
>>> }).tolist()
TreeObject({
Object({
'a': [[1, 2], [3, 4]],
'b': [1, 2, 3],
'c': True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册