提交 fbfdb128 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add ttorch.tensor

上级 eba7817e
from .test_funcs import TestTensorFuncs
from .test_treetensor import TestTensorTreetensor
from .test_funcs import TestTorchFuncs
from .test_tensor import TestTorchTensor
......@@ -7,7 +7,34 @@ import treetensor.torch as ttorch
# noinspection DuplicatedCode
@pytest.mark.unittest
class TestTensorFuncs:
class TestTorchFuncs:
def test_tensor(self):
t1 = ttorch.tensor(True)
assert isinstance(t1, torch.Tensor)
assert t1
t2 = ttorch.tensor([[1, 2, 3], [4, 5, 6]])
assert isinstance(t2, torch.Tensor)
assert (t2 == torch.tensor([[1, 2, 3], [4, 5, 6]])).all()
t3 = ttorch.tensor({
'a': [1, 2],
'b': [[3, 4], [5, 6.2]],
'x': {
'c': True,
'd': [False, True],
}
})
assert isinstance(t3, ttorch.Tensor)
assert (t3 == ttorch.Tensor({
'a': torch.tensor([1, 2]),
'b': torch.tensor([[3, 4], [5, 6.2]]),
'x': {
'c': torch.tensor(True),
'd': torch.tensor([False, True]),
}
})).all()
def test_zeros(self):
assert ttorch.all(ttorch.zeros((2, 3)) == torch.zeros(2, 3))
assert ttorch.all(ttorch.zeros(TreeValue({
......@@ -16,7 +43,7 @@ class TestTensorFuncs:
'x': {
'c': (2, 3, 4),
}
})) == ttorch.TreeTensor({
})) == ttorch.Tensor({
'a': torch.zeros(2, 3),
'b': torch.zeros(5, 6),
'x': {
......@@ -30,14 +57,14 @@ class TestTensorFuncs:
torch.tensor([[0, 0, 0], [0, 0, 0]]),
)
assert ttorch.all(
ttorch.zeros_like(ttorch.TreeTensor({
ttorch.zeros_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
})) == ttorch.TreeTensor({
})) == ttorch.Tensor({
'a': torch.tensor([[0, 0, 0], [0, 0, 0]]),
'b': torch.tensor([0, 0, 0, 0]),
'x': {
......@@ -55,7 +82,7 @@ class TestTensorFuncs:
'x': {
'c': (2, 3, 4),
}
})) == ttorch.TreeTensor({
})) == ttorch.Tensor({
'a': torch.ones(2, 3),
'b': torch.ones(5, 6),
'x': {
......@@ -69,14 +96,14 @@ class TestTensorFuncs:
torch.tensor([[1, 1, 1], [1, 1, 1]])
)
assert ttorch.all(
ttorch.ones_like(ttorch.TreeTensor({
ttorch.ones_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
})) == ttorch.TreeTensor({
})) == ttorch.Tensor({
'a': torch.tensor([[1, 1, 1], [1, 1, 1]]),
'b': torch.tensor([1, 1, 1, 1]),
'x': {
......@@ -99,7 +126,7 @@ class TestTensorFuncs:
'c': (2, 3, 4),
}
}))
assert _target.shape == ttorch.TreeSize({
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
......@@ -113,7 +140,7 @@ class TestTensorFuncs:
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300])
_target = ttorch.randn_like(ttorch.TreeTensor({
_target = ttorch.randn_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32),
'x': {
......@@ -121,7 +148,7 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]], dtype=torch.float32),
}
}))
assert _target.shape == ttorch.TreeSize({
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
......@@ -140,7 +167,7 @@ class TestTensorFuncs:
}))
assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.TreeSize({
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
......@@ -157,7 +184,7 @@ class TestTensorFuncs:
}))
assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.TreeSize({
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
......@@ -166,7 +193,7 @@ class TestTensorFuncs:
})
def test_randint_like(self):
_target = ttorch.randint_like(ttorch.TreeTensor({
_target = ttorch.randint_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
......@@ -176,7 +203,7 @@ class TestTensorFuncs:
}), -10, 10)
assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.TreeSize({
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
......@@ -185,7 +212,7 @@ class TestTensorFuncs:
}
})
_target = ttorch.randint_like(ttorch.TreeTensor({
_target = ttorch.randint_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
......@@ -195,7 +222,7 @@ class TestTensorFuncs:
}), 10)
assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.TreeSize({
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
......@@ -213,7 +240,7 @@ class TestTensorFuncs:
}
}), 233)
assert ttorch.all(_target == 233)
assert _target.shape == ttorch.TreeSize({
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
......@@ -222,7 +249,7 @@ class TestTensorFuncs:
})
def test_full_like(self):
_target = ttorch.full_like(ttorch.TreeTensor({
_target = ttorch.full_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
......@@ -231,7 +258,7 @@ class TestTensorFuncs:
}
}), 233)
assert ttorch.all(_target == 233)
assert _target.shape == ttorch.TreeSize({
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
......@@ -248,7 +275,7 @@ class TestTensorFuncs:
'c': (2, 3, 4),
}
}))
assert _target.shape == ttorch.TreeSize({
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
......@@ -257,7 +284,7 @@ class TestTensorFuncs:
})
def test_empty_like(self):
_target = ttorch.empty_like(ttorch.TreeTensor({
_target = ttorch.empty_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
......@@ -265,7 +292,7 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]]),
}
}))
assert _target.shape == ttorch.TreeSize({
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
......@@ -290,7 +317,7 @@ class TestTensorFuncs:
assert r3 == torch.tensor(False)
assert not r3
r4 = ttorch.all(ttorch.TreeTensor({
r4 = ttorch.all(({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
})).all()
......@@ -298,7 +325,7 @@ class TestTensorFuncs:
assert r4 == torch.tensor(True)
assert r4
r5 = ttorch.all(ttorch.TreeTensor({
r5 = ttorch.all(({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
})).all()
......@@ -306,7 +333,7 @@ class TestTensorFuncs:
assert r5 == torch.tensor(False)
assert not r5
r6 = ttorch.all(ttorch.TreeTensor({
r6 = ttorch.all(({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
})).all()
......@@ -330,7 +357,7 @@ class TestTensorFuncs:
assert r3 == torch.tensor(False)
assert not r3
r4 = ttorch.any(ttorch.TreeTensor({
r4 = ttorch.any(({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
})).all()
......@@ -338,7 +365,7 @@ class TestTensorFuncs:
assert r4 == torch.tensor(True)
assert r4
r5 = ttorch.any(ttorch.TreeTensor({
r5 = ttorch.any(({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
})).all()
......@@ -346,7 +373,7 @@ class TestTensorFuncs:
assert r5 == torch.tensor(True)
assert r5
r6 = ttorch.any(ttorch.TreeTensor({
r6 = ttorch.any(({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
})).all()
......@@ -360,17 +387,17 @@ class TestTensorFuncs:
assert ttorch.eq(torch.tensor([1, 1, 1]), 1).all()
assert not ttorch.eq(torch.tensor([1, 1, 2]), 1).all()
assert ttorch.eq(ttorch.TreeTensor({
assert ttorch.eq(({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({
}), ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
})).all()
assert not ttorch.eq(ttorch.TreeTensor({
assert not ttorch.eq(({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({
}), ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]),
})).all()
......@@ -384,20 +411,20 @@ class TestTensorFuncs:
assert isinstance(p2, bool)
assert not p2
p3 = ttorch.equal(ttorch.TreeTensor({
p3 = ttorch.equal(({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({
}), ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}))
assert isinstance(p3, bool)
assert p3
p4 = ttorch.equal(ttorch.TreeTensor({
p4 = ttorch.equal(({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({
}), ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]),
}))
......
......@@ -6,12 +6,12 @@ from treevalue import func_treelize
import treetensor.numpy as tnp
import treetensor.torch as ttorch
_all_is = func_treelize(return_type=ttorch.TreeTensor)(lambda x, y: x is y)
_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
@pytest.mark.unittest
class TestTensorTreetensor:
_DEMO_1 = ttorch.TreeTensor({
class TestTorchTensor:
_DEMO_1 = ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 6]]),
'x': {
......@@ -20,7 +20,7 @@ class TestTensorTreetensor:
}
})
_DEMO_2 = ttorch.TreeTensor({
_DEMO_2 = ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 60]]),
'x': {
......@@ -47,7 +47,7 @@ class TestTensorTreetensor:
assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values()))
def test_to(self):
assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.TreeTensor({
assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32),
'x': {
......
from .common import TreeObject
from .numpy import TreeNumpy
from .torch import TreeTensor
from .torch import Tensor
import builtins
import torch
from treevalue import TreeValue
from treevalue import func_treelize as original_func_treelize
from treevalue.utils import post_process
from .tensor import TreeTensor, tireduce
from .tensor import Tensor, tireduce
from ..common import TreeObject, ireduce
from ..utils import replaceable_partial, doc_from
from ..utils import replaceable_partial, doc_from, args_mapping
__all__ = [
'zeros', 'zeros_like',
......@@ -16,9 +18,13 @@ __all__ = [
'empty', 'empty_like',
'all', 'any',
'eq', 'equal',
'tensor',
]
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor)
func_treelize = post_process(post_process(args_mapping(
lambda i, x: Tensor(x) if isinstance(x, (dict, TreeValue)) else x)))(
replaceable_partial(original_func_treelize, return_type=Tensor)
)
@doc_from(torch.zeros)
......@@ -102,18 +108,20 @@ def all(input_, *args, **kwargs):
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> all(torch.tensor([True, True])) # the same as torch.all
torch.tensor(True)
>>> all(TreeTensor({
>>> 'a': torch.tensor([True, True]),
>>> 'b': torch.tensor([True, True]),
>>> all(ttorch.tensor({
>>> 'a': [True, True],
>>> 'b': [True, True],
>>> }))
torch.tensor(True)
>>> all(TreeTensor({
>>> 'a': torch.tensor([True, True]),
>>> 'b': torch.tensor([True, False]),
>>> all(Tensor({
>>> 'a': [True, True],
>>> 'b': [True, False],
>>> }))
torch.tensor(False)
......@@ -139,3 +147,30 @@ def eq(input_, other, *args, **kwargs):
@func_treelize()
def equal(input_, other, *args, **kwargs):
return torch.equal(input_, other, *args, **kwargs)
@doc_from(torch.tensor)
@func_treelize()
def tensor(*args, **kwargs):
"""
In ``treetensor``, you can create a tree tensor with simple data structure.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.tensor(True) # the same as torch.tensor(True)
torch.tensor(True)
>>> ttorch.tensor([1, 2, 3]) # the same as torch.tensor([1, 2, 3])
torch.tensor([1, 2, 3])
>>> ttorch.tensor({'a': 1, 'b': [1, 2, 3], 'c': [[True, False], [False, True]]})
ttorch.Tensor({
'a': torch.tensor(1),
'b': torch.tensor([1, 2, 3]),
'c': torch.tensor([[True, False], [False, True]]),
})
"""
return torch.tensor(*args, **kwargs)
......@@ -7,12 +7,12 @@ from ..utils import replaceable_partial
func_treelize = replaceable_partial(original_func_treelize)
__all__ = [
'TreeSize'
'Size'
]
# noinspection PyTypeChecker
class TreeSize(TreeObject):
class Size(TreeObject):
@func_treelize(return_type=TreeObject)
def numel(self: torch.Size) -> TreeObject:
return self.numel()
......
......@@ -3,13 +3,13 @@ import torch
from treevalue import method_treelize
from treevalue.utils import pre_process
from .size import TreeSize
from .size import Size
from ..common import TreeObject, TreeData, ireduce
from ..numpy import TreeNumpy
from ..utils import inherit_names, current_names, doc_from
__all__ = [
'TreeTensor'
'Tensor'
]
_reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
......@@ -19,7 +19,7 @@ tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduc
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
@current_names()
@inherit_names(TreeData)
class TreeTensor(TreeData):
class Tensor(TreeData):
@doc_from(torch.Tensor.numpy)
@method_treelize(return_type=TreeNumpy)
def numpy(self: torch.Tensor) -> np.ndarray:
......@@ -53,7 +53,7 @@ class TreeTensor(TreeData):
@property
@doc_from(torch.Tensor.shape)
@method_treelize(return_type=TreeSize)
@method_treelize(return_type=Size)
def shape(self: torch.Tensor):
return self.shape
......
from functools import wraps
from typing import Callable, Union, Any
__all__ = [
'replaceable_partial',
'args_mapping',
]
def replaceable_partial(func, **kws):
@wraps(func)
def _new_func(*args, **kwargs):
return func(*args, **{**kws, **kwargs})
return _new_func
def args_mapping(mapper: Callable[[Union[int, str], Any], Any]):
def _decorator(func):
@wraps(func)
def _new_func(*args, **kwargs):
return func(
*(mapper(i, x) for i, x in enumerate(args)),
**{k: mapper(k, v) for k, v in kwargs.items()},
)
return _new_func
return _decorator
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册