funcs.py 1.6 KB
Newer Older
1 2 3 4 5 6
from functools import partial, wraps
from typing import Tuple

import torch
from treevalue import func_treelize, TreeValue

7
from .tensor import TreeTensor
8
from ..common import vreduce
9 10

_treelize = partial(func_treelize, return_type=TreeTensor)
11
_python_all = all
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32


def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **kwargs_):
    def _decorator(func):
        @_treelize(*args_, **kwargs_)
        def _sub_func(size: Tuple[int, ...], *args, **kwargs):
            _size_args = (size,) if tuple_ else size
            _args = (*args, *_size_args) if prefix else (*_size_args, *args)
            return func(*_args, **kwargs)

        @wraps(func)
        def _new_func(size, *args, **kwargs):
            if isinstance(size, (TreeValue, dict)):
                size = TreeTensor(size)
            return _sub_func(size, *args, **kwargs)

        return _new_func

    return _decorator


33
# Tensor generation based on shapes
34 35 36 37
zeros = _size_based_treelize()(torch.zeros)
randn = _size_based_treelize()(torch.randn)
randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint)
ones = _size_based_treelize()(torch.ones)
HansBug's avatar
HansBug 已提交
38
full = _size_based_treelize(tuple_=True)(torch.full)
39 40
empty = _size_based_treelize()(torch.empty)

41
# Tensor generation based on another tensor
42 43 44 45 46 47
zeros_like = _treelize()(torch.zeros_like)
randn_like = _treelize()(torch.randn_like)
randint_like = _treelize()(torch.randint_like)
ones_like = _treelize()(torch.ones_like)
full_like = _treelize()(torch.full_like)
empty_like = _treelize()(torch.empty_like)
48 49

# Tensor operators
50
all = vreduce(all)(_treelize()(torch.all))
51 52
eq = _treelize()(torch.eq)
equal = _treelize()(torch.equal)