提交 daa01e40 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add kwreduce and vreduce decorator && refactor mergable and conclusion functions

上级 fb36ba81
......@@ -36,10 +36,10 @@ class TestNumpyFuncs:
})
def test__numpy_all(self):
assert not _numpy_all(self._DEMO_1 == self._DEMO_2).all()
assert _numpy_all(self._DEMO_1 == self._DEMO_3).all()
assert not _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 4])).all()
assert _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 3])).all()
assert not _numpy_all(self._DEMO_1 == self._DEMO_2)
assert _numpy_all(self._DEMO_1 == self._DEMO_3)
assert not _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 4]))
assert _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 3]))
def test_equal(self):
assert _numpy_all(
......@@ -51,7 +51,7 @@ class TestNumpyFuncs:
'd': np.array([[True, True]]),
}
})
).all()
)
assert _numpy_all(
equal(self._DEMO_1, self._DEMO_3) == TreeNumpy({
'a': np.array([[True, True, True], [True, True, True]]),
......@@ -61,7 +61,7 @@ class TestNumpyFuncs:
'd': np.array([[True, True]]),
}
})
).all()
)
def test_array_equal(self):
assert array_equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({
......
......@@ -10,7 +10,7 @@ from treetensor.tensor import all as _tensor_all
@pytest.mark.unittest
class TestTensorFuncs:
def test_zeros(self):
assert _tensor_all(zeros((2, 3)) == torch.zeros(2, 3)).all()
assert _tensor_all(zeros((2, 3)) == torch.zeros(2, 3))
assert _tensor_all(zeros({
'a': (2, 3),
'b': (5, 6),
......@@ -23,7 +23,7 @@ class TestTensorFuncs:
'x': {
'c': torch.zeros(2, 3, 4),
}
})).all()
}))
def test_zeros_like(self):
assert _tensor_all(
......@@ -46,7 +46,7 @@ class TestTensorFuncs:
'd': torch.tensor([[[0, 0]]]),
}
})
).all()
)
def test_ones(self):
assert _tensor_all(ones((2, 3)) == torch.ones(2, 3))
......@@ -62,7 +62,7 @@ class TestTensorFuncs:
'x': {
'c': torch.ones(2, 3, 4),
}
})).all()
}))
def test_ones_like(self):
assert _tensor_all(
......@@ -85,7 +85,7 @@ class TestTensorFuncs:
'd': torch.tensor([[[1, 1]]]),
}
})
).all()
)
def test_randn(self):
_target = randn((200, 300))
......@@ -139,8 +139,8 @@ class TestTensorFuncs:
'c': (2, 3, 4),
}
}, -10, 10)
assert _tensor_all(_target < 10).all()
assert _tensor_all(-10 <= _target).all()
assert _tensor_all(_target < 10)
assert _tensor_all(-10 <= _target)
assert _target.shape == TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
......@@ -156,8 +156,8 @@ class TestTensorFuncs:
'c': (2, 3, 4),
}
}, 10)
assert _tensor_all(_target < 10).all()
assert _tensor_all(0 <= _target).all()
assert _tensor_all(_target < 10)
assert _tensor_all(0 <= _target)
assert _target.shape == TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
......@@ -175,8 +175,8 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]]),
}
}), -10, 10)
assert _tensor_all(_target < 10).all()
assert _tensor_all(-10 <= _target).all()
assert _tensor_all(_target < 10)
assert _tensor_all(-10 <= _target)
assert _target.shape == TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
......@@ -194,8 +194,8 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]]),
}
}), 10)
assert _tensor_all(_target < 10).all()
assert _tensor_all(0 <= _target).all()
assert _tensor_all(_target < 10)
assert _tensor_all(0 <= _target)
assert _target.shape == TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
......@@ -213,7 +213,7 @@ class TestTensorFuncs:
'c': (2, 3, 4),
}
}, 233)
assert _tensor_all(_target == 233).all()
assert _tensor_all(_target == 233)
assert _target.shape == TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
......@@ -231,7 +231,7 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]]),
}
}), 233)
assert _tensor_all(_target == 233).all()
assert _tensor_all(_target == 233)
assert _target.shape == TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
......
......@@ -3,7 +3,6 @@ import pytest
import torch
from treevalue import func_treelize
from treetensor.common import TreeObject
from treetensor.numpy import TreeNumpy
from treetensor.numpy import all as _numpy_all
from treetensor.tensor import TreeTensor
......@@ -24,15 +23,7 @@ class TestTensorTreetensor:
})
def test_numel(self):
assert self._DEMO_1.numel() == TreeObject({
'a': 6,
'b': 4,
'x': {
'c': 4,
'd': 4,
}
})
assert self._DEMO_1.numel().sum() == 18
assert self._DEMO_1.numel() == 18
def test_numpy(self):
assert _numpy_all(self._DEMO_1.numpy() == TreeNumpy({
......@@ -42,10 +33,10 @@ class TestTensorTreetensor:
'c': np.array([3, 5, 6, 7]),
'd': np.array([[[1, 2], [8, 9]]]),
}
})).all()
}))
def test_cpu(self):
assert _tensor_all(self._DEMO_1.cpu() == self._DEMO_1).all()
assert _tensor_all(self._DEMO_1.cpu() == self._DEMO_1)
assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values()))
def test_to(self):
......@@ -56,4 +47,4 @@ class TestTensorTreetensor:
'c': torch.tensor([3, 5, 6, 7], dtype=torch.float32),
'd': torch.tensor([[[1, 2], [8, 9]]], dtype=torch.float32),
}
})).all()
}))
from .base import BaseTreeStruct
from .data import TreeData
from .obj import TreeObject
from .trees import TreeData, TreeObject, BaseTreeStruct
from .wrappers import kwreduce, vreduce
from abc import ABCMeta
from functools import lru_cache
from treevalue import general_tree_value
@lru_cache()
def _merge_func(red):
return lambda **kws: red(kws.values())
class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
def all(self) -> bool:
return self.reduce(_merge_func(all))
def any(self) -> bool:
return self.reduce(_merge_func(any))
def sum(self):
return self.reduce(_merge_func(sum))
from .base import BaseTreeStruct
class TreeObject(BaseTreeStruct):
pass
import operator
from abc import ABCMeta
from treevalue import func_treelize
from treevalue import func_treelize, general_tree_value
class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
"""
Overview:
Base structure of all the trees in ``treetensor``.
"""
pass
from .base import BaseTreeStruct
_OPERATORS = {}
for _op_name in getattr(operator, '__all__'):
......@@ -27,3 +35,7 @@ class TreeData(BaseTreeStruct):
def __ne__(self, other):
return _OPERATORS['ne'](self, other)
class TreeObject(BaseTreeStruct):
pass
from functools import wraps
from treevalue import TreeValue
from treevalue import reduce_ as treevalue_reduce
def kwreduce(reduce_func):
def _decorator(func):
@wraps(func)
def _new_func(*args, **kwargs):
_result = func(*args, **kwargs)
if isinstance(_result, TreeValue):
return treevalue_reduce(_result, reduce_func)
else:
return _result
return _new_func
return _decorator
def vreduce(vreduce_func):
return kwreduce(lambda **kws: vreduce_func(kws.values()))
......@@ -4,9 +4,10 @@ import numpy as np
from treevalue import func_treelize
from .numpy import TreeNumpy
from ..common import vreduce
_treelize = partial(func_treelize, return_type=TreeNumpy)
all = _treelize()(np.all)
all = vreduce(all)(_treelize()(np.all))
equal = _treelize()(np.equal)
array_equal = _treelize()(np.array_equal)
import numpy as np
from treevalue import method_treelize
from ..common import TreeObject, TreeData
from ..common import TreeObject, TreeData, vreduce
class TreeNumpy(TreeData):
......@@ -15,18 +15,18 @@ class TreeNumpy(TreeData):
return self.tolist()
@property
def size(self) -> int:
return self \
.map(lambda d: d.size) \
.reduce(lambda **kwargs: sum(kwargs.values()))
@vreduce(sum)
@method_treelize(return_type=TreeObject)
def size(self: np.ndarray) -> int:
return self.size
@property
def nbytes(self) -> int:
return self \
.map(lambda d: d.nbytes) \
.reduce(lambda **kwargs: sum(kwargs.values()))
@vreduce(sum)
@method_treelize(return_type=TreeObject)
def nbytes(self: np.ndarray) -> int:
return self.nbytes
def sum(self):
return self \
.map(lambda d: d.sum()) \
.reduce(lambda **kwargs: sum(kwargs.values()))
@vreduce(sum)
@method_treelize(return_type=TreeObject)
def sum(self: np.ndarray, *args, **kwargs):
return self.sum(*args, **kwargs)
......@@ -5,6 +5,7 @@ import torch
from treevalue import func_treelize, TreeValue
from .tensor import TreeTensor
from ..common import vreduce
_treelize = partial(func_treelize, return_type=TreeTensor)
_python_all = all
......@@ -46,6 +47,6 @@ full_like = _treelize()(torch.full_like)
empty_like = _treelize()(torch.empty_like)
# Tensor operators
all = _treelize()(torch.all)
all = vreduce(all)(_treelize()(torch.all))
eq = _treelize()(torch.eq)
equal = _treelize()(torch.equal)
......@@ -3,7 +3,7 @@ import torch
from treevalue import method_treelize, TreeValue
from .size import TreeSize
from ..common import TreeObject, TreeData
from ..common import TreeObject, TreeData, vreduce
from ..numpy import TreeNumpy
......@@ -29,7 +29,7 @@ def _same_merge(eq, hash_, **kwargs):
return TreeTensor(kws)
# noinspection PyTypeChecker,PyShadowingBuiltins
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
class TreeTensor(TreeData):
@method_treelize(return_type=TreeNumpy)
def numpy(self: torch.Tensor) -> np.ndarray:
......@@ -51,6 +51,7 @@ class TreeTensor(TreeData):
def to(self: torch.Tensor, *args, **kwargs):
return self.to(*args, **kwargs)
@vreduce(sum)
@method_treelize(return_type=TreeObject)
def numel(self: torch.Tensor):
return self.numel()
......@@ -59,3 +60,8 @@ class TreeTensor(TreeData):
@method_treelize(return_type=TreeSize)
def shape(self: torch.Tensor):
return self.shape
@vreduce(all)
@method_treelize(return_type=TreeObject)
def all(self: torch.Tensor, *args, **kwargs):
return self.all(*args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册