提交 c5c230d8 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): refactor the current code

上级 daa01e40
......@@ -35,13 +35,25 @@ class TestNumpyFuncs:
}
})
def test__numpy_all(self):
def test_all(self):
assert not _numpy_all(np.array([True, True, False]))
assert _numpy_all(np.array([True, True, True]))
assert not _numpy_all(self._DEMO_1 == self._DEMO_2)
assert _numpy_all(self._DEMO_1 == self._DEMO_3)
assert not _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 4]))
assert _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 3]))
def test_equal(self):
assert _numpy_all(equal(
np.array([1, 2, 3]),
np.array([1, 2, 3]),
))
assert not _numpy_all(equal(
np.array([1, 2, 3]),
np.array([1, 2, 4]),
))
assert _numpy_all(
equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({
'a': np.array([[True, True, True], [True, True, False]]),
......@@ -64,6 +76,15 @@ class TestNumpyFuncs:
)
def test_array_equal(self):
assert _numpy_all(array_equal(
np.array([1, 2, 3]),
np.array([1, 2, 3]),
))
assert not _numpy_all(array_equal(
np.array([1, 2, 3]),
np.array([1, 2, 4]),
))
assert array_equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({
'a': False,
'b': True,
......
import pytest
import torch
from treevalue import TreeValue
from treetensor.tensor import TreeTensor, zeros, zeros_like, ones, ones_like, randint, randint_like, randn, \
randn_like, full, full_like, TreeSize
......@@ -11,13 +12,13 @@ from treetensor.tensor import all as _tensor_all
class TestTensorFuncs:
def test_zeros(self):
assert _tensor_all(zeros((2, 3)) == torch.zeros(2, 3))
assert _tensor_all(zeros({
assert _tensor_all(zeros(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}) == TreeTensor({
})) == TreeTensor({
'a': torch.zeros(2, 3),
'b': torch.zeros(5, 6),
'x': {
......@@ -50,13 +51,13 @@ class TestTensorFuncs:
def test_ones(self):
assert _tensor_all(ones((2, 3)) == torch.ones(2, 3))
assert _tensor_all(ones({
assert _tensor_all(ones(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}) == TreeTensor({
})) == TreeTensor({
'a': torch.ones(2, 3),
'b': torch.ones(5, 6),
'x': {
......@@ -93,13 +94,13 @@ class TestTensorFuncs:
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300])
_target = randn({
_target = randn(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})
}))
assert _target.shape == TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
......@@ -132,13 +133,13 @@ class TestTensorFuncs:
})
def test_randint(self):
_target = randint({
_target = randint(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}, -10, 10)
}), -10, 10)
assert _tensor_all(_target < 10)
assert _tensor_all(-10 <= _target)
assert _target.shape == TreeSize({
......@@ -149,13 +150,13 @@ class TestTensorFuncs:
}
})
_target = randint({
_target = randint(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}, 10)
}), 10)
assert _tensor_all(_target < 10)
assert _tensor_all(0 <= _target)
assert _target.shape == TreeSize({
......@@ -206,13 +207,13 @@ class TestTensorFuncs:
})
def test_full(self):
_target = full({
_target = full(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}, 233)
}), 233)
assert _tensor_all(_target == 233)
assert _target.shape == TreeSize({
'a': torch.Size([2, 3]),
......@@ -240,3 +241,44 @@ class TestTensorFuncs:
'd': torch.Size([1, 1, 2]),
}
})
def test_all(self):
r1 = _tensor_all(torch.tensor([1, 1, 1]) == 1)
assert torch.is_tensor(r1)
assert r1 == torch.tensor(True)
r2 = _tensor_all(torch.tensor([1, 1, 2]) == 1)
assert torch.is_tensor(r2)
assert r2 == torch.tensor(False)
r3 = _tensor_all(TreeTensor({
'a': torch.Tensor([1, 2, 3]),
'b': torch.Tensor([4, 5, 6]),
'x': {
'c': torch.Tensor([7, 8, 9])
}
}) == TreeTensor({
'a': torch.Tensor([1, 2, 3]),
'b': torch.Tensor([4, 5, 6]),
'x': {
'c': torch.Tensor([7, 8, 9])
}
}))
assert torch.is_tensor(r3)
assert r3 == torch.tensor(True)
r4 = _tensor_all(TreeTensor({
'a': torch.Tensor([1, 2, 3]),
'b': torch.Tensor([4, 5, 6]),
'x': {
'c': torch.Tensor([7, 8, 9])
}
}) == TreeTensor({
'a': torch.Tensor([1, 2, 3]),
'b': torch.Tensor([4, 5, 6]),
'x': {
'c': torch.Tensor([7, 8, 8])
}
}))
assert torch.is_tensor(r4)
assert r4 == torch.tensor(False)
......@@ -22,6 +22,15 @@ class TestTensorTreetensor:
}
})
_DEMO_2 = TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 60]]),
'x': {
'c': torch.tensor([3, 5, 6, 7]),
'd': torch.tensor([[[1, 2], [8, 9]]]),
}
})
def test_numel(self):
assert self._DEMO_1.numel() == 18
......@@ -48,3 +57,7 @@ class TestTensorTreetensor:
'd': torch.tensor([[[1, 2], [8, 9]]], dtype=torch.float32),
}
}))
def test_all(self):
assert (self._DEMO_1 == self._DEMO_1).all()
assert not (self._DEMO_1 == self._DEMO_2).all()
from .trees import TreeData, TreeObject, BaseTreeStruct
from .wrappers import kwreduce, vreduce
from .wrappers import kwreduce, vreduce, ireduce
import operator
from abc import ABCMeta
from treevalue import func_treelize, general_tree_value
from treevalue import general_tree_value, method_treelize
class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
......@@ -12,29 +11,35 @@ class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
pass
_OPERATORS = {}
for _op_name in getattr(operator, '__all__'):
_OPERATORS[_op_name] = func_treelize()(getattr(operator, _op_name))
class TreeData(BaseTreeStruct, metaclass=ABCMeta):
"""
Overview:
In ``TreeData`` class, all the comparison operators will be override.
"""
@method_treelize()
def __eq__(self, other):
return self == other
class TreeData(BaseTreeStruct):
def __le__(self, other):
return _OPERATORS['le'](self, other)
@method_treelize()
def __ne__(self, other):
return self != other
@method_treelize()
def __lt__(self, other):
return _OPERATORS['lt'](self, other)
return self < other
def __ge__(self, other):
return _OPERATORS['ge'](self, other)
@method_treelize()
def __le__(self, other):
return self <= other
@method_treelize()
def __gt__(self, other):
return _OPERATORS['gt'](self, other)
def __eq__(self, other):
return _OPERATORS['eq'](self, other)
return self > other
def __ne__(self, other):
return _OPERATORS['ne'](self, other)
@method_treelize()
def __ge__(self, other):
return self >= other
class TreeObject(BaseTreeStruct):
......
from collections import namedtuple
from functools import wraps
from itertools import chain
from treevalue import TreeValue
from treevalue import reduce_ as treevalue_reduce
def kwreduce(reduce_func):
def kwreduce(rfunc):
def _decorator(func):
@wraps(func)
def _new_func(*args, **kwargs):
_result = func(*args, **kwargs)
if isinstance(_result, TreeValue):
return treevalue_reduce(_result, reduce_func)
return treevalue_reduce(_result, rfunc)
else:
return _result
......@@ -19,5 +21,33 @@ def kwreduce(reduce_func):
return _decorator
def vreduce(vreduce_func):
return kwreduce(lambda **kws: vreduce_func(kws.values()))
def vreduce(rfunc):
return kwreduce(lambda **kws: rfunc(kws.values()))
def ireduce(rfunc):
_IterReduceWrapper = namedtuple("_IterReduceWrapper", ['v'])
def _reduce_func(values):
_list = []
for item in values:
if isinstance(item, _IterReduceWrapper):
_list.append(item.v)
else:
_list.append([item])
return _IterReduceWrapper(chain(*_list))
def _decorator(func):
rifunc = vreduce(_reduce_func)(func)
@wraps(func)
def _new_func(*args, **kwargs):
_iw = rifunc(*args, **kwargs)
if isinstance(_iw, _IterReduceWrapper):
return rfunc(_iw.v)
else:
return _iw
return _new_func
return _decorator
from functools import partial
import numpy as np
from treevalue import func_treelize
from treevalue import func_treelize as original_func_treelize
from .numpy import TreeNumpy
from ..common import vreduce
from ..common import ireduce
from ..utils import replaceable_partial
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeNumpy)
@ireduce(all)
@func_treelize()
def all(a, *args, **kwargs):
return np.all(a, *args, **kwargs)
@func_treelize()
def equal(x1, x2, *args, **kwargs):
return np.equal(x1, x2, *args, **kwargs)
_treelize = partial(func_treelize, return_type=TreeNumpy)
all = vreduce(all)(_treelize()(np.all))
equal = _treelize()(np.equal)
array_equal = _treelize()(np.array_equal)
@func_treelize()
def array_equal(a1, a2, *args, **kwargs):
return np.array_equal(a1, a2, *args, **kwargs)
import numpy as np
from treevalue import method_treelize
from ..common import TreeObject, TreeData, vreduce
from ..common import TreeObject, TreeData, ireduce
class TreeNumpy(TreeData):
......@@ -15,18 +15,18 @@ class TreeNumpy(TreeData):
return self.tolist()
@property
@vreduce(sum)
@ireduce(sum)
@method_treelize(return_type=TreeObject)
def size(self: np.ndarray) -> int:
return self.size
@property
@vreduce(sum)
@ireduce(sum)
@method_treelize(return_type=TreeObject)
def nbytes(self: np.ndarray) -> int:
return self.nbytes
@vreduce(sum)
@ireduce(sum)
@method_treelize(return_type=TreeObject)
def sum(self: np.ndarray, *args, **kwargs):
return self.sum(*args, **kwargs)
from functools import partial, wraps
from typing import Tuple
import torch
from treevalue import func_treelize, TreeValue
from .tensor import TreeTensor
from ..common import vreduce
_treelize = partial(func_treelize, return_type=TreeTensor)
_python_all = all
def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **kwargs_):
def _decorator(func):
@_treelize(*args_, **kwargs_)
def _sub_func(size: Tuple[int, ...], *args, **kwargs):
_size_args = (size,) if tuple_ else size
_args = (*args, *_size_args) if prefix else (*_size_args, *args)
return func(*_args, **kwargs)
@wraps(func)
def _new_func(size, *args, **kwargs):
if isinstance(size, (TreeValue, dict)):
size = TreeTensor(size)
return _sub_func(size, *args, **kwargs)
return _new_func
return _decorator
# Tensor generation based on shapes
zeros = _size_based_treelize()(torch.zeros)
randn = _size_based_treelize()(torch.randn)
randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint)
ones = _size_based_treelize()(torch.ones)
full = _size_based_treelize(tuple_=True)(torch.full)
empty = _size_based_treelize()(torch.empty)
# Tensor generation based on another tensor
zeros_like = _treelize()(torch.zeros_like)
randn_like = _treelize()(torch.randn_like)
randint_like = _treelize()(torch.randint_like)
ones_like = _treelize()(torch.ones_like)
full_like = _treelize()(torch.full_like)
empty_like = _treelize()(torch.empty_like)
# Tensor operators
all = vreduce(all)(_treelize()(torch.all))
eq = _treelize()(torch.eq)
equal = _treelize()(torch.equal)
from treevalue import func_treelize as original_func_treelize
from .tensor import TreeTensor, _reduce_tensor_wrap
from ..common import TreeObject, ireduce
from ..utils import replaceable_partial
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor)
@func_treelize()
def zeros(size, *args, **kwargs):
return torch.zeros(*size, *args, **kwargs)
@func_treelize()
def zeros_like(input_, *args, **kwargs):
return torch.zeros_like(input_, *args, **kwargs)
@func_treelize()
def randn(size, *args, **kwargs):
return torch.randn(*size, *args, **kwargs)
@func_treelize()
def randn_like(input_, *args, **kwargs):
return torch.randn_like(input_, *args, **kwargs)
@func_treelize()
def randint(size, *args, **kwargs):
return torch.randint(*args, size, **kwargs)
@func_treelize()
def randint_like(input_, *args, **kwargs):
return torch.randint_like(input_, *args, **kwargs)
@func_treelize()
def ones(size, *args, **kwargs):
return torch.ones(*size, *args, **kwargs)
@func_treelize()
def ones_like(input_, *args, **kwargs):
return torch.ones_like(input_, *args, **kwargs)
@func_treelize()
def full(size, *args, **kwargs):
return torch.full(size, *args, **kwargs)
@func_treelize()
def full_like(input_, *args, **kwargs):
return torch.full_like(input_, *args, **kwargs)
@func_treelize()
def empty(size, *args, **kwargs):
return torch.empty(size, *args, **kwargs)
@func_treelize()
def empty_like(input_, *args, **kwargs):
return torch.empty_like(input_, *args, **kwargs)
@ireduce(_reduce_tensor_wrap(torch.all))
@func_treelize(return_type=TreeObject)
def all(input_, *args, **kwargs):
return torch.all(input_, *args, **kwargs)
@func_treelize()
def eq(input_, other, *args, **kwargs):
return torch.eq(input_, other, *args, **kwargs)
@func_treelize()
def equal(input_, other, *args, **kwargs):
return torch.equal(input_, other, *args, **kwargs)
import torch
from treevalue import func_treelize
from treevalue import func_treelize as original_func_treelize
from ..common import BaseTreeStruct, TreeObject
from ..common import TreeObject
from ..utils import replaceable_partial
func_treelize = replaceable_partial(original_func_treelize)
# noinspection PyTypeChecker
class TreeSize(BaseTreeStruct):
class TreeSize(TreeObject):
@func_treelize(return_type=TreeObject)
def numel(self: torch.Size) -> TreeObject:
return self.numel()
......
import numpy as np
import torch
from treevalue import method_treelize, TreeValue
from treevalue import method_treelize
from treevalue.utils import pre_process
from .size import TreeSize
from ..common import TreeObject, TreeData, vreduce
from ..common import TreeObject, TreeData, ireduce
from ..numpy import TreeNumpy
def _same_merge(eq, hash_, **kwargs):
kws = {
key: value for key, value in kwargs.items()
if not (isinstance(value, TreeValue) and not value)
}
class _Wrapper:
def __init__(self, v):
self.v = v
def __hash__(self):
return hash_(self.v)
def __eq__(self, other):
return eq(self.v, other.v)
if len(set(_Wrapper(v) for v in kws.values())) == 1:
return list(kws.values())[0]
else:
return TreeTensor(kws)
_reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
......@@ -51,7 +32,7 @@ class TreeTensor(TreeData):
def to(self: torch.Tensor, *args, **kwargs):
return self.to(*args, **kwargs)
@vreduce(sum)
@ireduce(sum)
@method_treelize(return_type=TreeObject)
def numel(self: torch.Tensor):
return self.numel()
......@@ -61,7 +42,27 @@ class TreeTensor(TreeData):
def shape(self: torch.Tensor):
return self.shape
@vreduce(all)
@ireduce(_reduce_tensor_wrap(torch.all))
@method_treelize(return_type=TreeObject)
def all(self: torch.Tensor, *args, **kwargs):
def all(self: torch.Tensor, *args, **kwargs) -> bool:
return self.all(*args, **kwargs)
@ireduce(_reduce_tensor_wrap(torch.any))
@method_treelize(return_type=TreeObject)
def any(self: torch.Tensor, *args, **kwargs) -> bool:
return self.any(*args, **kwargs)
@ireduce(_reduce_tensor_wrap(torch.max))
@method_treelize(return_type=TreeObject)
def max(self: torch.Tensor, *args, **kwargs):
return self.max(*args, **kwargs)
@ireduce(_reduce_tensor_wrap(torch.min))
@method_treelize(return_type=TreeObject)
def min(self: torch.Tensor, *args, **kwargs):
return self.min(*args, **kwargs)
@ireduce(_reduce_tensor_wrap(torch.sum))
@method_treelize(return_type=TreeObject)
def sum(self: torch.Tensor, *args, **kwargs):
return self.sum(*args, **kwargs)
from .func import replaceable_partial
def replaceable_partial(func, **kws):
def _new_func(*args, **kwargs):
return func(*args, **{**kws, **kwargs})
return _new_func
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册