tensor.py 2.6 KB
Newer Older
1 2
import numpy as np
import torch
3 4
from treevalue import method_treelize
from treevalue.utils import pre_process
5 6

from .size import TreeSize
7
from ..common import TreeObject, TreeData, ireduce
8
from ..numpy import TreeNumpy
9
from ..utils import inherit_names, current_names, doc_from
10

11 12 13 14
__all__ = [
    'TreeTensor'
]

15
_reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
16
tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce)
17 18


19
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
20 21
@current_names()
@inherit_names(TreeData)
22
class TreeTensor(TreeData):
23
    @doc_from(torch.Tensor.numpy)
24 25 26 27
    @method_treelize(return_type=TreeNumpy)
    def numpy(self: torch.Tensor) -> np.ndarray:
        return self.numpy()

28
    @doc_from(torch.Tensor.tolist)
29 30 31 32
    @method_treelize(return_type=TreeObject)
    def tolist(self: torch.Tensor):
        return self.tolist()

33
    @doc_from(torch.Tensor.cpu)
34 35 36 37
    @method_treelize()
    def cpu(self: torch.Tensor, *args, **kwargs):
        return self.cpu(*args, **kwargs)

38
    @doc_from(torch.Tensor.cuda)
39 40 41 42
    @method_treelize()
    def cuda(self: torch.Tensor, *args, **kwargs):
        return self.cuda(*args, **kwargs)

43
    @doc_from(torch.Tensor.to)
44 45 46 47
    @method_treelize()
    def to(self: torch.Tensor, *args, **kwargs):
        return self.to(*args, **kwargs)

48
    @doc_from(torch.Tensor.numel)
49
    @ireduce(sum)
50 51 52 53 54
    @method_treelize(return_type=TreeObject)
    def numel(self: torch.Tensor):
        return self.numel()

    @property
55
    @doc_from(torch.Tensor.shape)
56 57 58
    @method_treelize(return_type=TreeSize)
    def shape(self: torch.Tensor):
        return self.shape
59

60
    @doc_from(torch.Tensor.all)
61
    @tireduce(torch.all)
62
    @method_treelize(return_type=TreeObject)
63
    def all(self: torch.Tensor, *args, **kwargs) -> bool:
64
        return self.all(*args, **kwargs)
65

66
    @doc_from(torch.Tensor.any)
67
    @tireduce(torch.any)
68 69 70 71
    @method_treelize(return_type=TreeObject)
    def any(self: torch.Tensor, *args, **kwargs) -> bool:
        return self.any(*args, **kwargs)

72
    @doc_from(torch.Tensor.max)
73
    @tireduce(torch.max)
74 75 76 77
    @method_treelize(return_type=TreeObject)
    def max(self: torch.Tensor, *args, **kwargs):
        return self.max(*args, **kwargs)

78
    @doc_from(torch.Tensor.min)
79
    @tireduce(torch.min)
80 81 82 83
    @method_treelize(return_type=TreeObject)
    def min(self: torch.Tensor, *args, **kwargs):
        return self.min(*args, **kwargs)

84
    @doc_from(torch.Tensor.sum)
85
    @tireduce(torch.sum)
86 87 88
    @method_treelize(return_type=TreeObject)
    def sum(self: torch.Tensor, *args, **kwargs):
        return self.sum(*args, **kwargs)