tensor.py 2.2 KB
Newer Older
1 2
import numpy as np
import torch
3 4
from treevalue import method_treelize
from treevalue.utils import pre_process
5 6

from .size import TreeSize
7
from ..common import TreeObject, TreeData, ireduce
8
from ..numpy import TreeNumpy
9
from ..utils import inherit_names, current_names
10

11 12 13 14
__all__ = [
    'TreeTensor'
]

15
_reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
16
tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce)
17 18


19
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
20 21
@current_names()
@inherit_names(TreeData)
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
class TreeTensor(TreeData):
    @method_treelize(return_type=TreeNumpy)
    def numpy(self: torch.Tensor) -> np.ndarray:
        return self.numpy()

    @method_treelize(return_type=TreeObject)
    def tolist(self: torch.Tensor):
        return self.tolist()

    @method_treelize()
    def cpu(self: torch.Tensor, *args, **kwargs):
        return self.cpu(*args, **kwargs)

    @method_treelize()
    def cuda(self: torch.Tensor, *args, **kwargs):
        return self.cuda(*args, **kwargs)

    @method_treelize()
    def to(self: torch.Tensor, *args, **kwargs):
        return self.to(*args, **kwargs)

43
    @ireduce(sum)
44 45 46 47 48 49 50 51
    @method_treelize(return_type=TreeObject)
    def numel(self: torch.Tensor):
        return self.numel()

    @property
    @method_treelize(return_type=TreeSize)
    def shape(self: torch.Tensor):
        return self.shape
52

53
    @tireduce(torch.all)
54
    @method_treelize(return_type=TreeObject)
55
    def all(self: torch.Tensor, *args, **kwargs) -> bool:
56
        return self.all(*args, **kwargs)
57

58
    @tireduce(torch.any)
59 60 61 62
    @method_treelize(return_type=TreeObject)
    def any(self: torch.Tensor, *args, **kwargs) -> bool:
        return self.any(*args, **kwargs)

63
    @tireduce(torch.max)
64 65 66 67
    @method_treelize(return_type=TreeObject)
    def max(self: torch.Tensor, *args, **kwargs):
        return self.max(*args, **kwargs)

68
    @tireduce(torch.min)
69 70 71 72
    @method_treelize(return_type=TreeObject)
    def min(self: torch.Tensor, *args, **kwargs):
        return self.min(*args, **kwargs)

73
    @tireduce(torch.sum)
74 75 76
    @method_treelize(return_type=TreeObject)
    def sum(self: torch.Tensor, *args, **kwargs):
        return self.sum(*args, **kwargs)