提交 fbfdb128 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add ttorch.tensor

上级 eba7817e
from .test_funcs import TestTensorFuncs from .test_funcs import TestTorchFuncs
from .test_treetensor import TestTensorTreetensor from .test_tensor import TestTorchTensor
...@@ -7,7 +7,34 @@ import treetensor.torch as ttorch ...@@ -7,7 +7,34 @@ import treetensor.torch as ttorch
# noinspection DuplicatedCode # noinspection DuplicatedCode
@pytest.mark.unittest @pytest.mark.unittest
class TestTensorFuncs: class TestTorchFuncs:
def test_tensor(self):
t1 = ttorch.tensor(True)
assert isinstance(t1, torch.Tensor)
assert t1
t2 = ttorch.tensor([[1, 2, 3], [4, 5, 6]])
assert isinstance(t2, torch.Tensor)
assert (t2 == torch.tensor([[1, 2, 3], [4, 5, 6]])).all()
t3 = ttorch.tensor({
'a': [1, 2],
'b': [[3, 4], [5, 6.2]],
'x': {
'c': True,
'd': [False, True],
}
})
assert isinstance(t3, ttorch.Tensor)
assert (t3 == ttorch.Tensor({
'a': torch.tensor([1, 2]),
'b': torch.tensor([[3, 4], [5, 6.2]]),
'x': {
'c': torch.tensor(True),
'd': torch.tensor([False, True]),
}
})).all()
def test_zeros(self): def test_zeros(self):
assert ttorch.all(ttorch.zeros((2, 3)) == torch.zeros(2, 3)) assert ttorch.all(ttorch.zeros((2, 3)) == torch.zeros(2, 3))
assert ttorch.all(ttorch.zeros(TreeValue({ assert ttorch.all(ttorch.zeros(TreeValue({
...@@ -16,7 +43,7 @@ class TestTensorFuncs: ...@@ -16,7 +43,7 @@ class TestTensorFuncs:
'x': { 'x': {
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
})) == ttorch.TreeTensor({ })) == ttorch.Tensor({
'a': torch.zeros(2, 3), 'a': torch.zeros(2, 3),
'b': torch.zeros(5, 6), 'b': torch.zeros(5, 6),
'x': { 'x': {
...@@ -30,14 +57,14 @@ class TestTensorFuncs: ...@@ -30,14 +57,14 @@ class TestTensorFuncs:
torch.tensor([[0, 0, 0], [0, 0, 0]]), torch.tensor([[0, 0, 0], [0, 0, 0]]),
) )
assert ttorch.all( assert ttorch.all(
ttorch.zeros_like(ttorch.TreeTensor({ ttorch.zeros_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
'c': torch.tensor([5, 6, 7]), 'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]), 'd': torch.tensor([[[8, 9]]]),
} }
})) == ttorch.TreeTensor({ })) == ttorch.Tensor({
'a': torch.tensor([[0, 0, 0], [0, 0, 0]]), 'a': torch.tensor([[0, 0, 0], [0, 0, 0]]),
'b': torch.tensor([0, 0, 0, 0]), 'b': torch.tensor([0, 0, 0, 0]),
'x': { 'x': {
...@@ -55,7 +82,7 @@ class TestTensorFuncs: ...@@ -55,7 +82,7 @@ class TestTensorFuncs:
'x': { 'x': {
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
})) == ttorch.TreeTensor({ })) == ttorch.Tensor({
'a': torch.ones(2, 3), 'a': torch.ones(2, 3),
'b': torch.ones(5, 6), 'b': torch.ones(5, 6),
'x': { 'x': {
...@@ -69,14 +96,14 @@ class TestTensorFuncs: ...@@ -69,14 +96,14 @@ class TestTensorFuncs:
torch.tensor([[1, 1, 1], [1, 1, 1]]) torch.tensor([[1, 1, 1], [1, 1, 1]])
) )
assert ttorch.all( assert ttorch.all(
ttorch.ones_like(ttorch.TreeTensor({ ttorch.ones_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
'c': torch.tensor([5, 6, 7]), 'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]), 'd': torch.tensor([[[8, 9]]]),
} }
})) == ttorch.TreeTensor({ })) == ttorch.Tensor({
'a': torch.tensor([[1, 1, 1], [1, 1, 1]]), 'a': torch.tensor([[1, 1, 1], [1, 1, 1]]),
'b': torch.tensor([1, 1, 1, 1]), 'b': torch.tensor([1, 1, 1, 1]),
'x': { 'x': {
...@@ -99,7 +126,7 @@ class TestTensorFuncs: ...@@ -99,7 +126,7 @@ class TestTensorFuncs:
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
})) }))
assert _target.shape == ttorch.TreeSize({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]), 'b': torch.Size([5, 6]),
'x': { 'x': {
...@@ -113,7 +140,7 @@ class TestTensorFuncs: ...@@ -113,7 +140,7 @@ class TestTensorFuncs:
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02 assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300]) assert _target.shape == torch.Size([200, 300])
_target = ttorch.randn_like(ttorch.TreeTensor({ _target = ttorch.randn_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32), 'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32),
'x': { 'x': {
...@@ -121,7 +148,7 @@ class TestTensorFuncs: ...@@ -121,7 +148,7 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]], dtype=torch.float32), 'd': torch.tensor([[[8, 9]]], dtype=torch.float32),
} }
})) }))
assert _target.shape == ttorch.TreeSize({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([4]), 'b': torch.Size([4]),
'x': { 'x': {
...@@ -140,7 +167,7 @@ class TestTensorFuncs: ...@@ -140,7 +167,7 @@ class TestTensorFuncs:
})) }))
assert ttorch.all(_target < 10) assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target) assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.TreeSize({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]), 'b': torch.Size([5, 6]),
'x': { 'x': {
...@@ -157,7 +184,7 @@ class TestTensorFuncs: ...@@ -157,7 +184,7 @@ class TestTensorFuncs:
})) }))
assert ttorch.all(_target < 10) assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target) assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.TreeSize({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]), 'b': torch.Size([5, 6]),
'x': { 'x': {
...@@ -166,7 +193,7 @@ class TestTensorFuncs: ...@@ -166,7 +193,7 @@ class TestTensorFuncs:
}) })
def test_randint_like(self): def test_randint_like(self):
_target = ttorch.randint_like(ttorch.TreeTensor({ _target = ttorch.randint_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
...@@ -176,7 +203,7 @@ class TestTensorFuncs: ...@@ -176,7 +203,7 @@ class TestTensorFuncs:
}), -10, 10) }), -10, 10)
assert ttorch.all(_target < 10) assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target) assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.TreeSize({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([4]), 'b': torch.Size([4]),
'x': { 'x': {
...@@ -185,7 +212,7 @@ class TestTensorFuncs: ...@@ -185,7 +212,7 @@ class TestTensorFuncs:
} }
}) })
_target = ttorch.randint_like(ttorch.TreeTensor({ _target = ttorch.randint_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
...@@ -195,7 +222,7 @@ class TestTensorFuncs: ...@@ -195,7 +222,7 @@ class TestTensorFuncs:
}), 10) }), 10)
assert ttorch.all(_target < 10) assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target) assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.TreeSize({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([4]), 'b': torch.Size([4]),
'x': { 'x': {
...@@ -213,7 +240,7 @@ class TestTensorFuncs: ...@@ -213,7 +240,7 @@ class TestTensorFuncs:
} }
}), 233) }), 233)
assert ttorch.all(_target == 233) assert ttorch.all(_target == 233)
assert _target.shape == ttorch.TreeSize({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]), 'b': torch.Size([5, 6]),
'x': { 'x': {
...@@ -222,7 +249,7 @@ class TestTensorFuncs: ...@@ -222,7 +249,7 @@ class TestTensorFuncs:
}) })
def test_full_like(self): def test_full_like(self):
_target = ttorch.full_like(ttorch.TreeTensor({ _target = ttorch.full_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
...@@ -231,7 +258,7 @@ class TestTensorFuncs: ...@@ -231,7 +258,7 @@ class TestTensorFuncs:
} }
}), 233) }), 233)
assert ttorch.all(_target == 233) assert ttorch.all(_target == 233)
assert _target.shape == ttorch.TreeSize({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([4]), 'b': torch.Size([4]),
'x': { 'x': {
...@@ -248,7 +275,7 @@ class TestTensorFuncs: ...@@ -248,7 +275,7 @@ class TestTensorFuncs:
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
})) }))
assert _target.shape == ttorch.TreeSize({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]), 'b': torch.Size([5, 6]),
'x': { 'x': {
...@@ -257,7 +284,7 @@ class TestTensorFuncs: ...@@ -257,7 +284,7 @@ class TestTensorFuncs:
}) })
def test_empty_like(self): def test_empty_like(self):
_target = ttorch.empty_like(ttorch.TreeTensor({ _target = ttorch.empty_like(({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
...@@ -265,7 +292,7 @@ class TestTensorFuncs: ...@@ -265,7 +292,7 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]]), 'd': torch.tensor([[[8, 9]]]),
} }
})) }))
assert _target.shape == ttorch.TreeSize({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([4]), 'b': torch.Size([4]),
'x': { 'x': {
...@@ -290,7 +317,7 @@ class TestTensorFuncs: ...@@ -290,7 +317,7 @@ class TestTensorFuncs:
assert r3 == torch.tensor(False) assert r3 == torch.tensor(False)
assert not r3 assert not r3
r4 = ttorch.all(ttorch.TreeTensor({ r4 = ttorch.all(({
'a': torch.tensor([True, True, True]), 'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, True]),
})).all() })).all()
...@@ -298,7 +325,7 @@ class TestTensorFuncs: ...@@ -298,7 +325,7 @@ class TestTensorFuncs:
assert r4 == torch.tensor(True) assert r4 == torch.tensor(True)
assert r4 assert r4
r5 = ttorch.all(ttorch.TreeTensor({ r5 = ttorch.all(({
'a': torch.tensor([True, True, True]), 'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]), 'b': torch.tensor([True, True, False]),
})).all() })).all()
...@@ -306,7 +333,7 @@ class TestTensorFuncs: ...@@ -306,7 +333,7 @@ class TestTensorFuncs:
assert r5 == torch.tensor(False) assert r5 == torch.tensor(False)
assert not r5 assert not r5
r6 = ttorch.all(ttorch.TreeTensor({ r6 = ttorch.all(({
'a': torch.tensor([False, False, False]), 'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]), 'b': torch.tensor([False, False, False]),
})).all() })).all()
...@@ -330,7 +357,7 @@ class TestTensorFuncs: ...@@ -330,7 +357,7 @@ class TestTensorFuncs:
assert r3 == torch.tensor(False) assert r3 == torch.tensor(False)
assert not r3 assert not r3
r4 = ttorch.any(ttorch.TreeTensor({ r4 = ttorch.any(({
'a': torch.tensor([True, True, True]), 'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, True]),
})).all() })).all()
...@@ -338,7 +365,7 @@ class TestTensorFuncs: ...@@ -338,7 +365,7 @@ class TestTensorFuncs:
assert r4 == torch.tensor(True) assert r4 == torch.tensor(True)
assert r4 assert r4
r5 = ttorch.any(ttorch.TreeTensor({ r5 = ttorch.any(({
'a': torch.tensor([True, True, True]), 'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]), 'b': torch.tensor([True, True, False]),
})).all() })).all()
...@@ -346,7 +373,7 @@ class TestTensorFuncs: ...@@ -346,7 +373,7 @@ class TestTensorFuncs:
assert r5 == torch.tensor(True) assert r5 == torch.tensor(True)
assert r5 assert r5
r6 = ttorch.any(ttorch.TreeTensor({ r6 = ttorch.any(({
'a': torch.tensor([False, False, False]), 'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]), 'b': torch.tensor([False, False, False]),
})).all() })).all()
...@@ -360,17 +387,17 @@ class TestTensorFuncs: ...@@ -360,17 +387,17 @@ class TestTensorFuncs:
assert ttorch.eq(torch.tensor([1, 1, 1]), 1).all() assert ttorch.eq(torch.tensor([1, 1, 1]), 1).all()
assert not ttorch.eq(torch.tensor([1, 1, 2]), 1).all() assert not ttorch.eq(torch.tensor([1, 1, 2]), 1).all()
assert ttorch.eq(ttorch.TreeTensor({ assert ttorch.eq(({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({ }), ({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
})).all() })).all()
assert not ttorch.eq(ttorch.TreeTensor({ assert not ttorch.eq(({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({ }), ({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]), 'b': torch.tensor([4, 5, 5]),
})).all() })).all()
...@@ -384,20 +411,20 @@ class TestTensorFuncs: ...@@ -384,20 +411,20 @@ class TestTensorFuncs:
assert isinstance(p2, bool) assert isinstance(p2, bool)
assert not p2 assert not p2
p3 = ttorch.equal(ttorch.TreeTensor({ p3 = ttorch.equal(({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({ }), ({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
})) }))
assert isinstance(p3, bool) assert isinstance(p3, bool)
assert p3 assert p3
p4 = ttorch.equal(ttorch.TreeTensor({ p4 = ttorch.equal(({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({ }), ({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]), 'b': torch.tensor([4, 5, 5]),
})) }))
......
...@@ -6,12 +6,12 @@ from treevalue import func_treelize ...@@ -6,12 +6,12 @@ from treevalue import func_treelize
import treetensor.numpy as tnp import treetensor.numpy as tnp
import treetensor.torch as ttorch import treetensor.torch as ttorch
_all_is = func_treelize(return_type=ttorch.TreeTensor)(lambda x, y: x is y) _all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
@pytest.mark.unittest @pytest.mark.unittest
class TestTensorTreetensor: class TestTorchTensor:
_DEMO_1 = ttorch.TreeTensor({ _DEMO_1 = ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 6]]), 'b': torch.tensor([[1, 2], [5, 6]]),
'x': { 'x': {
...@@ -20,7 +20,7 @@ class TestTensorTreetensor: ...@@ -20,7 +20,7 @@ class TestTensorTreetensor:
} }
}) })
_DEMO_2 = ttorch.TreeTensor({ _DEMO_2 = ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 60]]), 'b': torch.tensor([[1, 2], [5, 60]]),
'x': { 'x': {
...@@ -47,7 +47,7 @@ class TestTensorTreetensor: ...@@ -47,7 +47,7 @@ class TestTensorTreetensor:
assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values())) assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values()))
def test_to(self): def test_to(self):
assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.TreeTensor({ assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32), 'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32),
'x': { 'x': {
......
from .common import TreeObject from .common import TreeObject
from .numpy import TreeNumpy from .numpy import TreeNumpy
from .torch import TreeTensor from .torch import Tensor
import builtins import builtins
import torch import torch
from treevalue import TreeValue
from treevalue import func_treelize as original_func_treelize from treevalue import func_treelize as original_func_treelize
from treevalue.utils import post_process
from .tensor import TreeTensor, tireduce from .tensor import Tensor, tireduce
from ..common import TreeObject, ireduce from ..common import TreeObject, ireduce
from ..utils import replaceable_partial, doc_from from ..utils import replaceable_partial, doc_from, args_mapping
__all__ = [ __all__ = [
'zeros', 'zeros_like', 'zeros', 'zeros_like',
...@@ -16,9 +18,13 @@ __all__ = [ ...@@ -16,9 +18,13 @@ __all__ = [
'empty', 'empty_like', 'empty', 'empty_like',
'all', 'any', 'all', 'any',
'eq', 'equal', 'eq', 'equal',
'tensor',
] ]
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor) func_treelize = post_process(post_process(args_mapping(
lambda i, x: Tensor(x) if isinstance(x, (dict, TreeValue)) else x)))(
replaceable_partial(original_func_treelize, return_type=Tensor)
)
@doc_from(torch.zeros) @doc_from(torch.zeros)
...@@ -102,18 +108,20 @@ def all(input_, *args, **kwargs): ...@@ -102,18 +108,20 @@ def all(input_, *args, **kwargs):
Example:: Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> all(torch.tensor([True, True])) # the same as torch.all >>> all(torch.tensor([True, True])) # the same as torch.all
torch.tensor(True) torch.tensor(True)
>>> all(TreeTensor({ >>> all(ttorch.tensor({
>>> 'a': torch.tensor([True, True]), >>> 'a': [True, True],
>>> 'b': torch.tensor([True, True]), >>> 'b': [True, True],
>>> })) >>> }))
torch.tensor(True) torch.tensor(True)
>>> all(TreeTensor({ >>> all(Tensor({
>>> 'a': torch.tensor([True, True]), >>> 'a': [True, True],
>>> 'b': torch.tensor([True, False]), >>> 'b': [True, False],
>>> })) >>> }))
torch.tensor(False) torch.tensor(False)
...@@ -139,3 +147,30 @@ def eq(input_, other, *args, **kwargs): ...@@ -139,3 +147,30 @@ def eq(input_, other, *args, **kwargs):
@func_treelize() @func_treelize()
def equal(input_, other, *args, **kwargs): def equal(input_, other, *args, **kwargs):
return torch.equal(input_, other, *args, **kwargs) return torch.equal(input_, other, *args, **kwargs)
@doc_from(torch.tensor)
@func_treelize()
def tensor(*args, **kwargs):
"""
In ``treetensor``, you can create a tree tensor with simple data structure.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.tensor(True) # the same as torch.tensor(True)
torch.tensor(True)
>>> ttorch.tensor([1, 2, 3]) # the same as torch.tensor([1, 2, 3])
torch.tensor([1, 2, 3])
>>> ttorch.tensor({'a': 1, 'b': [1, 2, 3], 'c': [[True, False], [False, True]]})
ttorch.Tensor({
'a': torch.tensor(1),
'b': torch.tensor([1, 2, 3]),
'c': torch.tensor([[True, False], [False, True]]),
})
"""
return torch.tensor(*args, **kwargs)
...@@ -7,12 +7,12 @@ from ..utils import replaceable_partial ...@@ -7,12 +7,12 @@ from ..utils import replaceable_partial
func_treelize = replaceable_partial(original_func_treelize) func_treelize = replaceable_partial(original_func_treelize)
__all__ = [ __all__ = [
'TreeSize' 'Size'
] ]
# noinspection PyTypeChecker # noinspection PyTypeChecker
class TreeSize(TreeObject): class Size(TreeObject):
@func_treelize(return_type=TreeObject) @func_treelize(return_type=TreeObject)
def numel(self: torch.Size) -> TreeObject: def numel(self: torch.Size) -> TreeObject:
return self.numel() return self.numel()
......
...@@ -3,13 +3,13 @@ import torch ...@@ -3,13 +3,13 @@ import torch
from treevalue import method_treelize from treevalue import method_treelize
from treevalue.utils import pre_process from treevalue.utils import pre_process
from .size import TreeSize from .size import Size
from ..common import TreeObject, TreeData, ireduce from ..common import TreeObject, TreeData, ireduce
from ..numpy import TreeNumpy from ..numpy import TreeNumpy
from ..utils import inherit_names, current_names, doc_from from ..utils import inherit_names, current_names, doc_from
__all__ = [ __all__ = [
'TreeTensor' 'Tensor'
] ]
_reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {})) _reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
...@@ -19,7 +19,7 @@ tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduc ...@@ -19,7 +19,7 @@ tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduc
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList # noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
@current_names() @current_names()
@inherit_names(TreeData) @inherit_names(TreeData)
class TreeTensor(TreeData): class Tensor(TreeData):
@doc_from(torch.Tensor.numpy) @doc_from(torch.Tensor.numpy)
@method_treelize(return_type=TreeNumpy) @method_treelize(return_type=TreeNumpy)
def numpy(self: torch.Tensor) -> np.ndarray: def numpy(self: torch.Tensor) -> np.ndarray:
...@@ -53,7 +53,7 @@ class TreeTensor(TreeData): ...@@ -53,7 +53,7 @@ class TreeTensor(TreeData):
@property @property
@doc_from(torch.Tensor.shape) @doc_from(torch.Tensor.shape)
@method_treelize(return_type=TreeSize) @method_treelize(return_type=Size)
def shape(self: torch.Tensor): def shape(self: torch.Tensor):
return self.shape return self.shape
......
from functools import wraps
from typing import Callable, Union, Any
__all__ = [ __all__ = [
'replaceable_partial', 'replaceable_partial',
'args_mapping',
] ]
def replaceable_partial(func, **kws): def replaceable_partial(func, **kws):
@wraps(func)
def _new_func(*args, **kwargs): def _new_func(*args, **kwargs):
return func(*args, **{**kws, **kwargs}) return func(*args, **{**kws, **kwargs})
return _new_func return _new_func
def args_mapping(mapper: Callable[[Union[int, str], Any], Any]):
def _decorator(func):
@wraps(func)
def _new_func(*args, **kwargs):
return func(
*(mapper(i, x) for i, x in enumerate(args)),
**{k: mapper(k, v) for k, v in kwargs.items()},
)
return _new_func
return _decorator
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册