From f6e128a69ff108c161be319dc8a34264ab11eaf9 Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 20 Sep 2021 20:56:05 +0800 Subject: [PATCH] test(hansbug): add all documentation for treetensor.torch part --- test/torch/test_tensor.py | 140 ++++++++++++++++++++++++++++++++++++- treetensor/common/trees.py | 4 +- treetensor/torch/tensor.py | 2 +- 3 files changed, 140 insertions(+), 6 deletions(-) diff --git a/test/torch/test_tensor.py b/test/torch/test_tensor.py index b2d6ffd72..7f2a03df5 100644 --- a/test/torch/test_tensor.py +++ b/test/torch/test_tensor.py @@ -5,6 +5,7 @@ from treevalue import func_treelize, typetrans, TreeValue import treetensor.numpy as tnp import treetensor.torch as ttorch +from treetensor.common import Object _all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y) @@ -69,8 +70,141 @@ class TestTorchTensor: })) def test_all(self): - assert (self._DEMO_1 == self._DEMO_1).all() - assert not (self._DEMO_1 == self._DEMO_2).all() + t1 = ttorch.Tensor({ + 'a': [True, True], + 'b': {'x': [[True, True, ], [True, True, ]]} + }).all() + assert isinstance(t1, torch.Tensor) + assert t1.dtype == torch.bool + assert t1 + + t2 = ttorch.Tensor({ + 'a': [True, False], + 'b': {'x': [[True, True, ], [True, True, ]]} + }).all() + assert isinstance(t2, torch.Tensor) + assert t2.dtype == torch.bool + assert not t2 def test_tolist(self): - pass + assert self._DEMO_1.tolist() == Object({ + 'a': [[1, 2, 3], [4, 5, 6]], + 'b': [[1, 2], [5, 6]], + 'x': { + 'c': [3, 5, 6, 7], + 'd': [[[1, 2], [8, 9]]], + } + }) + + def test_any(self): + t1 = ttorch.Tensor({ + 'a': [True, False], + 'b': {'x': [[False, False, ], [False, False, ]]} + }).any() + assert isinstance(t1, torch.Tensor) + assert t1.dtype == torch.bool + assert t1 + + t2 = ttorch.Tensor({ + 'a': [False, False], + 'b': {'x': [[False, False, ], [False, False, ]]} + }).any() + assert isinstance(t2, torch.Tensor) + assert t2.dtype == torch.bool + assert not t2 + + def test_max(self): + t1 = ttorch.Tensor({ + 'a': [1, 2], + 'b': {'x': [[0, 3], [2, -1]]} + }).max() + assert isinstance(t1, torch.Tensor) + assert t1.tolist() == 3 + + def test_min(self): + t1 = ttorch.Tensor({ + 'a': [1, 2], + 'b': {'x': [[0, 3], [2, -1]]} + }).min() + assert isinstance(t1, torch.Tensor) + assert t1.tolist() == -1 + + def test_sum(self): + t1 = ttorch.Tensor({ + 'a': [1, 2], + 'b': {'x': [[0, 3], [2, -1]]} + }).sum() + assert isinstance(t1, torch.Tensor) + assert t1.tolist() == 7 + + def test_eq(self): + assert ((ttorch.Tensor({ + 'a': [1, 2], + 'b': {'x': [[0, 3], [2, -1]]} + }) == ttorch.Tensor({ + 'a': [1, 21], + 'b': {'x': [[-1, 3], [12, -10]]} + })) == ttorch.Tensor({ + 'a': [True, False], + 'b': {'x': [[False, True], [False, False]]} + })).all() + + def test_ne(self): + assert ((ttorch.Tensor({ + 'a': [1, 2], + 'b': {'x': [[0, 3], [2, -1]]} + }) != ttorch.Tensor({ + 'a': [1, 21], + 'b': {'x': [[-1, 3], [12, -10]]} + })) == ttorch.Tensor({ + 'a': [False, True], + 'b': {'x': [[True, False], [True, True]]} + })).all() + + def test_lt(self): + assert ((ttorch.Tensor({ + 'a': [1, 2], + 'b': {'x': [[0, 3], [2, -1]]} + }) < ttorch.Tensor({ + 'a': [1, 21], + 'b': {'x': [[-1, 3], [12, -10]]} + })) == ttorch.Tensor({ + 'a': [False, True], + 'b': {'x': [[False, False], [True, False]]} + })).all() + + def test_le(self): + assert ((ttorch.Tensor({ + 'a': [1, 2], + 'b': {'x': [[0, 3], [2, -1]]} + }) <= ttorch.Tensor({ + 'a': [1, 21], + 'b': {'x': [[-1, 3], [12, -10]]} + })) == ttorch.Tensor({ + 'a': [True, True], + 'b': {'x': [[False, True], [True, False]]} + })).all() + + def test_gt(self): + assert ((ttorch.Tensor({ + 'a': [1, 2], + 'b': {'x': [[0, 3], [2, -1]]} + }) > ttorch.Tensor({ + 'a': [1, 21], + 'b': {'x': [[-1, 3], [12, -10]]} + })) == ttorch.Tensor({ + 'a': [False, False], + 'b': {'x': [[True, False], [False, True]]} + })).all() + + def test_ge(self): + assert ((ttorch.Tensor({ + 'a': [1, 2], + 'b': {'x': [[0, 3], [2, -1]]} + }) >= ttorch.Tensor({ + 'a': [1, 21], + 'b': {'x': [[-1, 3], [12, -10]]} + })) == ttorch.Tensor({ + 'a': [True, False], + 'b': {'x': [[True, True], [False, True]]} + })).all() diff --git a/treetensor/common/trees.py b/treetensor/common/trees.py index 48c98cbd8..a3fd995fe 100644 --- a/treetensor/common/trees.py +++ b/treetensor/common/trees.py @@ -27,8 +27,8 @@ def print_tree(tree: TreeValue, repr_: Callable = str, Arguments: - tree (:obj:`TreeValue`): Given tree object. - - repr\_ (:obj:`Callable`): Representation function, default is ``str``. - - ascii\_ (:obj:`bool`): Use ascii to print the tree, default is ``False``. + - repr\\_ (:obj:`Callable`): Representation function, default is ``str``. + - ascii\\_ (:obj:`bool`): Use ascii to print the tree, default is ``False``. - show_node_id (:obj:`bool`): Show node id of the tree, default is ``True``. - file: Output file of this print procedure, default is ``None`` which means to stdout. """ diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index a3cf332f9..77a9f69ce 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -88,7 +88,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): >>> 'b': [1, 2, 3], >>> 'c': True, >>> }).tolist() - TreeObject({ + Object({ 'a': [[1, 2], [3, 4]], 'b': [1, 2, 3], 'c': True, -- GitLab