funcs.py 1.9 KB
Newer Older
1 2
import builtins

3
import numpy as np
4
from treevalue import TreeValue
5
from treevalue import func_treelize as original_func_treelize
6
from treevalue.tree.common import TreeStorage
7
from treevalue.utils import post_process
8

9
from .array import ndarray
10
from ..common import ireduce, Object, module_func_loader
11
from ..utils import replaceable_partial, doc_from, args_mapping
12

13
__all__ = [
14
    'all', 'any', 'array',
15 16 17
    'equal', 'array_equal',
]

18
func_treelize = post_process(post_process(args_mapping(
19
    lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeStorage, TreeValue)) else x)))(
20
    replaceable_partial(original_func_treelize, return_type=ndarray)
21
)
22 23
get_func_from_numpy = module_func_loader(np, ndarray,
                                         [(np.ndarray, ndarray)])
24 25


26
@doc_from(np.all)
27
@ireduce(builtins.all)
28
@func_treelize(return_type=Object)
29 30 31 32
def all(a, *args, **kwargs):
    return np.all(a, *args, **kwargs)


33
@doc_from(np.any)
34
@ireduce(builtins.any)
35 36 37 38 39
@func_treelize()
def any(a, *args, **kwargs):
    return np.any(a, *args, **kwargs)


40
@doc_from(np.equal)
41 42 43
@func_treelize()
def equal(x1, x2, *args, **kwargs):
    return np.equal(x1, x2, *args, **kwargs)
44 45


46
@doc_from(np.array_equal)
47 48 49
@func_treelize()
def array_equal(a1, a2, *args, **kwargs):
    return np.array_equal(a1, a2, *args, **kwargs)
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73


@doc_from(np.array)
@func_treelize()
def array(p_object, *args, **kwargs):
    """
    In ``treetensor``, you can create a tree of :class:`np.ndarray` with :func:`array`.

    Examples::

        >>> import numpy as np
        >>> import treetensor.numpy as tnp
        >>> tnp.array({
        >>>     'a': [1, 2, 3],
        >>>     'b': [[4, 5], [5, 6]],
        >>>     'c': True,
        >>> })
        tnp.ndarray({
            'a': np.array([1, 2, 3]),
            'b': np.array([[4, 5], [5, 6]]),
            'c': np.array(True),
        })
    """
    return np.array(p_object, *args, **kwargs)