From c5c230d8f1956c394144682a6a6cc68d6b8bd5ed Mon Sep 17 00:00:00 2001 From: HansBug Date: Fri, 10 Sep 2021 20:06:21 +0800 Subject: [PATCH] dev(hansbug): refactor the current code --- test/numpy/test_funcs.py | 23 +++++- test/tensor/test_funcs.py | 66 +++++++++++++--- test/tensor/test_treetensor.py | 13 ++++ treetensor/common/__init__.py | 2 +- treetensor/common/trees.py | 39 +++++----- treetensor/common/wrappers.py | 38 +++++++++- treetensor/numpy/funcs.py | 27 +++++-- treetensor/numpy/numpy.py | 8 +- treetensor/tensor/funcs.py | 134 ++++++++++++++++++++------------- treetensor/tensor/size.py | 9 ++- treetensor/tensor/tensor.py | 53 ++++++------- treetensor/utils/__init__.py | 1 + treetensor/utils/func.py | 5 ++ 13 files changed, 291 insertions(+), 127 deletions(-) create mode 100644 treetensor/utils/__init__.py create mode 100644 treetensor/utils/func.py diff --git a/test/numpy/test_funcs.py b/test/numpy/test_funcs.py index c30ddffe2..74ac11646 100644 --- a/test/numpy/test_funcs.py +++ b/test/numpy/test_funcs.py @@ -35,13 +35,25 @@ class TestNumpyFuncs: } }) - def test__numpy_all(self): + def test_all(self): + assert not _numpy_all(np.array([True, True, False])) + assert _numpy_all(np.array([True, True, True])) + assert not _numpy_all(self._DEMO_1 == self._DEMO_2) assert _numpy_all(self._DEMO_1 == self._DEMO_3) assert not _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 4])) assert _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 3])) def test_equal(self): + assert _numpy_all(equal( + np.array([1, 2, 3]), + np.array([1, 2, 3]), + )) + assert not _numpy_all(equal( + np.array([1, 2, 3]), + np.array([1, 2, 4]), + )) + assert _numpy_all( equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({ 'a': np.array([[True, True, True], [True, True, False]]), @@ -64,6 +76,15 @@ class TestNumpyFuncs: ) def test_array_equal(self): + assert _numpy_all(array_equal( + np.array([1, 2, 3]), + np.array([1, 2, 3]), + )) + assert not _numpy_all(array_equal( + np.array([1, 2, 3]), + np.array([1, 2, 4]), + )) + assert array_equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({ 'a': False, 'b': True, diff --git a/test/tensor/test_funcs.py b/test/tensor/test_funcs.py index 84a528779..f4d5745a3 100644 --- a/test/tensor/test_funcs.py +++ b/test/tensor/test_funcs.py @@ -1,5 +1,6 @@ import pytest import torch +from treevalue import TreeValue from treetensor.tensor import TreeTensor, zeros, zeros_like, ones, ones_like, randint, randint_like, randn, \ randn_like, full, full_like, TreeSize @@ -11,13 +12,13 @@ from treetensor.tensor import all as _tensor_all class TestTensorFuncs: def test_zeros(self): assert _tensor_all(zeros((2, 3)) == torch.zeros(2, 3)) - assert _tensor_all(zeros({ + assert _tensor_all(zeros(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } - }) == TreeTensor({ + })) == TreeTensor({ 'a': torch.zeros(2, 3), 'b': torch.zeros(5, 6), 'x': { @@ -50,13 +51,13 @@ class TestTensorFuncs: def test_ones(self): assert _tensor_all(ones((2, 3)) == torch.ones(2, 3)) - assert _tensor_all(ones({ + assert _tensor_all(ones(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } - }) == TreeTensor({ + })) == TreeTensor({ 'a': torch.ones(2, 3), 'b': torch.ones(5, 6), 'x': { @@ -93,13 +94,13 @@ class TestTensorFuncs: assert 0.98 <= _target.view(60000).std().tolist() <= 1.02 assert _target.shape == torch.Size([200, 300]) - _target = randn({ + _target = randn(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } - }) + })) assert _target.shape == TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), @@ -132,13 +133,13 @@ class TestTensorFuncs: }) def test_randint(self): - _target = randint({ + _target = randint(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } - }, -10, 10) + }), -10, 10) assert _tensor_all(_target < 10) assert _tensor_all(-10 <= _target) assert _target.shape == TreeSize({ @@ -149,13 +150,13 @@ class TestTensorFuncs: } }) - _target = randint({ + _target = randint(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } - }, 10) + }), 10) assert _tensor_all(_target < 10) assert _tensor_all(0 <= _target) assert _target.shape == TreeSize({ @@ -206,13 +207,13 @@ class TestTensorFuncs: }) def test_full(self): - _target = full({ + _target = full(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } - }, 233) + }), 233) assert _tensor_all(_target == 233) assert _target.shape == TreeSize({ 'a': torch.Size([2, 3]), @@ -240,3 +241,44 @@ class TestTensorFuncs: 'd': torch.Size([1, 1, 2]), } }) + + def test_all(self): + r1 = _tensor_all(torch.tensor([1, 1, 1]) == 1) + assert torch.is_tensor(r1) + assert r1 == torch.tensor(True) + + r2 = _tensor_all(torch.tensor([1, 1, 2]) == 1) + assert torch.is_tensor(r2) + assert r2 == torch.tensor(False) + + r3 = _tensor_all(TreeTensor({ + 'a': torch.Tensor([1, 2, 3]), + 'b': torch.Tensor([4, 5, 6]), + 'x': { + 'c': torch.Tensor([7, 8, 9]) + } + }) == TreeTensor({ + 'a': torch.Tensor([1, 2, 3]), + 'b': torch.Tensor([4, 5, 6]), + 'x': { + 'c': torch.Tensor([7, 8, 9]) + } + })) + assert torch.is_tensor(r3) + assert r3 == torch.tensor(True) + + r4 = _tensor_all(TreeTensor({ + 'a': torch.Tensor([1, 2, 3]), + 'b': torch.Tensor([4, 5, 6]), + 'x': { + 'c': torch.Tensor([7, 8, 9]) + } + }) == TreeTensor({ + 'a': torch.Tensor([1, 2, 3]), + 'b': torch.Tensor([4, 5, 6]), + 'x': { + 'c': torch.Tensor([7, 8, 8]) + } + })) + assert torch.is_tensor(r4) + assert r4 == torch.tensor(False) diff --git a/test/tensor/test_treetensor.py b/test/tensor/test_treetensor.py index a31dd6197..85f15e30b 100644 --- a/test/tensor/test_treetensor.py +++ b/test/tensor/test_treetensor.py @@ -22,6 +22,15 @@ class TestTensorTreetensor: } }) + _DEMO_2 = TreeTensor({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), + 'b': torch.tensor([[1, 2], [5, 60]]), + 'x': { + 'c': torch.tensor([3, 5, 6, 7]), + 'd': torch.tensor([[[1, 2], [8, 9]]]), + } + }) + def test_numel(self): assert self._DEMO_1.numel() == 18 @@ -48,3 +57,7 @@ class TestTensorTreetensor: 'd': torch.tensor([[[1, 2], [8, 9]]], dtype=torch.float32), } })) + + def test_all(self): + assert (self._DEMO_1 == self._DEMO_1).all() + assert not (self._DEMO_1 == self._DEMO_2).all() diff --git a/treetensor/common/__init__.py b/treetensor/common/__init__.py index 2b15c059d..65d1187f1 100644 --- a/treetensor/common/__init__.py +++ b/treetensor/common/__init__.py @@ -1,2 +1,2 @@ from .trees import TreeData, TreeObject, BaseTreeStruct -from .wrappers import kwreduce, vreduce +from .wrappers import kwreduce, vreduce, ireduce diff --git a/treetensor/common/trees.py b/treetensor/common/trees.py index 892d15ce1..3bdc87656 100644 --- a/treetensor/common/trees.py +++ b/treetensor/common/trees.py @@ -1,7 +1,6 @@ -import operator from abc import ABCMeta -from treevalue import func_treelize, general_tree_value +from treevalue import general_tree_value, method_treelize class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta): @@ -12,29 +11,35 @@ class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta): pass -_OPERATORS = {} -for _op_name in getattr(operator, '__all__'): - _OPERATORS[_op_name] = func_treelize()(getattr(operator, _op_name)) +class TreeData(BaseTreeStruct, metaclass=ABCMeta): + """ + Overview: + In ``TreeData`` class, all the comparison operators will be override. + """ + @method_treelize() + def __eq__(self, other): + return self == other -class TreeData(BaseTreeStruct): - def __le__(self, other): - return _OPERATORS['le'](self, other) + @method_treelize() + def __ne__(self, other): + return self != other + @method_treelize() def __lt__(self, other): - return _OPERATORS['lt'](self, other) + return self < other - def __ge__(self, other): - return _OPERATORS['ge'](self, other) + @method_treelize() + def __le__(self, other): + return self <= other + @method_treelize() def __gt__(self, other): - return _OPERATORS['gt'](self, other) - - def __eq__(self, other): - return _OPERATORS['eq'](self, other) + return self > other - def __ne__(self, other): - return _OPERATORS['ne'](self, other) + @method_treelize() + def __ge__(self, other): + return self >= other class TreeObject(BaseTreeStruct): diff --git a/treetensor/common/wrappers.py b/treetensor/common/wrappers.py index b72c80a72..e7e510150 100644 --- a/treetensor/common/wrappers.py +++ b/treetensor/common/wrappers.py @@ -1,16 +1,18 @@ +from collections import namedtuple from functools import wraps +from itertools import chain from treevalue import TreeValue from treevalue import reduce_ as treevalue_reduce -def kwreduce(reduce_func): +def kwreduce(rfunc): def _decorator(func): @wraps(func) def _new_func(*args, **kwargs): _result = func(*args, **kwargs) if isinstance(_result, TreeValue): - return treevalue_reduce(_result, reduce_func) + return treevalue_reduce(_result, rfunc) else: return _result @@ -19,5 +21,33 @@ def kwreduce(reduce_func): return _decorator -def vreduce(vreduce_func): - return kwreduce(lambda **kws: vreduce_func(kws.values())) +def vreduce(rfunc): + return kwreduce(lambda **kws: rfunc(kws.values())) + + +def ireduce(rfunc): + _IterReduceWrapper = namedtuple("_IterReduceWrapper", ['v']) + + def _reduce_func(values): + _list = [] + for item in values: + if isinstance(item, _IterReduceWrapper): + _list.append(item.v) + else: + _list.append([item]) + return _IterReduceWrapper(chain(*_list)) + + def _decorator(func): + rifunc = vreduce(_reduce_func)(func) + + @wraps(func) + def _new_func(*args, **kwargs): + _iw = rifunc(*args, **kwargs) + if isinstance(_iw, _IterReduceWrapper): + return rfunc(_iw.v) + else: + return _iw + + return _new_func + + return _decorator diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index 6d34ea120..10fdc8c2f 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -1,13 +1,24 @@ -from functools import partial - import numpy as np -from treevalue import func_treelize +from treevalue import func_treelize as original_func_treelize from .numpy import TreeNumpy -from ..common import vreduce +from ..common import ireduce +from ..utils import replaceable_partial + +func_treelize = replaceable_partial(original_func_treelize, return_type=TreeNumpy) + + +@ireduce(all) +@func_treelize() +def all(a, *args, **kwargs): + return np.all(a, *args, **kwargs) + + +@func_treelize() +def equal(x1, x2, *args, **kwargs): + return np.equal(x1, x2, *args, **kwargs) -_treelize = partial(func_treelize, return_type=TreeNumpy) -all = vreduce(all)(_treelize()(np.all)) -equal = _treelize()(np.equal) -array_equal = _treelize()(np.array_equal) +@func_treelize() +def array_equal(a1, a2, *args, **kwargs): + return np.array_equal(a1, a2, *args, **kwargs) diff --git a/treetensor/numpy/numpy.py b/treetensor/numpy/numpy.py index eff7847eb..227db7e8d 100644 --- a/treetensor/numpy/numpy.py +++ b/treetensor/numpy/numpy.py @@ -1,7 +1,7 @@ import numpy as np from treevalue import method_treelize -from ..common import TreeObject, TreeData, vreduce +from ..common import TreeObject, TreeData, ireduce class TreeNumpy(TreeData): @@ -15,18 +15,18 @@ class TreeNumpy(TreeData): return self.tolist() @property - @vreduce(sum) + @ireduce(sum) @method_treelize(return_type=TreeObject) def size(self: np.ndarray) -> int: return self.size @property - @vreduce(sum) + @ireduce(sum) @method_treelize(return_type=TreeObject) def nbytes(self: np.ndarray) -> int: return self.nbytes - @vreduce(sum) + @ireduce(sum) @method_treelize(return_type=TreeObject) def sum(self: np.ndarray, *args, **kwargs): return self.sum(*args, **kwargs) diff --git a/treetensor/tensor/funcs.py b/treetensor/tensor/funcs.py index 9b83f1bbf..0a1cee923 100644 --- a/treetensor/tensor/funcs.py +++ b/treetensor/tensor/funcs.py @@ -1,52 +1,84 @@ -from functools import partial, wraps -from typing import Tuple - import torch -from treevalue import func_treelize, TreeValue - -from .tensor import TreeTensor -from ..common import vreduce - -_treelize = partial(func_treelize, return_type=TreeTensor) -_python_all = all - - -def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **kwargs_): - def _decorator(func): - @_treelize(*args_, **kwargs_) - def _sub_func(size: Tuple[int, ...], *args, **kwargs): - _size_args = (size,) if tuple_ else size - _args = (*args, *_size_args) if prefix else (*_size_args, *args) - return func(*_args, **kwargs) - - @wraps(func) - def _new_func(size, *args, **kwargs): - if isinstance(size, (TreeValue, dict)): - size = TreeTensor(size) - return _sub_func(size, *args, **kwargs) - - return _new_func - - return _decorator - - -# Tensor generation based on shapes -zeros = _size_based_treelize()(torch.zeros) -randn = _size_based_treelize()(torch.randn) -randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint) -ones = _size_based_treelize()(torch.ones) -full = _size_based_treelize(tuple_=True)(torch.full) -empty = _size_based_treelize()(torch.empty) - -# Tensor generation based on another tensor -zeros_like = _treelize()(torch.zeros_like) -randn_like = _treelize()(torch.randn_like) -randint_like = _treelize()(torch.randint_like) -ones_like = _treelize()(torch.ones_like) -full_like = _treelize()(torch.full_like) -empty_like = _treelize()(torch.empty_like) - -# Tensor operators -all = vreduce(all)(_treelize()(torch.all)) -eq = _treelize()(torch.eq) -equal = _treelize()(torch.equal) +from treevalue import func_treelize as original_func_treelize + +from .tensor import TreeTensor, _reduce_tensor_wrap +from ..common import TreeObject, ireduce +from ..utils import replaceable_partial + +func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor) + + +@func_treelize() +def zeros(size, *args, **kwargs): + return torch.zeros(*size, *args, **kwargs) + + +@func_treelize() +def zeros_like(input_, *args, **kwargs): + return torch.zeros_like(input_, *args, **kwargs) + + +@func_treelize() +def randn(size, *args, **kwargs): + return torch.randn(*size, *args, **kwargs) + + +@func_treelize() +def randn_like(input_, *args, **kwargs): + return torch.randn_like(input_, *args, **kwargs) + + +@func_treelize() +def randint(size, *args, **kwargs): + return torch.randint(*args, size, **kwargs) + + +@func_treelize() +def randint_like(input_, *args, **kwargs): + return torch.randint_like(input_, *args, **kwargs) + + +@func_treelize() +def ones(size, *args, **kwargs): + return torch.ones(*size, *args, **kwargs) + + +@func_treelize() +def ones_like(input_, *args, **kwargs): + return torch.ones_like(input_, *args, **kwargs) + + +@func_treelize() +def full(size, *args, **kwargs): + return torch.full(size, *args, **kwargs) + + +@func_treelize() +def full_like(input_, *args, **kwargs): + return torch.full_like(input_, *args, **kwargs) + + +@func_treelize() +def empty(size, *args, **kwargs): + return torch.empty(size, *args, **kwargs) + + +@func_treelize() +def empty_like(input_, *args, **kwargs): + return torch.empty_like(input_, *args, **kwargs) + + +@ireduce(_reduce_tensor_wrap(torch.all)) +@func_treelize(return_type=TreeObject) +def all(input_, *args, **kwargs): + return torch.all(input_, *args, **kwargs) + + +@func_treelize() +def eq(input_, other, *args, **kwargs): + return torch.eq(input_, other, *args, **kwargs) + + +@func_treelize() +def equal(input_, other, *args, **kwargs): + return torch.equal(input_, other, *args, **kwargs) diff --git a/treetensor/tensor/size.py b/treetensor/tensor/size.py index d52c249b1..4fd71c930 100644 --- a/treetensor/tensor/size.py +++ b/treetensor/tensor/size.py @@ -1,11 +1,14 @@ import torch -from treevalue import func_treelize +from treevalue import func_treelize as original_func_treelize -from ..common import BaseTreeStruct, TreeObject +from ..common import TreeObject +from ..utils import replaceable_partial + +func_treelize = replaceable_partial(original_func_treelize) # noinspection PyTypeChecker -class TreeSize(BaseTreeStruct): +class TreeSize(TreeObject): @func_treelize(return_type=TreeObject) def numel(self: torch.Size) -> TreeObject: return self.numel() diff --git a/treetensor/tensor/tensor.py b/treetensor/tensor/tensor.py index a25e757ce..ab2cf1a05 100644 --- a/treetensor/tensor/tensor.py +++ b/treetensor/tensor/tensor.py @@ -1,32 +1,13 @@ import numpy as np import torch -from treevalue import method_treelize, TreeValue +from treevalue import method_treelize +from treevalue.utils import pre_process from .size import TreeSize -from ..common import TreeObject, TreeData, vreduce +from ..common import TreeObject, TreeData, ireduce from ..numpy import TreeNumpy - -def _same_merge(eq, hash_, **kwargs): - kws = { - key: value for key, value in kwargs.items() - if not (isinstance(value, TreeValue) and not value) - } - - class _Wrapper: - def __init__(self, v): - self.v = v - - def __hash__(self): - return hash_(self.v) - - def __eq__(self, other): - return eq(self.v, other.v) - - if len(set(_Wrapper(v) for v in kws.values())) == 1: - return list(kws.values())[0] - else: - return TreeTensor(kws) +_reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {})) # noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList @@ -51,7 +32,7 @@ class TreeTensor(TreeData): def to(self: torch.Tensor, *args, **kwargs): return self.to(*args, **kwargs) - @vreduce(sum) + @ireduce(sum) @method_treelize(return_type=TreeObject) def numel(self: torch.Tensor): return self.numel() @@ -61,7 +42,27 @@ class TreeTensor(TreeData): def shape(self: torch.Tensor): return self.shape - @vreduce(all) + @ireduce(_reduce_tensor_wrap(torch.all)) @method_treelize(return_type=TreeObject) - def all(self: torch.Tensor, *args, **kwargs): + def all(self: torch.Tensor, *args, **kwargs) -> bool: return self.all(*args, **kwargs) + + @ireduce(_reduce_tensor_wrap(torch.any)) + @method_treelize(return_type=TreeObject) + def any(self: torch.Tensor, *args, **kwargs) -> bool: + return self.any(*args, **kwargs) + + @ireduce(_reduce_tensor_wrap(torch.max)) + @method_treelize(return_type=TreeObject) + def max(self: torch.Tensor, *args, **kwargs): + return self.max(*args, **kwargs) + + @ireduce(_reduce_tensor_wrap(torch.min)) + @method_treelize(return_type=TreeObject) + def min(self: torch.Tensor, *args, **kwargs): + return self.min(*args, **kwargs) + + @ireduce(_reduce_tensor_wrap(torch.sum)) + @method_treelize(return_type=TreeObject) + def sum(self: torch.Tensor, *args, **kwargs): + return self.sum(*args, **kwargs) diff --git a/treetensor/utils/__init__.py b/treetensor/utils/__init__.py new file mode 100644 index 000000000..171999312 --- /dev/null +++ b/treetensor/utils/__init__.py @@ -0,0 +1 @@ +from .func import replaceable_partial diff --git a/treetensor/utils/func.py b/treetensor/utils/func.py new file mode 100644 index 000000000..b7ec1dd5f --- /dev/null +++ b/treetensor/utils/func.py @@ -0,0 +1,5 @@ +def replaceable_partial(func, **kws): + def _new_func(*args, **kwargs): + return func(*args, **{**kws, **kwargs}) + + return _new_func -- GitLab