提交 ec7f9fd0 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add ttorch.as_tensor, update constructor of ttorch.Tensor

上级 37a4f1f4
......@@ -34,6 +34,19 @@ class TestTorchFuncsConstruct:
}
})).all()
@choose_mark()
def test_tensor(self):
assert ttorch.as_tensor(True) == torch.tensor(True)
assert (ttorch.as_tensor([1, 2, 3], dtype=torch.float32) == torch.tensor([1.0, 2.0, 3.0])).all()
assert (ttorch.as_tensor({
'a': torch.tensor([1, 2, 3]),
'b': {'x': [[4, 5], [6, 7]]}
}, dtype=torch.float32) == ttorch.tensor({
'a': [1.0, 2.0, 3.0],
'b': {'x': [[4.0, 5.0], [6.0, 7.0]]},
})).all()
@choose_mark()
def test_clone(self):
t1 = ttorch.clone(torch.tensor([1.0, 2.0, 1.5]))
......
......@@ -49,16 +49,16 @@ class TestTorchFuncsReduction:
assert not r6
r7 = ttorch.all(ttorch.tensor({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
'a': [True, True, True],
'b': [True, True, False],
}), reduce=False)
assert (r7 == ttorch.tensor({
'a': True, 'b': False
})).all()
r8 = ttorch.all(ttorch.tensor({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
'a': [True, True, True],
'b': [True, True, False],
}), dim=0)
assert (r8 == ttorch.tensor({
'a': True, 'b': False
......@@ -66,8 +66,8 @@ class TestTorchFuncsReduction:
with pytest.warns(UserWarning):
r9 = ttorch.all(ttorch.tensor({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
'a': [True, True, True],
'b': [True, True, False],
}), dim=0, reduce=True)
assert (r9 == ttorch.tensor({
'a': True, 'b': False
......@@ -90,41 +90,41 @@ class TestTorchFuncsReduction:
assert r3 == torch.tensor(False)
assert not r3
r4 = ttorch.any({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
})
r4 = ttorch.any(ttorch.tensor({
'a': [True, True, True],
'b': [True, True, True],
}))
assert torch.is_tensor(r4)
assert r4 == torch.tensor(True)
assert r4
r5 = ttorch.any({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
})
r5 = ttorch.any(ttorch.tensor({
'a': [True, True, True],
'b': [True, True, False],
}))
assert torch.is_tensor(r5)
assert r5 == torch.tensor(True)
assert r5
r6 = ttorch.any({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
})
r6 = ttorch.any(ttorch.tensor({
'a': [False, False, False],
'b': [False, False, False],
}))
assert torch.is_tensor(r6)
assert r6 == torch.tensor(False)
assert not r6
r7 = ttorch.any(ttorch.tensor({
'a': torch.tensor([True, True, False]),
'b': torch.tensor([False, False, False]),
'a': [True, True, False],
'b': [False, False, False],
}), reduce=False)
assert (r7 == ttorch.tensor({
'a': True, 'b': False
})).all()
r8 = ttorch.any(ttorch.tensor({
'a': torch.tensor([True, True, False]),
'b': torch.tensor([False, False, False]),
'a': [True, True, False],
'b': [False, False, False],
}), dim=0)
assert (r8 == ttorch.tensor({
'a': True, 'b': False
......@@ -132,8 +132,8 @@ class TestTorchFuncsReduction:
with pytest.warns(UserWarning):
r9 = ttorch.any(ttorch.tensor({
'a': torch.tensor([True, True, False]),
'b': torch.tensor([False, False, False]),
'a': [True, True, False],
'b': [False, False, False],
}), dim=0, reduce=True)
assert (r9 == ttorch.tensor({
'a': True, 'b': False
......
......@@ -3,7 +3,7 @@ import torch
from .base import doc_from_base, func_treelize
__all__ = [
'tensor', 'clone',
'tensor', 'as_tensor', 'clone',
'zeros', 'zeros_like',
'randn', 'randn_like',
'randint', 'randint_like',
......@@ -36,10 +36,36 @@ def tensor(data, *args, **kwargs):
└── c --> tensor([[ True, False],
[False, True]])
"""
if torch.is_tensor(data):
return data
else:
return torch.tensor(data, *args, **kwargs)
return torch.tensor(data, *args, **kwargs)
@doc_from_base()
@func_treelize()
def as_tensor(data, *args, **kwargs):
"""
Convert the data into a :class:`treetensor.torch.Tensor` or :class:`torch.Tensor`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.as_tensor(True)
tensor(True)
>>> ttorch.as_tensor([1, 2, 3], dtype=torch.float32)
tensor([1., 2., 3.])
>>> ttorch.as_tensor({
... 'a': torch.tensor([1, 2, 3]),
... 'b': {'x': [[4, 5], [6, 7]]}
... }, dtype=torch.float32)
<Tensor 0x7fc2b80c25c0>
├── a --> tensor([1., 2., 3.])
└── b --> <Tensor 0x7fc2b80c24e0>
└── x --> tensor([[4., 5.],
[6., 7.]])
"""
return torch.as_tensor(data, *args, **kwargs)
# noinspection PyShadowingBuiltins
......
......@@ -18,17 +18,7 @@ doc_from_base = replaceable_partial(original_doc_from_base, base=pytorch.Tensor)
_TorchProxy, _InstanceTorchProxy = get_tree_proxy(pytorch.Tensor)
def _to_tensor(*args, **kwargs):
if (len(args) == 1 and not kwargs) or \
(not args and set(kwargs.keys()) == {'data'}):
data = args[0] if len(args) == 1 else kwargs['data']
if isinstance(data, pytorch.Tensor):
return data
return pytorch.tensor(*args, **kwargs)
class _BaseTensorMeta(clsmeta(_to_tensor, allow_dict=True)):
class _BaseTensorMeta(clsmeta(pytorch.as_tensor, allow_dict=True)):
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册