diff --git a/test/tensor/test_funcs.py b/test/tensor/test_funcs.py index e9d7a6f85bd71c5c3e729457a073dd1466fd0629..7c60767d62cd12e3826f9868b9ddc3dfd45c6f21 100644 --- a/test/tensor/test_funcs.py +++ b/test/tensor/test_funcs.py @@ -1,7 +1,9 @@ import pytest import torch -from treetensor.tensor import TreeTensor, zeros, all_equal, zeros_like, ones, ones_like +from treetensor.tensor import TreeTensor, zeros, all_equal, zeros_like, ones, ones_like, randint, randint_like, randn, \ + randn_like, full, full_like +from treetensor.tensor import all as _tensor_all # noinspection DuplicatedCode @@ -48,6 +50,10 @@ class TestTensorFuncs: })) def test_zeros_like(self): + assert all_equal( + zeros_like(torch.tensor([[1, 2, 3], [4, 5, 6]])), + torch.tensor([[0, 0, 0], [0, 0, 0]]), + ) assert all_equal( zeros_like(TreeTensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), @@ -84,6 +90,10 @@ class TestTensorFuncs: })) def test_ones_like(self): + assert all_equal( + ones_like(torch.tensor([[1, 2, 3], [4, 5, 6]])), + torch.tensor([[1, 1, 1], [1, 1, 1]]) + ) assert all_equal( ones_like(TreeTensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), @@ -102,3 +112,157 @@ class TestTensorFuncs: } }) ) + + def test_randn(self): + _target = randn((200, 300)) + assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02 + assert 0.98 <= _target.view(60000).std().tolist() <= 1.02 + assert _target.shape == torch.Size([200, 300]) + + _target = randn({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + }) + assert _target.raw_shape == TreeTensor({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([5, 6]), + 'x': { + 'c': torch.Size([2, 3, 4]), + } + }) + + def test_randn_like(self): + _target = randn_like(torch.ones(200, 300)) + assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02 + assert 0.98 <= _target.view(60000).std().tolist() <= 1.02 + assert _target.shape == torch.Size([200, 300]) + + _target = randn_like(TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), + 'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32), + 'x': { + 'c': torch.tensor([5, 6, 7], dtype=torch.float32), + 'd': torch.tensor([[[8, 9]]], dtype=torch.float32), + } + })) + assert _target.raw_shape == TreeTensor({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([4]), + 'x': { + 'c': torch.Size([3]), + 'd': torch.Size([1, 1, 2]), + } + }) + + def test_randint(self): + _target = randint({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + }, -10, 10) + assert _tensor_all(_target < 10).all() + assert _tensor_all(-10 <= _target).all() + assert _target.raw_shape == TreeTensor({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([5, 6]), + 'x': { + 'c': torch.Size([2, 3, 4]), + } + }) + + _target = randint({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + }, 10) + assert _tensor_all(_target < 10).all() + assert _tensor_all(0 <= _target).all() + assert _target.raw_shape == TreeTensor({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([5, 6]), + 'x': { + 'c': torch.Size([2, 3, 4]), + } + }) + + def test_randint_like(self): + _target = randint_like(TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + }), -10, 10) + assert _tensor_all(_target < 10).all() + assert _tensor_all(-10 <= _target).all() + assert _target.raw_shape == TreeTensor({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([4]), + 'x': { + 'c': torch.Size([3]), + 'd': torch.Size([1, 1, 2]), + } + }) + + _target = randint_like(TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + }), 10) + assert _tensor_all(_target < 10).all() + assert _tensor_all(0 <= _target).all() + assert _target.raw_shape == TreeTensor({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([4]), + 'x': { + 'c': torch.Size([3]), + 'd': torch.Size([1, 1, 2]), + } + }) + + def test_full(self): + _target = full({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + }, 233) + assert _tensor_all(_target.tensor_eq(233)).all() + assert _target.raw_shape == TreeTensor({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([5, 6]), + 'x': { + 'c': torch.Size([2, 3, 4]), + } + }) + + def test_full_like(self): + _target = full_like(TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + }), 233) + assert _tensor_all(_target.tensor_eq(233)).all() + assert _target.raw_shape == TreeTensor({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([4]), + 'x': { + 'c': torch.Size([3]), + 'd': torch.Size([1, 1, 2]), + } + }) diff --git a/treetensor/common/__init__.py b/treetensor/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..35dea9c8bd9b9a757f9dfddd50b4b5e94d9fd230 --- /dev/null +++ b/treetensor/common/__init__.py @@ -0,0 +1 @@ +from .treelist import TreeList diff --git a/treetensor/common/treelist.py b/treetensor/common/treelist.py new file mode 100644 index 0000000000000000000000000000000000000000..5436b2f1d9fd5b3bed7443df2cdf167c4717acec --- /dev/null +++ b/treetensor/common/treelist.py @@ -0,0 +1,5 @@ +from treevalue import general_tree_value + + +class TreeList(general_tree_value()): + pass diff --git a/treetensor/numpy/numpy.py b/treetensor/numpy/numpy.py index 1b38d912696b8cf0241989dd8666b47e95a590f7..365fe529eb6d3279928c1679fe19c052a23c991d 100644 --- a/treetensor/numpy/numpy.py +++ b/treetensor/numpy/numpy.py @@ -1,4 +1,6 @@ -from treevalue import general_tree_value +from treevalue import general_tree_value, method_treelize + +from ..common import TreeList class TreeNumpy(general_tree_value()): @@ -7,6 +9,8 @@ class TreeNumpy(general_tree_value()): Real numpy tree. """ + tolist = method_treelize(return_type=TreeList)(lambda d: d.tolist()) + @property def size(self) -> int: return self \ diff --git a/treetensor/tensor/funcs.py b/treetensor/tensor/funcs.py index c6d26712790f9fcfe9a7e7613a1ef3d94b4f6656..beeaa868ba1a7ba64f38a679fb61865ca099f629 100644 --- a/treetensor/tensor/funcs.py +++ b/treetensor/tensor/funcs.py @@ -34,7 +34,7 @@ zeros = _size_based_treelize()(torch.zeros) randn = _size_based_treelize()(torch.randn) randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint) ones = _size_based_treelize()(torch.ones) -full = _size_based_treelize()(torch.full) +full = _size_based_treelize(tuple_=True)(torch.full) empty = _size_based_treelize()(torch.empty) # Tensor generation based on another tensor diff --git a/treetensor/tensor/treetensor.py b/treetensor/tensor/treetensor.py index 9f78950c2a474d145c672894d2a92edead316d25..d537c351cdd24126ea2da4720406bf264232108c 100644 --- a/treetensor/tensor/treetensor.py +++ b/treetensor/tensor/treetensor.py @@ -1,9 +1,35 @@ +from functools import partial +from operator import __eq__ + from torch import Tensor -from treevalue import general_tree_value, method_treelize +from treevalue import general_tree_value, method_treelize, TreeValue +from ..common import TreeList from ..numpy import TreeNumpy +def _same_merge(eq, hash_, **kwargs): + kws = { + key: value for key, value in kwargs.items() + if not (isinstance(value, TreeValue) and not value) + } + + class _Wrapper: + def __init__(self, v): + self.v = v + + def __hash__(self): + return hash_(self.v) + + def __eq__(self, other): + return eq(self.v, other.v) + + if len(set(_Wrapper(v) for v in kws.values())) == 1: + return list(kws.values())[0] + else: + return TreeTensor(kws) + + # noinspection PyTypeChecker,PyShadowingBuiltins class TreeTensor(general_tree_value()): def numel(self) -> int: @@ -11,7 +37,46 @@ class TreeTensor(general_tree_value()): .map(lambda t: t.numel()) \ .reduce(lambda **kws: sum(kws.values())) + @property + def raw_shape(self): + return self.map(lambda t: t.shape) + + @property + def shape(self): + return self.raw_shape.reduce(partial(_same_merge, __eq__, hash)) + numpy = method_treelize(return_type=TreeNumpy)(Tensor.numpy) + tolist = method_treelize(return_type=TreeList)(Tensor.tolist) cpu = method_treelize()(Tensor.cpu) cuda = method_treelize()(Tensor.cuda) to = method_treelize()(Tensor.to) + + @method_treelize() + def __lt__(self, other): + return self < other + + @method_treelize() + def __le__(self, other): + return self <= other + + @method_treelize() + def __gt__(self, other): + return self > other + + @method_treelize() + def __ge__(self, other): + return self >= other + + @method_treelize() + def tensor_eq(self, other): + return self == other + + @method_treelize() + def tensor_ne(self, other): + return self != other + + def all(self): + return self.reduce(lambda **kws: all(kws.values())) + + def any(self): + return self.reduce(lambda **kws: any(kws.values()))