From 7cb29d8eae9fa081d85c05b12083344629b12e82 Mon Sep 17 00:00:00 2001 From: HansBug Date: Sun, 12 Sep 2021 09:57:41 +0800 Subject: [PATCH] dev,test(hansbug): add test for tensor/funcs.py --- test/tensor/test_funcs.py | 63 ++++++++++++++++++++++++++++++++++++++ treetensor/numpy/funcs.py | 6 ++-- treetensor/tensor/funcs.py | 5 ++- 3 files changed, 71 insertions(+), 3 deletions(-) diff --git a/test/tensor/test_funcs.py b/test/tensor/test_funcs.py index c825eba6c..53be6d572 100644 --- a/test/tensor/test_funcs.py +++ b/test/tensor/test_funcs.py @@ -240,6 +240,40 @@ class TestTensorFuncs: } }) + def test_empty(self): + _target = ttorch.empty(TreeValue({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + })) + assert _target.shape == ttorch.TreeSize({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([5, 6]), + 'x': { + 'c': torch.Size([2, 3, 4]), + } + }) + + def test_empty_like(self): + _target = ttorch.empty_like(ttorch.TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + })) + assert _target.shape == ttorch.TreeSize({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([4]), + 'x': { + 'c': torch.Size([3]), + 'd': torch.Size([1, 1, 2]), + } + }) + def test_all(self): r1 = ttorch.all(torch.tensor([True, True, True])) assert torch.is_tensor(r1) @@ -340,3 +374,32 @@ class TestTensorFuncs: 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 5]), })).all() + + def test_equal(self): + p1 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])) + assert isinstance(p1, bool) + assert p1 + + p2 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 4])) + assert isinstance(p2, bool) + assert not p2 + + p3 = ttorch.equal(ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 6]), + }), ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 6]), + })) + assert isinstance(p3, bool) + assert p3 + + p4 = ttorch.equal(ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 6]), + }), ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 5]), + })) + assert isinstance(p4, bool) + assert not p4 diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index 6322783f2..149825c0b 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -1,3 +1,5 @@ +import builtins + import numpy as np from treevalue import func_treelize as original_func_treelize @@ -13,13 +15,13 @@ __all__ = [ func_treelize = replaceable_partial(original_func_treelize, return_type=TreeNumpy) -@ireduce(all) +@ireduce(builtins.all) @func_treelize(return_type=TreeObject) def all(a, *args, **kwargs): return np.all(a, *args, **kwargs) -@ireduce(any) +@ireduce(builtins.any) @func_treelize() def any(a, *args, **kwargs): return np.any(a, *args, **kwargs) diff --git a/treetensor/tensor/funcs.py b/treetensor/tensor/funcs.py index d9ab078a6..71e46c5bf 100644 --- a/treetensor/tensor/funcs.py +++ b/treetensor/tensor/funcs.py @@ -1,8 +1,10 @@ +import builtins + import torch from treevalue import func_treelize as original_func_treelize from .tensor import TreeTensor, tireduce -from ..common import TreeObject +from ..common import TreeObject, ireduce from ..utils import replaceable_partial func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor) @@ -96,6 +98,7 @@ def eq(input_, other, *args, **kwargs): return torch.eq(input_, other, *args, **kwargs) +@ireduce(builtins.all) @func_treelize() def equal(input_, other, *args, **kwargs): return torch.equal(input_, other, *args, **kwargs) -- GitLab