diff --git a/test/tensor/test_funcs.py b/test/tensor/test_funcs.py index c825eba6c42c9e4188bca26f20305f69db94de09..53be6d5726570f760db6e169b8b5956ce3dcfed4 100644 --- a/test/tensor/test_funcs.py +++ b/test/tensor/test_funcs.py @@ -240,6 +240,40 @@ class TestTensorFuncs: } }) + def test_empty(self): + _target = ttorch.empty(TreeValue({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + })) + assert _target.shape == ttorch.TreeSize({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([5, 6]), + 'x': { + 'c': torch.Size([2, 3, 4]), + } + }) + + def test_empty_like(self): + _target = ttorch.empty_like(ttorch.TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + })) + assert _target.shape == ttorch.TreeSize({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([4]), + 'x': { + 'c': torch.Size([3]), + 'd': torch.Size([1, 1, 2]), + } + }) + def test_all(self): r1 = ttorch.all(torch.tensor([True, True, True])) assert torch.is_tensor(r1) @@ -340,3 +374,32 @@ class TestTensorFuncs: 'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 5]), })).all() + + def test_equal(self): + p1 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])) + assert isinstance(p1, bool) + assert p1 + + p2 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 4])) + assert isinstance(p2, bool) + assert not p2 + + p3 = ttorch.equal(ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 6]), + }), ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 6]), + })) + assert isinstance(p3, bool) + assert p3 + + p4 = ttorch.equal(ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 6]), + }), ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 5]), + })) + assert isinstance(p4, bool) + assert not p4 diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index 6322783f227fcc580f3a741537f8882e0f2e3ed9..149825c0b47492016aa1f3b234b0bddd0f7c1c60 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -1,3 +1,5 @@ +import builtins + import numpy as np from treevalue import func_treelize as original_func_treelize @@ -13,13 +15,13 @@ __all__ = [ func_treelize = replaceable_partial(original_func_treelize, return_type=TreeNumpy) -@ireduce(all) +@ireduce(builtins.all) @func_treelize(return_type=TreeObject) def all(a, *args, **kwargs): return np.all(a, *args, **kwargs) -@ireduce(any) +@ireduce(builtins.any) @func_treelize() def any(a, *args, **kwargs): return np.any(a, *args, **kwargs) diff --git a/treetensor/tensor/funcs.py b/treetensor/tensor/funcs.py index d9ab078a60cfcfd6ee14e0c934a8163ade780640..71e46c5bff3a203ffc8ca5f90739621a21159835 100644 --- a/treetensor/tensor/funcs.py +++ b/treetensor/tensor/funcs.py @@ -1,8 +1,10 @@ +import builtins + import torch from treevalue import func_treelize as original_func_treelize from .tensor import TreeTensor, tireduce -from ..common import TreeObject +from ..common import TreeObject, ireduce from ..utils import replaceable_partial func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor) @@ -96,6 +98,7 @@ def eq(input_, other, *args, **kwargs): return torch.eq(input_, other, *args, **kwargs) +@ireduce(builtins.all) @func_treelize() def equal(input_, other, *args, **kwargs): return torch.equal(input_, other, *args, **kwargs)