diff --git a/test/torch/funcs/test_construct.py b/test/torch/funcs/test_construct.py index a8e8d3ece18764f71d20c79f3a63dc3376f4e63e..6ea120b858b85121a571bb0b0196c68545e71232 100644 --- a/test/torch/funcs/test_construct.py +++ b/test/torch/funcs/test_construct.py @@ -34,6 +34,19 @@ class TestTorchFuncsConstruct: } })).all() + @choose_mark() + def test_tensor(self): + assert ttorch.as_tensor(True) == torch.tensor(True) + assert (ttorch.as_tensor([1, 2, 3], dtype=torch.float32) == torch.tensor([1.0, 2.0, 3.0])).all() + + assert (ttorch.as_tensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': {'x': [[4, 5], [6, 7]]} + }, dtype=torch.float32) == ttorch.tensor({ + 'a': [1.0, 2.0, 3.0], + 'b': {'x': [[4.0, 5.0], [6.0, 7.0]]}, + })).all() + @choose_mark() def test_clone(self): t1 = ttorch.clone(torch.tensor([1.0, 2.0, 1.5])) diff --git a/test/torch/funcs/test_reduction.py b/test/torch/funcs/test_reduction.py index 2cfdfa313ec8e4697a467514d2657a12bbd53835..58dd5ab88fb6130fb2e9083fef7ba79f89d08cc7 100644 --- a/test/torch/funcs/test_reduction.py +++ b/test/torch/funcs/test_reduction.py @@ -49,16 +49,16 @@ class TestTorchFuncsReduction: assert not r6 r7 = ttorch.all(ttorch.tensor({ - 'a': torch.tensor([True, True, True]), - 'b': torch.tensor([True, True, False]), + 'a': [True, True, True], + 'b': [True, True, False], }), reduce=False) assert (r7 == ttorch.tensor({ 'a': True, 'b': False })).all() r8 = ttorch.all(ttorch.tensor({ - 'a': torch.tensor([True, True, True]), - 'b': torch.tensor([True, True, False]), + 'a': [True, True, True], + 'b': [True, True, False], }), dim=0) assert (r8 == ttorch.tensor({ 'a': True, 'b': False @@ -66,8 +66,8 @@ class TestTorchFuncsReduction: with pytest.warns(UserWarning): r9 = ttorch.all(ttorch.tensor({ - 'a': torch.tensor([True, True, True]), - 'b': torch.tensor([True, True, False]), + 'a': [True, True, True], + 'b': [True, True, False], }), dim=0, reduce=True) assert (r9 == ttorch.tensor({ 'a': True, 'b': False @@ -90,41 +90,41 @@ class TestTorchFuncsReduction: assert r3 == torch.tensor(False) assert not r3 - r4 = ttorch.any({ - 'a': torch.tensor([True, True, True]), - 'b': torch.tensor([True, True, True]), - }) + r4 = ttorch.any(ttorch.tensor({ + 'a': [True, True, True], + 'b': [True, True, True], + })) assert torch.is_tensor(r4) assert r4 == torch.tensor(True) assert r4 - r5 = ttorch.any({ - 'a': torch.tensor([True, True, True]), - 'b': torch.tensor([True, True, False]), - }) + r5 = ttorch.any(ttorch.tensor({ + 'a': [True, True, True], + 'b': [True, True, False], + })) assert torch.is_tensor(r5) assert r5 == torch.tensor(True) assert r5 - r6 = ttorch.any({ - 'a': torch.tensor([False, False, False]), - 'b': torch.tensor([False, False, False]), - }) + r6 = ttorch.any(ttorch.tensor({ + 'a': [False, False, False], + 'b': [False, False, False], + })) assert torch.is_tensor(r6) assert r6 == torch.tensor(False) assert not r6 r7 = ttorch.any(ttorch.tensor({ - 'a': torch.tensor([True, True, False]), - 'b': torch.tensor([False, False, False]), + 'a': [True, True, False], + 'b': [False, False, False], }), reduce=False) assert (r7 == ttorch.tensor({ 'a': True, 'b': False })).all() r8 = ttorch.any(ttorch.tensor({ - 'a': torch.tensor([True, True, False]), - 'b': torch.tensor([False, False, False]), + 'a': [True, True, False], + 'b': [False, False, False], }), dim=0) assert (r8 == ttorch.tensor({ 'a': True, 'b': False @@ -132,8 +132,8 @@ class TestTorchFuncsReduction: with pytest.warns(UserWarning): r9 = ttorch.any(ttorch.tensor({ - 'a': torch.tensor([True, True, False]), - 'b': torch.tensor([False, False, False]), + 'a': [True, True, False], + 'b': [False, False, False], }), dim=0, reduce=True) assert (r9 == ttorch.tensor({ 'a': True, 'b': False diff --git a/treetensor/torch/funcs/construct.py b/treetensor/torch/funcs/construct.py index f8a7d10a9c4c32f34f8474614931e507f50aa823..5e8156c34fb86b0c99b4b832082c7b87e33b34db 100644 --- a/treetensor/torch/funcs/construct.py +++ b/treetensor/torch/funcs/construct.py @@ -3,7 +3,7 @@ import torch from .base import doc_from_base, func_treelize __all__ = [ - 'tensor', 'clone', + 'tensor', 'as_tensor', 'clone', 'zeros', 'zeros_like', 'randn', 'randn_like', 'randint', 'randint_like', @@ -36,10 +36,36 @@ def tensor(data, *args, **kwargs): └── c --> tensor([[ True, False], [False, True]]) """ - if torch.is_tensor(data): - return data - else: - return torch.tensor(data, *args, **kwargs) + return torch.tensor(data, *args, **kwargs) + + +@doc_from_base() +@func_treelize() +def as_tensor(data, *args, **kwargs): + """ + Convert the data into a :class:`treetensor.torch.Tensor` or :class:`torch.Tensor`. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.as_tensor(True) + tensor(True) + + >>> ttorch.as_tensor([1, 2, 3], dtype=torch.float32) + tensor([1., 2., 3.]) + + >>> ttorch.as_tensor({ + ... 'a': torch.tensor([1, 2, 3]), + ... 'b': {'x': [[4, 5], [6, 7]]} + ... }, dtype=torch.float32) + + ├── a --> tensor([1., 2., 3.]) + └── b --> + └── x --> tensor([[4., 5.], + [6., 7.]]) + """ + return torch.as_tensor(data, *args, **kwargs) # noinspection PyShadowingBuiltins diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index b0d2360f6a93b461a142dd47593b4c8acaef59ed..f7288cd3e49c92d2105af8bdc859a73188b194c0 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -18,17 +18,7 @@ doc_from_base = replaceable_partial(original_doc_from_base, base=pytorch.Tensor) _TorchProxy, _InstanceTorchProxy = get_tree_proxy(pytorch.Tensor) -def _to_tensor(*args, **kwargs): - if (len(args) == 1 and not kwargs) or \ - (not args and set(kwargs.keys()) == {'data'}): - data = args[0] if len(args) == 1 else kwargs['data'] - if isinstance(data, pytorch.Tensor): - return data - - return pytorch.tensor(*args, **kwargs) - - -class _BaseTensorMeta(clsmeta(_to_tensor, allow_dict=True)): +class _BaseTensorMeta(clsmeta(pytorch.as_tensor, allow_dict=True)): pass