diff --git a/test/numpy/__init__.py b/test/numpy/__init__.py index 043647c4381f3f63d2e5c80be70cc12cf1f76b16..099cc30b9463365f7cdb4bb0f4c6e338cc6cbf48 100644 --- a/test/numpy/__init__.py +++ b/test/numpy/__init__.py @@ -1,2 +1,2 @@ -from .test_fake import TestNumpyFake -from .test_numpy import TestNumpyReal +from .test_funcs import TestNumpyFuncs +from .test_numpy import TestNumpyNumpy diff --git a/test/numpy/test_fake.py b/test/numpy/test_fake.py deleted file mode 100644 index 1d2a3856702331a09eb9f67ea612f77c20e0395c..0000000000000000000000000000000000000000 --- a/test/numpy/test_fake.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest - -from treetensor import TreeNumpy - -try: - import numpy as np -except ImportError: - need_fake = True - from treetensor.numpy.fake import FakeTreeNumpy -else: - need_fake = False - -unittest_mark = pytest.mark.unittest if need_fake else pytest.mark.ignore - - -@unittest_mark -class TestNumpyFake: - def test_base(self): - assert TreeNumpy is FakeTreeNumpy diff --git a/test/numpy/test_funcs.py b/test/numpy/test_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..23faca1d80965bb06763a44e75f883f417d264fe --- /dev/null +++ b/test/numpy/test_funcs.py @@ -0,0 +1,83 @@ +import numpy as np +import pytest + +from treetensor.numpy import TreeNumpy, all_array_equal, equal, array_equal + + +# noinspection DuplicatedCode +@pytest.mark.unittest +class TestNumpyFuncs: + _DEMO_1 = TreeNumpy({ + 'a': np.array([[1, 2, 3], [5, 6, 7]]), + 'b': np.array([1, 3, 5, 7]), + 'x': { + 'c': np.array([3, 5, 7]), + 'd': np.array([[7, 9]]), + } + }) + + _DEMO_2 = TreeNumpy({ + 'a': np.array([[1, 2, 3], [5, 6, 8]]), + 'b': np.array([1, 3, 5, 7]), + 'x': { + 'c': np.array([3, 5, 7]), + 'd': np.array([[7, 9]]), + } + }) + + _DEMO_3 = TreeNumpy({ + 'a': np.array([[1, 2, 3], [5, 6, 7]]), + 'b': np.array([1, 3, 5, 7]), + 'x': { + 'c': np.array([3, 5, 7]), + 'd': np.array([[7, 9]]), + } + }) + + def test_all_array_equal(self): + assert not all_array_equal(self._DEMO_1, self._DEMO_2) + assert all_array_equal(self._DEMO_1, self._DEMO_3) + assert not all_array_equal(np.array([1, 2, 3]), np.array([1, 2, 4])) + assert all_array_equal(np.array([1, 2, 3]), np.array([1, 2, 3])) + + def test_equal(self): + assert all_array_equal( + equal(self._DEMO_1, self._DEMO_2), + TreeNumpy({ + 'a': np.array([[True, True, True], [True, True, False]]), + 'b': np.array([True, True, True, True]), + 'x': { + 'c': np.array([True, True, True]), + 'd': np.array([[True, True]]), + } + }) + ) + assert all_array_equal( + equal(self._DEMO_1, self._DEMO_3), + TreeNumpy({ + 'a': np.array([[True, True, True], [True, True, True]]), + 'b': np.array([True, True, True, True]), + 'x': { + 'c': np.array([True, True, True]), + 'd': np.array([[True, True]]), + } + }) + ) + + def test_array_equal(self): + assert array_equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({ + 'a': False, + 'b': True, + 'x': { + 'c': True, + 'd': True, + } + }) + assert array_equal(self._DEMO_1, self._DEMO_3) == TreeNumpy({ + 'a': True, + 'b': True, + 'x': { + 'c': True, + 'd': True, + } + }) diff --git a/test/numpy/test_numpy.py b/test/numpy/test_numpy.py index 0c14a45933c083479b6cb236b42efc49dd3c665e..56757c73cf43f0661a9c764ae887bd02dda1b1f3 100644 --- a/test/numpy/test_numpy.py +++ b/test/numpy/test_numpy.py @@ -1,23 +1,11 @@ +import numpy as np import pytest -from treetensor import TreeNumpy +from treetensor.numpy import TreeNumpy -try: - import numpy as np -except ImportError: - need_real = False -else: - need_real = True - from treetensor.numpy.numpy import TreeNumpy as RealTreeNumpy - -unittest_mark = pytest.mark.unittest if need_real else pytest.mark.ignore - - -@unittest_mark -class TestNumpyReal: - def test_base(self): - assert TreeNumpy is RealTreeNumpy +@pytest.mark.unittest +class TestNumpyNumpy: _DEMO_1 = TreeNumpy({ 'a': np.array([[1, 2, 3], [4, 5, 6]]), 'b': np.array([1, 3, 5, 7]), diff --git a/test/tensor/__init__.py b/test/tensor/__init__.py index d88a680cd1a6fa1edea1f9a073ba6f477922b73a..a78ed870d9be2e5912ab08c243239a40d7167194 100644 --- a/test/tensor/__init__.py +++ b/test/tensor/__init__.py @@ -1 +1,2 @@ +from .test_funcs import TestTensorFuncs from .test_treetensor import TestTensorTreetensor diff --git a/test/tensor/test_funcs.py b/test/tensor/test_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d7a6f85bd71c5c3e729457a073dd1466fd0629 --- /dev/null +++ b/test/tensor/test_funcs.py @@ -0,0 +1,104 @@ +import pytest +import torch + +from treetensor.tensor import TreeTensor, zeros, all_equal, zeros_like, ones, ones_like + + +# noinspection DuplicatedCode +@pytest.mark.unittest +class TestTensorFuncs: + def test_all_equal(self): + assert all_equal( + torch.tensor([[1, 2, 3], [4, 5, 6]]), + torch.tensor([[1, 2, 3], [4, 5, 6]]), + ) + assert all_equal( + TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + }), + TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + }), + ) + + def test_zeros(self): + assert all_equal(zeros((2, 3)), torch.zeros(2, 3)) + assert all_equal(zeros({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + }), TreeTensor({ + 'a': torch.zeros(2, 3), + 'b': torch.zeros(5, 6), + 'x': { + 'c': torch.zeros(2, 3, 4), + } + })) + + def test_zeros_like(self): + assert all_equal( + zeros_like(TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + })), + TreeTensor({ + 'a': torch.tensor([[0, 0, 0], [0, 0, 0]]), + 'b': torch.tensor([0, 0, 0, 0]), + 'x': { + 'c': torch.tensor([0, 0, 0]), + 'd': torch.tensor([[[0, 0]]]), + } + }) + ) + + def test_ones(self): + assert all_equal(ones((2, 3)), torch.ones(2, 3)) + assert all_equal(ones({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + }), TreeTensor({ + 'a': torch.ones(2, 3), + 'b': torch.ones(5, 6), + 'x': { + 'c': torch.ones(2, 3, 4), + } + })) + + def test_ones_like(self): + assert all_equal( + ones_like(TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + })), + TreeTensor({ + 'a': torch.tensor([[1, 1, 1], [1, 1, 1]]), + 'b': torch.tensor([1, 1, 1, 1]), + 'x': { + 'c': torch.tensor([1, 1, 1]), + 'd': torch.tensor([[[1, 1]]]), + } + }) + ) diff --git a/test/tensor/test_treetensor.py b/test/tensor/test_treetensor.py index 0024ec7ca66df4110a0a1c6a9b5d447a785e5617..c3aef76143a60e4cff6f944bd378069ba83f939e 100644 --- a/test/tensor/test_treetensor.py +++ b/test/tensor/test_treetensor.py @@ -1,7 +1,12 @@ +import numpy as np import pytest import torch +from treevalue import func_treelize -from treetensor import TreeTensor +from treetensor.numpy import all_array_equal, TreeNumpy +from treetensor.tensor import TreeTensor, all_equal + +_all_is = func_treelize(return_type=TreeTensor)(lambda x, y: x is y) @pytest.mark.unittest @@ -17,3 +22,27 @@ class TestTensorTreetensor: def test_numel(self): assert self._DEMO_1.numel() == 18 + + def test_numpy(self): + assert all_array_equal(self._DEMO_1.numpy(), TreeNumpy({ + 'a': np.array([[1, 2, 3], [4, 5, 6]]), + 'b': np.array([[1, 2], [5, 6]]), + 'x': { + 'c': np.array([3, 5, 6, 7]), + 'd': np.array([[[1, 2], [8, 9]]]), + } + })) + + def test_cpu(self): + assert all_equal(self._DEMO_1.cpu(), self._DEMO_1) + assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values())) + + def test_to(self): + assert all_equal(self._DEMO_1.to(torch.float32), TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), + 'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32), + 'x': { + 'c': torch.tensor([3, 5, 6, 7], dtype=torch.float32), + 'd': torch.tensor([[[1, 2], [8, 9]]], dtype=torch.float32), + } + })) diff --git a/treetensor/__init__.py b/treetensor/__init__.py index e04893b671edd3b66659f223b6c1e211b885e669..41a44f63736a7b2fe4548ed66c375558a0b2bda9 100644 --- a/treetensor/__init__.py +++ b/treetensor/__init__.py @@ -1,2 +1,2 @@ -from .numpy import * -from .tensor import * +from .numpy import TreeNumpy +from .tensor import TreeTensor diff --git a/treetensor/numpy/__init__.py b/treetensor/numpy/__init__.py index 75b713a3ac6924904a6d013022a1c4027fef3160..017c4291927d891f53a1b6ad4882c50c1fd98eab 100644 --- a/treetensor/numpy/__init__.py +++ b/treetensor/numpy/__init__.py @@ -1,6 +1,2 @@ -try: - import numpy as np -except ImportError: # numpy not exist - from .fake import FakeTreeNumpy as TreeNumpy -else: - from .numpy import TreeNumpy as TreeNumpy +from .funcs import equal, array_equal, all_array_equal +from .numpy import TreeNumpy diff --git a/treetensor/numpy/fake.py b/treetensor/numpy/fake.py deleted file mode 100644 index 6ff89f87bd5fab259e0b2b4f072c231e6da9dcb5..0000000000000000000000000000000000000000 --- a/treetensor/numpy/fake.py +++ /dev/null @@ -1,5 +0,0 @@ -from treevalue import general_tree_value - - -class FakeTreeNumpy(general_tree_value()): - pass diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..bf937e14f913c3373d4b005bbd588c4efc280543 --- /dev/null +++ b/treetensor/numpy/funcs.py @@ -0,0 +1,19 @@ +from functools import partial + +import numpy as np +from treevalue import func_treelize, TreeValue + +from .numpy import TreeNumpy + +_treelize = partial(func_treelize, return_type=TreeNumpy) + +equal = _treelize()(np.equal) +array_equal = _treelize()(np.array_equal) + + +def all_array_equal(tx, ty, *args, **kwargs) -> bool: + _result = array_equal(tx, ty, *args, **kwargs) + if isinstance(tx, TreeValue) and isinstance(ty, TreeValue): + return _result.reduce(lambda **kws: all(list(kws.values()))) + else: + return _result diff --git a/treetensor/tensor/__init__.py b/treetensor/tensor/__init__.py index 8629dcfbf253b2de268391b0c23a61e3c1191300..31276672baacd02019b0094ee87e1919aac32da4 100644 --- a/treetensor/tensor/__init__.py +++ b/treetensor/tensor/__init__.py @@ -1,4 +1,4 @@ # noinspection PyUnresolvedReferences from .funcs import zeros_like, full_like, ones_like, randint_like, randn_like, empty_like, zeros, randn, randint, \ - func_treelize, ones, empty, full + func_treelize, ones, empty, full, all, eq, equal, all_eq, all_equal from .treetensor import TreeTensor diff --git a/treetensor/tensor/funcs.py b/treetensor/tensor/funcs.py index e0f0492698548f29d4554186cb1310b6517e9229..c6d26712790f9fcfe9a7e7613a1ef3d94b4f6656 100644 --- a/treetensor/tensor/funcs.py +++ b/treetensor/tensor/funcs.py @@ -7,6 +7,7 @@ from treevalue import func_treelize, TreeValue from .treetensor import TreeTensor _treelize = partial(func_treelize, return_type=TreeTensor) +_python_all = all def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **kwargs_): @@ -28,6 +29,7 @@ def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **k return _decorator +# Tensor generation based on shapes zeros = _size_based_treelize()(torch.zeros) randn = _size_based_treelize()(torch.randn) randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint) @@ -35,9 +37,31 @@ ones = _size_based_treelize()(torch.ones) full = _size_based_treelize()(torch.full) empty = _size_based_treelize()(torch.empty) +# Tensor generation based on another tensor zeros_like = _treelize()(torch.zeros_like) randn_like = _treelize()(torch.randn_like) randint_like = _treelize()(torch.randint_like) ones_like = _treelize()(torch.ones_like) full_like = _treelize()(torch.full_like) empty_like = _treelize()(torch.empty_like) + +# Tensor operators +all = _treelize()(torch.all) +eq = _treelize()(torch.eq) +equal = _treelize()(torch.equal) + + +def all_eq(tx, ty, *args, **kwargs) -> bool: + _result = eq(tx, ty, *args, **kwargs) + if isinstance(tx, TreeValue) and isinstance(ty, TreeValue): + return _result.reduce(lambda **kws: _python_all(kws.values())) + else: + return _result + + +def all_equal(tx, ty, *args, **kwargs) -> bool: + _result = equal(tx, ty, *args, **kwargs) + if isinstance(tx, TreeValue) and isinstance(ty, TreeValue): + return _result.reduce(lambda **kws: _python_all(kws.values())) + else: + return _result