From daa01e40734dd9e65c68b479f468446e4e893839 Mon Sep 17 00:00:00 2001 From: HansBug Date: Thu, 9 Sep 2021 13:49:39 +0800 Subject: [PATCH] dev(hansbug): add kwreduce and vreduce decorator && refactor mergable and conclusion functions --- test/numpy/test_funcs.py | 12 +++++----- test/tensor/test_funcs.py | 30 ++++++++++++------------- test/tensor/test_treetensor.py | 17 ++++---------- treetensor/common/__init__.py | 5 ++--- treetensor/common/base.py | 20 ----------------- treetensor/common/obj.py | 5 ----- treetensor/common/{data.py => trees.py} | 16 +++++++++++-- treetensor/common/wrappers.py | 23 +++++++++++++++++++ treetensor/numpy/funcs.py | 3 ++- treetensor/numpy/numpy.py | 26 ++++++++++----------- treetensor/tensor/funcs.py | 3 ++- treetensor/tensor/tensor.py | 10 +++++++-- 12 files changed, 89 insertions(+), 81 deletions(-) delete mode 100644 treetensor/common/base.py delete mode 100644 treetensor/common/obj.py rename treetensor/common/{data.py => trees.py} (68%) create mode 100644 treetensor/common/wrappers.py diff --git a/test/numpy/test_funcs.py b/test/numpy/test_funcs.py index 01ee93354..c30ddffe2 100644 --- a/test/numpy/test_funcs.py +++ b/test/numpy/test_funcs.py @@ -36,10 +36,10 @@ class TestNumpyFuncs: }) def test__numpy_all(self): - assert not _numpy_all(self._DEMO_1 == self._DEMO_2).all() - assert _numpy_all(self._DEMO_1 == self._DEMO_3).all() - assert not _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 4])).all() - assert _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 3])).all() + assert not _numpy_all(self._DEMO_1 == self._DEMO_2) + assert _numpy_all(self._DEMO_1 == self._DEMO_3) + assert not _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 4])) + assert _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 3])) def test_equal(self): assert _numpy_all( @@ -51,7 +51,7 @@ class TestNumpyFuncs: 'd': np.array([[True, True]]), } }) - ).all() + ) assert _numpy_all( equal(self._DEMO_1, self._DEMO_3) == TreeNumpy({ 'a': np.array([[True, True, True], [True, True, True]]), @@ -61,7 +61,7 @@ class TestNumpyFuncs: 'd': np.array([[True, True]]), } }) - ).all() + ) def test_array_equal(self): assert array_equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({ diff --git a/test/tensor/test_funcs.py b/test/tensor/test_funcs.py index 103109d17..84a528779 100644 --- a/test/tensor/test_funcs.py +++ b/test/tensor/test_funcs.py @@ -10,7 +10,7 @@ from treetensor.tensor import all as _tensor_all @pytest.mark.unittest class TestTensorFuncs: def test_zeros(self): - assert _tensor_all(zeros((2, 3)) == torch.zeros(2, 3)).all() + assert _tensor_all(zeros((2, 3)) == torch.zeros(2, 3)) assert _tensor_all(zeros({ 'a': (2, 3), 'b': (5, 6), @@ -23,7 +23,7 @@ class TestTensorFuncs: 'x': { 'c': torch.zeros(2, 3, 4), } - })).all() + })) def test_zeros_like(self): assert _tensor_all( @@ -46,7 +46,7 @@ class TestTensorFuncs: 'd': torch.tensor([[[0, 0]]]), } }) - ).all() + ) def test_ones(self): assert _tensor_all(ones((2, 3)) == torch.ones(2, 3)) @@ -62,7 +62,7 @@ class TestTensorFuncs: 'x': { 'c': torch.ones(2, 3, 4), } - })).all() + })) def test_ones_like(self): assert _tensor_all( @@ -85,7 +85,7 @@ class TestTensorFuncs: 'd': torch.tensor([[[1, 1]]]), } }) - ).all() + ) def test_randn(self): _target = randn((200, 300)) @@ -139,8 +139,8 @@ class TestTensorFuncs: 'c': (2, 3, 4), } }, -10, 10) - assert _tensor_all(_target < 10).all() - assert _tensor_all(-10 <= _target).all() + assert _tensor_all(_target < 10) + assert _tensor_all(-10 <= _target) assert _target.shape == TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), @@ -156,8 +156,8 @@ class TestTensorFuncs: 'c': (2, 3, 4), } }, 10) - assert _tensor_all(_target < 10).all() - assert _tensor_all(0 <= _target).all() + assert _tensor_all(_target < 10) + assert _tensor_all(0 <= _target) assert _target.shape == TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), @@ -175,8 +175,8 @@ class TestTensorFuncs: 'd': torch.tensor([[[8, 9]]]), } }), -10, 10) - assert _tensor_all(_target < 10).all() - assert _tensor_all(-10 <= _target).all() + assert _tensor_all(_target < 10) + assert _tensor_all(-10 <= _target) assert _target.shape == TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), @@ -194,8 +194,8 @@ class TestTensorFuncs: 'd': torch.tensor([[[8, 9]]]), } }), 10) - assert _tensor_all(_target < 10).all() - assert _tensor_all(0 <= _target).all() + assert _tensor_all(_target < 10) + assert _tensor_all(0 <= _target) assert _target.shape == TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), @@ -213,7 +213,7 @@ class TestTensorFuncs: 'c': (2, 3, 4), } }, 233) - assert _tensor_all(_target == 233).all() + assert _tensor_all(_target == 233) assert _target.shape == TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), @@ -231,7 +231,7 @@ class TestTensorFuncs: 'd': torch.tensor([[[8, 9]]]), } }), 233) - assert _tensor_all(_target == 233).all() + assert _tensor_all(_target == 233) assert _target.shape == TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), diff --git a/test/tensor/test_treetensor.py b/test/tensor/test_treetensor.py index 55329ea4f..a31dd6197 100644 --- a/test/tensor/test_treetensor.py +++ b/test/tensor/test_treetensor.py @@ -3,7 +3,6 @@ import pytest import torch from treevalue import func_treelize -from treetensor.common import TreeObject from treetensor.numpy import TreeNumpy from treetensor.numpy import all as _numpy_all from treetensor.tensor import TreeTensor @@ -24,15 +23,7 @@ class TestTensorTreetensor: }) def test_numel(self): - assert self._DEMO_1.numel() == TreeObject({ - 'a': 6, - 'b': 4, - 'x': { - 'c': 4, - 'd': 4, - } - }) - assert self._DEMO_1.numel().sum() == 18 + assert self._DEMO_1.numel() == 18 def test_numpy(self): assert _numpy_all(self._DEMO_1.numpy() == TreeNumpy({ @@ -42,10 +33,10 @@ class TestTensorTreetensor: 'c': np.array([3, 5, 6, 7]), 'd': np.array([[[1, 2], [8, 9]]]), } - })).all() + })) def test_cpu(self): - assert _tensor_all(self._DEMO_1.cpu() == self._DEMO_1).all() + assert _tensor_all(self._DEMO_1.cpu() == self._DEMO_1) assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values())) def test_to(self): @@ -56,4 +47,4 @@ class TestTensorTreetensor: 'c': torch.tensor([3, 5, 6, 7], dtype=torch.float32), 'd': torch.tensor([[[1, 2], [8, 9]]], dtype=torch.float32), } - })).all() + })) diff --git a/treetensor/common/__init__.py b/treetensor/common/__init__.py index ab0a49cc8..2b15c059d 100644 --- a/treetensor/common/__init__.py +++ b/treetensor/common/__init__.py @@ -1,3 +1,2 @@ -from .base import BaseTreeStruct -from .data import TreeData -from .obj import TreeObject +from .trees import TreeData, TreeObject, BaseTreeStruct +from .wrappers import kwreduce, vreduce diff --git a/treetensor/common/base.py b/treetensor/common/base.py deleted file mode 100644 index b62ba4f0e..000000000 --- a/treetensor/common/base.py +++ /dev/null @@ -1,20 +0,0 @@ -from abc import ABCMeta -from functools import lru_cache - -from treevalue import general_tree_value - - -@lru_cache() -def _merge_func(red): - return lambda **kws: red(kws.values()) - - -class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta): - def all(self) -> bool: - return self.reduce(_merge_func(all)) - - def any(self) -> bool: - return self.reduce(_merge_func(any)) - - def sum(self): - return self.reduce(_merge_func(sum)) diff --git a/treetensor/common/obj.py b/treetensor/common/obj.py deleted file mode 100644 index cfd65f748..000000000 --- a/treetensor/common/obj.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base import BaseTreeStruct - - -class TreeObject(BaseTreeStruct): - pass diff --git a/treetensor/common/data.py b/treetensor/common/trees.py similarity index 68% rename from treetensor/common/data.py rename to treetensor/common/trees.py index fcfd77b5e..892d15ce1 100644 --- a/treetensor/common/data.py +++ b/treetensor/common/trees.py @@ -1,8 +1,16 @@ import operator +from abc import ABCMeta -from treevalue import func_treelize +from treevalue import func_treelize, general_tree_value + + +class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta): + """ + Overview: + Base structure of all the trees in ``treetensor``. + """ + pass -from .base import BaseTreeStruct _OPERATORS = {} for _op_name in getattr(operator, '__all__'): @@ -27,3 +35,7 @@ class TreeData(BaseTreeStruct): def __ne__(self, other): return _OPERATORS['ne'](self, other) + + +class TreeObject(BaseTreeStruct): + pass diff --git a/treetensor/common/wrappers.py b/treetensor/common/wrappers.py new file mode 100644 index 000000000..b72c80a72 --- /dev/null +++ b/treetensor/common/wrappers.py @@ -0,0 +1,23 @@ +from functools import wraps + +from treevalue import TreeValue +from treevalue import reduce_ as treevalue_reduce + + +def kwreduce(reduce_func): + def _decorator(func): + @wraps(func) + def _new_func(*args, **kwargs): + _result = func(*args, **kwargs) + if isinstance(_result, TreeValue): + return treevalue_reduce(_result, reduce_func) + else: + return _result + + return _new_func + + return _decorator + + +def vreduce(vreduce_func): + return kwreduce(lambda **kws: vreduce_func(kws.values())) diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index e6d76c1d3..6d34ea120 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -4,9 +4,10 @@ import numpy as np from treevalue import func_treelize from .numpy import TreeNumpy +from ..common import vreduce _treelize = partial(func_treelize, return_type=TreeNumpy) -all = _treelize()(np.all) +all = vreduce(all)(_treelize()(np.all)) equal = _treelize()(np.equal) array_equal = _treelize()(np.array_equal) diff --git a/treetensor/numpy/numpy.py b/treetensor/numpy/numpy.py index 9c62cf002..eff7847eb 100644 --- a/treetensor/numpy/numpy.py +++ b/treetensor/numpy/numpy.py @@ -1,7 +1,7 @@ import numpy as np from treevalue import method_treelize -from ..common import TreeObject, TreeData +from ..common import TreeObject, TreeData, vreduce class TreeNumpy(TreeData): @@ -15,18 +15,18 @@ class TreeNumpy(TreeData): return self.tolist() @property - def size(self) -> int: - return self \ - .map(lambda d: d.size) \ - .reduce(lambda **kwargs: sum(kwargs.values())) + @vreduce(sum) + @method_treelize(return_type=TreeObject) + def size(self: np.ndarray) -> int: + return self.size @property - def nbytes(self) -> int: - return self \ - .map(lambda d: d.nbytes) \ - .reduce(lambda **kwargs: sum(kwargs.values())) + @vreduce(sum) + @method_treelize(return_type=TreeObject) + def nbytes(self: np.ndarray) -> int: + return self.nbytes - def sum(self): - return self \ - .map(lambda d: d.sum()) \ - .reduce(lambda **kwargs: sum(kwargs.values())) + @vreduce(sum) + @method_treelize(return_type=TreeObject) + def sum(self: np.ndarray, *args, **kwargs): + return self.sum(*args, **kwargs) diff --git a/treetensor/tensor/funcs.py b/treetensor/tensor/funcs.py index 6dae80af1..9b83f1bbf 100644 --- a/treetensor/tensor/funcs.py +++ b/treetensor/tensor/funcs.py @@ -5,6 +5,7 @@ import torch from treevalue import func_treelize, TreeValue from .tensor import TreeTensor +from ..common import vreduce _treelize = partial(func_treelize, return_type=TreeTensor) _python_all = all @@ -46,6 +47,6 @@ full_like = _treelize()(torch.full_like) empty_like = _treelize()(torch.empty_like) # Tensor operators -all = _treelize()(torch.all) +all = vreduce(all)(_treelize()(torch.all)) eq = _treelize()(torch.eq) equal = _treelize()(torch.equal) diff --git a/treetensor/tensor/tensor.py b/treetensor/tensor/tensor.py index f12671d5e..a25e757ce 100644 --- a/treetensor/tensor/tensor.py +++ b/treetensor/tensor/tensor.py @@ -3,7 +3,7 @@ import torch from treevalue import method_treelize, TreeValue from .size import TreeSize -from ..common import TreeObject, TreeData +from ..common import TreeObject, TreeData, vreduce from ..numpy import TreeNumpy @@ -29,7 +29,7 @@ def _same_merge(eq, hash_, **kwargs): return TreeTensor(kws) -# noinspection PyTypeChecker,PyShadowingBuiltins +# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList class TreeTensor(TreeData): @method_treelize(return_type=TreeNumpy) def numpy(self: torch.Tensor) -> np.ndarray: @@ -51,6 +51,7 @@ class TreeTensor(TreeData): def to(self: torch.Tensor, *args, **kwargs): return self.to(*args, **kwargs) + @vreduce(sum) @method_treelize(return_type=TreeObject) def numel(self: torch.Tensor): return self.numel() @@ -59,3 +60,8 @@ class TreeTensor(TreeData): @method_treelize(return_type=TreeSize) def shape(self: torch.Tensor): return self.shape + + @vreduce(all) + @method_treelize(return_type=TreeObject) + def all(self: torch.Tensor, *args, **kwargs): + return self.all(*args, **kwargs) -- GitLab