From 8f298c55eb4980fddf7bf623b2a0702c353a1c1e Mon Sep 17 00:00:00 2001 From: HansBug Date: Tue, 7 Sep 2021 19:55:54 +0800 Subject: [PATCH] dev(hansbug): add plenty of basic functions --- test/numpy/__init__.py | 4 +- test/numpy/test_fake.py | 19 ------ test/numpy/test_funcs.py | 83 ++++++++++++++++++++++++++ test/numpy/test_numpy.py | 20 ++----- test/tensor/__init__.py | 1 + test/tensor/test_funcs.py | 104 +++++++++++++++++++++++++++++++++ test/tensor/test_treetensor.py | 31 +++++++++- treetensor/__init__.py | 4 +- treetensor/numpy/__init__.py | 8 +-- treetensor/numpy/fake.py | 5 -- treetensor/numpy/funcs.py | 19 ++++++ treetensor/tensor/__init__.py | 2 +- treetensor/tensor/funcs.py | 24 ++++++++ 13 files changed, 272 insertions(+), 52 deletions(-) delete mode 100644 test/numpy/test_fake.py create mode 100644 test/numpy/test_funcs.py create mode 100644 test/tensor/test_funcs.py delete mode 100644 treetensor/numpy/fake.py create mode 100644 treetensor/numpy/funcs.py diff --git a/test/numpy/__init__.py b/test/numpy/__init__.py index 043647c43..099cc30b9 100644 --- a/test/numpy/__init__.py +++ b/test/numpy/__init__.py @@ -1,2 +1,2 @@ -from .test_fake import TestNumpyFake -from .test_numpy import TestNumpyReal +from .test_funcs import TestNumpyFuncs +from .test_numpy import TestNumpyNumpy diff --git a/test/numpy/test_fake.py b/test/numpy/test_fake.py deleted file mode 100644 index 1d2a38567..000000000 --- a/test/numpy/test_fake.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest - -from treetensor import TreeNumpy - -try: - import numpy as np -except ImportError: - need_fake = True - from treetensor.numpy.fake import FakeTreeNumpy -else: - need_fake = False - -unittest_mark = pytest.mark.unittest if need_fake else pytest.mark.ignore - - -@unittest_mark -class TestNumpyFake: - def test_base(self): - assert TreeNumpy is FakeTreeNumpy diff --git a/test/numpy/test_funcs.py b/test/numpy/test_funcs.py new file mode 100644 index 000000000..23faca1d8 --- /dev/null +++ b/test/numpy/test_funcs.py @@ -0,0 +1,83 @@ +import numpy as np +import pytest + +from treetensor.numpy import TreeNumpy, all_array_equal, equal, array_equal + + +# noinspection DuplicatedCode +@pytest.mark.unittest +class TestNumpyFuncs: + _DEMO_1 = TreeNumpy({ + 'a': np.array([[1, 2, 3], [5, 6, 7]]), + 'b': np.array([1, 3, 5, 7]), + 'x': { + 'c': np.array([3, 5, 7]), + 'd': np.array([[7, 9]]), + } + }) + + _DEMO_2 = TreeNumpy({ + 'a': np.array([[1, 2, 3], [5, 6, 8]]), + 'b': np.array([1, 3, 5, 7]), + 'x': { + 'c': np.array([3, 5, 7]), + 'd': np.array([[7, 9]]), + } + }) + + _DEMO_3 = TreeNumpy({ + 'a': np.array([[1, 2, 3], [5, 6, 7]]), + 'b': np.array([1, 3, 5, 7]), + 'x': { + 'c': np.array([3, 5, 7]), + 'd': np.array([[7, 9]]), + } + }) + + def test_all_array_equal(self): + assert not all_array_equal(self._DEMO_1, self._DEMO_2) + assert all_array_equal(self._DEMO_1, self._DEMO_3) + assert not all_array_equal(np.array([1, 2, 3]), np.array([1, 2, 4])) + assert all_array_equal(np.array([1, 2, 3]), np.array([1, 2, 3])) + + def test_equal(self): + assert all_array_equal( + equal(self._DEMO_1, self._DEMO_2), + TreeNumpy({ + 'a': np.array([[True, True, True], [True, True, False]]), + 'b': np.array([True, True, True, True]), + 'x': { + 'c': np.array([True, True, True]), + 'd': np.array([[True, True]]), + } + }) + ) + assert all_array_equal( + equal(self._DEMO_1, self._DEMO_3), + TreeNumpy({ + 'a': np.array([[True, True, True], [True, True, True]]), + 'b': np.array([True, True, True, True]), + 'x': { + 'c': np.array([True, True, True]), + 'd': np.array([[True, True]]), + } + }) + ) + + def test_array_equal(self): + assert array_equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({ + 'a': False, + 'b': True, + 'x': { + 'c': True, + 'd': True, + } + }) + assert array_equal(self._DEMO_1, self._DEMO_3) == TreeNumpy({ + 'a': True, + 'b': True, + 'x': { + 'c': True, + 'd': True, + } + }) diff --git a/test/numpy/test_numpy.py b/test/numpy/test_numpy.py index 0c14a4593..56757c73c 100644 --- a/test/numpy/test_numpy.py +++ b/test/numpy/test_numpy.py @@ -1,23 +1,11 @@ +import numpy as np import pytest -from treetensor import TreeNumpy +from treetensor.numpy import TreeNumpy -try: - import numpy as np -except ImportError: - need_real = False -else: - need_real = True - from treetensor.numpy.numpy import TreeNumpy as RealTreeNumpy - -unittest_mark = pytest.mark.unittest if need_real else pytest.mark.ignore - - -@unittest_mark -class TestNumpyReal: - def test_base(self): - assert TreeNumpy is RealTreeNumpy +@pytest.mark.unittest +class TestNumpyNumpy: _DEMO_1 = TreeNumpy({ 'a': np.array([[1, 2, 3], [4, 5, 6]]), 'b': np.array([1, 3, 5, 7]), diff --git a/test/tensor/__init__.py b/test/tensor/__init__.py index d88a680cd..a78ed870d 100644 --- a/test/tensor/__init__.py +++ b/test/tensor/__init__.py @@ -1 +1,2 @@ +from .test_funcs import TestTensorFuncs from .test_treetensor import TestTensorTreetensor diff --git a/test/tensor/test_funcs.py b/test/tensor/test_funcs.py new file mode 100644 index 000000000..e9d7a6f85 --- /dev/null +++ b/test/tensor/test_funcs.py @@ -0,0 +1,104 @@ +import pytest +import torch + +from treetensor.tensor import TreeTensor, zeros, all_equal, zeros_like, ones, ones_like + + +# noinspection DuplicatedCode +@pytest.mark.unittest +class TestTensorFuncs: + def test_all_equal(self): + assert all_equal( + torch.tensor([[1, 2, 3], [4, 5, 6]]), + torch.tensor([[1, 2, 3], [4, 5, 6]]), + ) + assert all_equal( + TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + }), + TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + }), + ) + + def test_zeros(self): + assert all_equal(zeros((2, 3)), torch.zeros(2, 3)) + assert all_equal(zeros({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + }), TreeTensor({ + 'a': torch.zeros(2, 3), + 'b': torch.zeros(5, 6), + 'x': { + 'c': torch.zeros(2, 3, 4), + } + })) + + def test_zeros_like(self): + assert all_equal( + zeros_like(TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + })), + TreeTensor({ + 'a': torch.tensor([[0, 0, 0], [0, 0, 0]]), + 'b': torch.tensor([0, 0, 0, 0]), + 'x': { + 'c': torch.tensor([0, 0, 0]), + 'd': torch.tensor([[[0, 0]]]), + } + }) + ) + + def test_ones(self): + assert all_equal(ones((2, 3)), torch.ones(2, 3)) + assert all_equal(ones({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + }), TreeTensor({ + 'a': torch.ones(2, 3), + 'b': torch.ones(5, 6), + 'x': { + 'c': torch.ones(2, 3, 4), + } + })) + + def test_ones_like(self): + assert all_equal( + ones_like(TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([1, 2, 3, 4]), + 'x': { + 'c': torch.tensor([5, 6, 7]), + 'd': torch.tensor([[[8, 9]]]), + } + })), + TreeTensor({ + 'a': torch.tensor([[1, 1, 1], [1, 1, 1]]), + 'b': torch.tensor([1, 1, 1, 1]), + 'x': { + 'c': torch.tensor([1, 1, 1]), + 'd': torch.tensor([[[1, 1]]]), + } + }) + ) diff --git a/test/tensor/test_treetensor.py b/test/tensor/test_treetensor.py index 0024ec7ca..c3aef7614 100644 --- a/test/tensor/test_treetensor.py +++ b/test/tensor/test_treetensor.py @@ -1,7 +1,12 @@ +import numpy as np import pytest import torch +from treevalue import func_treelize -from treetensor import TreeTensor +from treetensor.numpy import all_array_equal, TreeNumpy +from treetensor.tensor import TreeTensor, all_equal + +_all_is = func_treelize(return_type=TreeTensor)(lambda x, y: x is y) @pytest.mark.unittest @@ -17,3 +22,27 @@ class TestTensorTreetensor: def test_numel(self): assert self._DEMO_1.numel() == 18 + + def test_numpy(self): + assert all_array_equal(self._DEMO_1.numpy(), TreeNumpy({ + 'a': np.array([[1, 2, 3], [4, 5, 6]]), + 'b': np.array([[1, 2], [5, 6]]), + 'x': { + 'c': np.array([3, 5, 6, 7]), + 'd': np.array([[[1, 2], [8, 9]]]), + } + })) + + def test_cpu(self): + assert all_equal(self._DEMO_1.cpu(), self._DEMO_1) + assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values())) + + def test_to(self): + assert all_equal(self._DEMO_1.to(torch.float32), TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), + 'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32), + 'x': { + 'c': torch.tensor([3, 5, 6, 7], dtype=torch.float32), + 'd': torch.tensor([[[1, 2], [8, 9]]], dtype=torch.float32), + } + })) diff --git a/treetensor/__init__.py b/treetensor/__init__.py index e04893b67..41a44f637 100644 --- a/treetensor/__init__.py +++ b/treetensor/__init__.py @@ -1,2 +1,2 @@ -from .numpy import * -from .tensor import * +from .numpy import TreeNumpy +from .tensor import TreeTensor diff --git a/treetensor/numpy/__init__.py b/treetensor/numpy/__init__.py index 75b713a3a..017c42919 100644 --- a/treetensor/numpy/__init__.py +++ b/treetensor/numpy/__init__.py @@ -1,6 +1,2 @@ -try: - import numpy as np -except ImportError: # numpy not exist - from .fake import FakeTreeNumpy as TreeNumpy -else: - from .numpy import TreeNumpy as TreeNumpy +from .funcs import equal, array_equal, all_array_equal +from .numpy import TreeNumpy diff --git a/treetensor/numpy/fake.py b/treetensor/numpy/fake.py deleted file mode 100644 index 6ff89f87b..000000000 --- a/treetensor/numpy/fake.py +++ /dev/null @@ -1,5 +0,0 @@ -from treevalue import general_tree_value - - -class FakeTreeNumpy(general_tree_value()): - pass diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py new file mode 100644 index 000000000..bf937e14f --- /dev/null +++ b/treetensor/numpy/funcs.py @@ -0,0 +1,19 @@ +from functools import partial + +import numpy as np +from treevalue import func_treelize, TreeValue + +from .numpy import TreeNumpy + +_treelize = partial(func_treelize, return_type=TreeNumpy) + +equal = _treelize()(np.equal) +array_equal = _treelize()(np.array_equal) + + +def all_array_equal(tx, ty, *args, **kwargs) -> bool: + _result = array_equal(tx, ty, *args, **kwargs) + if isinstance(tx, TreeValue) and isinstance(ty, TreeValue): + return _result.reduce(lambda **kws: all(list(kws.values()))) + else: + return _result diff --git a/treetensor/tensor/__init__.py b/treetensor/tensor/__init__.py index 8629dcfbf..31276672b 100644 --- a/treetensor/tensor/__init__.py +++ b/treetensor/tensor/__init__.py @@ -1,4 +1,4 @@ # noinspection PyUnresolvedReferences from .funcs import zeros_like, full_like, ones_like, randint_like, randn_like, empty_like, zeros, randn, randint, \ - func_treelize, ones, empty, full + func_treelize, ones, empty, full, all, eq, equal, all_eq, all_equal from .treetensor import TreeTensor diff --git a/treetensor/tensor/funcs.py b/treetensor/tensor/funcs.py index e0f049269..c6d267127 100644 --- a/treetensor/tensor/funcs.py +++ b/treetensor/tensor/funcs.py @@ -7,6 +7,7 @@ from treevalue import func_treelize, TreeValue from .treetensor import TreeTensor _treelize = partial(func_treelize, return_type=TreeTensor) +_python_all = all def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **kwargs_): @@ -28,6 +29,7 @@ def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **k return _decorator +# Tensor generation based on shapes zeros = _size_based_treelize()(torch.zeros) randn = _size_based_treelize()(torch.randn) randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint) @@ -35,9 +37,31 @@ ones = _size_based_treelize()(torch.ones) full = _size_based_treelize()(torch.full) empty = _size_based_treelize()(torch.empty) +# Tensor generation based on another tensor zeros_like = _treelize()(torch.zeros_like) randn_like = _treelize()(torch.randn_like) randint_like = _treelize()(torch.randint_like) ones_like = _treelize()(torch.ones_like) full_like = _treelize()(torch.full_like) empty_like = _treelize()(torch.empty_like) + +# Tensor operators +all = _treelize()(torch.all) +eq = _treelize()(torch.eq) +equal = _treelize()(torch.equal) + + +def all_eq(tx, ty, *args, **kwargs) -> bool: + _result = eq(tx, ty, *args, **kwargs) + if isinstance(tx, TreeValue) and isinstance(ty, TreeValue): + return _result.reduce(lambda **kws: _python_all(kws.values())) + else: + return _result + + +def all_equal(tx, ty, *args, **kwargs) -> bool: + _result = equal(tx, ty, *args, **kwargs) + if isinstance(tx, TreeValue) and isinstance(ty, TreeValue): + return _result.reduce(lambda **kws: _python_all(kws.values())) + else: + return _result -- GitLab