tensor.py 6.5 KB
Newer Older
1 2
import numpy as np
import torch
3 4
from treevalue import method_treelize
from treevalue.utils import pre_process
5

6
from .base import TreeTorch
HansBug's avatar
HansBug 已提交
7
from .size import Size
8
from ..common import Object, ireduce, clsmeta
9 10
from ..numpy import ndarray
from ..utils import current_names, doc_from
11

12
__all__ = [
HansBug's avatar
HansBug 已提交
13
    'Tensor'
14 15
]

16
_reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
17
tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce)
18 19


20 21 22 23 24 25 26 27 28 29 30
def _to_tensor(*args, **kwargs):
    if (len(args) == 1 and not kwargs) or \
            (not args and set(kwargs.keys()) == {'data'}):
        data = args[0] if len(args) == 1 else kwargs['data']
        if isinstance(data, torch.Tensor):
            return data

    return torch.tensor(*args, **kwargs)


# noinspection PyTypeChecker
31
@current_names()
32
class Tensor(TreeTorch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
33
    @doc_from(torch.Tensor.numpy)
34
    @method_treelize(return_type=ndarray)
35
    def numpy(self: torch.Tensor) -> np.ndarray:
36 37 38 39 40
        """
        Returns ``self`` tree tensor as a NumPy ``ndarray``.
        This tensor and the returned :class:`treetensor.numpy.ndarray` share the same underlying storage.
        Changes to self tensor will be reflected in the ``ndarray`` and vice versa.
        """
41 42
        return self.numpy()

43
    @doc_from(torch.Tensor.tolist)
44
    @method_treelize(return_type=Object)
45
    def tolist(self: torch.Tensor):
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        """
        Get the dump result of tree tensor.

        Example::

            >>> import torch
            >>> import treetensor.torch as ttorch
            >>> ttorch.tensor({
            >>>     'a': [[1, 2], [3, 4]],
            >>>     'b': [1, 2, 3],
            >>>     'c': True,
            >>> }).tolist()
            TreeObject({
                'a': [[1, 2], [3, 4]],
                'b': [1, 2, 3],
                'c': True,
            })
        """
64 65
        return self.tolist()

66
    @doc_from(torch.Tensor.cpu)
67 68
    @method_treelize()
    def cpu(self: torch.Tensor, *args, **kwargs):
69 70 71 72 73 74
        """
        Returns a copy of this tree tensor in CPU memory.

        If this tree tensor is already in CPU memory and on the correct device,
        then no copy is performed and the original object is returned.
        """
75 76
        return self.cpu(*args, **kwargs)

77
    @doc_from(torch.Tensor.cuda)
78 79
    @method_treelize()
    def cuda(self: torch.Tensor, *args, **kwargs):
80 81 82 83 84 85
        """
        Returns a copy of this tree tensor in CUDA memory.

        If this tree tensor is already in CUDA memory and on the correct device,
        then no copy is performed and the original object is returned.
        """
86 87
        return self.cuda(*args, **kwargs)

88
    @doc_from(torch.Tensor.to)
89 90
    @method_treelize()
    def to(self: torch.Tensor, *args, **kwargs):
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        """
        Turn the original tree tensor to another format.

        Example::

            >>> import torch
            >>> import treetensor.torch as ttorch
            >>> ttorch.tensor({
            ...     'a': [[1, 11], [2, 22], [3, 33]],
            ...     'b': {'x': [[4, 5], [6, 7]]},
            ... }).to(torch.float64)
            <Tensor 0x7ff363bb6518>
            ├── a --> tensor([[ 1., 11.],
            │                 [ 2., 22.],
            │                 [ 3., 33.]], dtype=torch.float64)
            └── b --> <Tensor 0x7ff363bb6ef0>
                └── x --> tensor([[4., 5.],
                                  [6., 7.]], dtype=torch.float64)
        """
110 111
        return self.to(*args, **kwargs)

112
    @doc_from(torch.Tensor.numel)
113
    @ireduce(sum)
114
    @method_treelize(return_type=Object)
115
    def numel(self: torch.Tensor):
116 117 118
        """
        See :func:`treetensor.torch.numel`
        """
119 120 121
        return self.numel()

    @property
122
    @doc_from(torch.Tensor.shape)
HansBug's avatar
HansBug 已提交
123
    @method_treelize(return_type=Size)
124
    def shape(self: torch.Tensor):
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        """
        Get the size of the tensors in the tree.

        Example::

            >>> import torch
            >>> import treetensor.torch as ttorch
            >>> ttorch.tensor({
            ...     'a': [[1, 11], [2, 22], [3, 33]],
            ...     'b': {'x': [[4, 5], [6, 7]]},
            ... }).shape
            <Size 0x7ff363bbbd68>
            ├── a --> torch.Size([3, 2])
            └── b --> <Size 0x7ff363bbbcf8>
                └── x --> torch.Size([2, 2])
        """
141
        return self.shape
142

143
    @doc_from(torch.Tensor.all)
144
    @tireduce(torch.all)
145
    @method_treelize(return_type=Object)
146
    def all(self: torch.Tensor, *args, **kwargs) -> bool:
147 148 149
        """
        See :func:`treetensor.torch.all`
        """
150
        return self.all(*args, **kwargs)
151

152
    @doc_from(torch.Tensor.any)
153
    @tireduce(torch.any)
154
    @method_treelize(return_type=Object)
155
    def any(self: torch.Tensor, *args, **kwargs) -> bool:
156 157 158
        """
        See :func:`treetensor.torch.any`
        """
159 160
        return self.any(*args, **kwargs)

161
    @doc_from(torch.Tensor.max)
162
    @tireduce(torch.max)
163
    @method_treelize(return_type=Object)
164
    def max(self: torch.Tensor, *args, **kwargs):
165 166 167
        """
        See :func:`treetensor.torch.max`
        """
168 169
        return self.max(*args, **kwargs)

170
    @doc_from(torch.Tensor.min)
171
    @tireduce(torch.min)
172
    @method_treelize(return_type=Object)
173
    def min(self: torch.Tensor, *args, **kwargs):
174 175 176
        """
        See :func:`treetensor.torch.min`
        """
177 178
        return self.min(*args, **kwargs)

179
    @doc_from(torch.Tensor.sum)
180
    @tireduce(torch.sum)
181
    @method_treelize(return_type=Object)
182
    def sum(self: torch.Tensor, *args, **kwargs):
183 184 185
        """
        See :func:`treetensor.torch.sum`
        """
186
        return self.sum(*args, **kwargs)
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228

    @method_treelize()
    def __eq__(self, other):
        """
        See :func:`treetensor.torch.eq`.
        """
        return self == other

    @method_treelize()
    def __ne__(self, other):
        """
        See :func:`treetensor.torch.ne`.
        """
        return self != other

    @method_treelize()
    def __lt__(self, other):
        """
        See :func:`treetensor.torch.lt`.
        """
        return self < other

    @method_treelize()
    def __gt__(self, other):
        """
        See :func:`treetensor.torch.gt`.
        """
        return self > other

    @method_treelize()
    def __le__(self, other):
        """
        See :func:`treetensor.torch.le`.
        """
        return self <= other

    @method_treelize()
    def __ge__(self, other):
        """
        See :func:`treetensor.torch.ge`.
        """
        return self >= other