funcs.py 4.1 KB
Newer Older
1 2
import builtins

3
import torch
HansBug's avatar
HansBug 已提交
4
from treevalue import TreeValue
5
from treevalue import func_treelize as original_func_treelize
HansBug's avatar
HansBug 已提交
6
from treevalue.utils import post_process
7

HansBug's avatar
HansBug 已提交
8
from .tensor import Tensor, tireduce
9
from ..common import TreeObject, ireduce
HansBug's avatar
HansBug 已提交
10
from ..utils import replaceable_partial, doc_from, args_mapping
11

12 13 14 15 16 17 18 19 20
__all__ = [
    'zeros', 'zeros_like',
    'randn', 'randn_like',
    'randint', 'randint_like',
    'ones', 'ones_like',
    'full', 'full_like',
    'empty', 'empty_like',
    'all', 'any',
    'eq', 'equal',
HansBug's avatar
HansBug 已提交
21
    'tensor',
22 23
]

HansBug's avatar
HansBug 已提交
24 25 26 27
func_treelize = post_process(post_process(args_mapping(
    lambda i, x: Tensor(x) if isinstance(x, (dict, TreeValue)) else x)))(
    replaceable_partial(original_func_treelize, return_type=Tensor)
)
28 29


30
@doc_from(torch.zeros)
31
@func_treelize()
32 33
def zeros(*args, **kwargs):
    return torch.zeros(*args, **kwargs)
34 35


36
@doc_from(torch.zeros_like)
37 38 39 40 41
@func_treelize()
def zeros_like(input_, *args, **kwargs):
    return torch.zeros_like(input_, *args, **kwargs)


42
@doc_from(torch.randn)
43
@func_treelize()
44 45
def randn(*args, **kwargs):
    return torch.randn(*args, **kwargs)
46 47


48
@doc_from(torch.randn_like)
49 50 51 52 53
@func_treelize()
def randn_like(input_, *args, **kwargs):
    return torch.randn_like(input_, *args, **kwargs)


54
@doc_from(torch.randint)
55
@func_treelize()
56 57
def randint(*args, **kwargs):
    return torch.randint(*args, **kwargs)
58 59


60
@doc_from(torch.randint_like)
61 62 63 64 65
@func_treelize()
def randint_like(input_, *args, **kwargs):
    return torch.randint_like(input_, *args, **kwargs)


66
@doc_from(torch.ones)
67
@func_treelize()
68 69
def ones(*args, **kwargs):
    return torch.ones(*args, **kwargs)
70 71


72
@doc_from(torch.ones_like)
73 74 75 76 77
@func_treelize()
def ones_like(input_, *args, **kwargs):
    return torch.ones_like(input_, *args, **kwargs)


78
@doc_from(torch.full)
79
@func_treelize()
80 81
def full(*args, **kwargs):
    return torch.full(*args, **kwargs)
82 83


84
@doc_from(torch.full_like)
85 86 87 88 89
@func_treelize()
def full_like(input_, *args, **kwargs):
    return torch.full_like(input_, *args, **kwargs)


90
@doc_from(torch.empty)
91
@func_treelize()
92 93
def empty(*args, **kwargs):
    return torch.empty(*args, **kwargs)
94 95


96
@doc_from(torch.empty_like)
97 98 99 100 101
@func_treelize()
def empty_like(input_, *args, **kwargs):
    return torch.empty_like(input_, *args, **kwargs)


102
@doc_from(torch.all)
103
@tireduce(torch.all)
104 105
@func_treelize(return_type=TreeObject)
def all(input_, *args, **kwargs):
106 107 108 109 110
    """
    In ``treetensor``, you can get the ``all`` result of a whole tree with this function.

    Example::

HansBug's avatar
HansBug 已提交
111 112
        >>> import torch
        >>> import treetensor.torch as ttorch
113 114 115
        >>> all(torch.tensor([True, True]))  # the same as torch.all
        torch.tensor(True)

HansBug's avatar
HansBug 已提交
116 117 118
        >>> all(ttorch.tensor({
        >>>     'a': [True, True],
        >>>     'b': [True, True],
119 120 121
        >>> }))
        torch.tensor(True)

HansBug's avatar
HansBug 已提交
122 123 124
        >>> all(Tensor({
        >>>     'a': [True, True],
        >>>     'b': [True, False],
125 126 127 128
        >>> }))
        torch.tensor(False)

    """
129 130 131
    return torch.all(input_, *args, **kwargs)


132
@doc_from(torch.any)
133 134 135 136 137 138
@tireduce(torch.any)
@func_treelize(return_type=TreeObject)
def any(input_, *args, **kwargs):
    return torch.any(input_, *args, **kwargs)


139
@doc_from(torch.eq)
140 141 142 143 144
@func_treelize()
def eq(input_, other, *args, **kwargs):
    return torch.eq(input_, other, *args, **kwargs)


145
@doc_from(torch.equal)
146
@ireduce(builtins.all)
147 148 149
@func_treelize()
def equal(input_, other, *args, **kwargs):
    return torch.equal(input_, other, *args, **kwargs)
HansBug's avatar
HansBug 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176


@doc_from(torch.tensor)
@func_treelize()
def tensor(*args, **kwargs):
    """
    In ``treetensor``, you can create a tree tensor with simple data structure.

    Examples::

        >>> import torch
        >>> import treetensor.torch as ttorch
        >>> ttorch.tensor(True)  # the same as torch.tensor(True)
        torch.tensor(True)

        >>> ttorch.tensor([1, 2, 3])  # the same as torch.tensor([1, 2, 3])
        torch.tensor([1, 2, 3])

        >>> ttorch.tensor({'a': 1, 'b': [1, 2, 3], 'c': [[True, False], [False, True]]})
        ttorch.Tensor({
            'a': torch.tensor(1),
            'b': torch.tensor([1, 2, 3]),
            'c': torch.tensor([[True, False], [False, True]]),
        })

    """
    return torch.tensor(*args, **kwargs)