提交 97511ca7 编写于 作者: HansBug's avatar HansBug 😆

doc(hansbug): add all the __all__ into source code && change all the inner...

doc(hansbug): add all the __all__ into source code && change all the inner from ... import xxx to import *
上级 c5c230d8
from .trees import TreeData, TreeObject, BaseTreeStruct
from .wrappers import kwreduce, vreduce, ireduce
from .trees import *
from .wrappers import *
......@@ -2,6 +2,10 @@ from abc import ABCMeta
from treevalue import general_tree_value, method_treelize
__all__ = [
'BaseTreeStruct', "TreeData", 'TreeObject',
]
class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
"""
......
......@@ -5,6 +5,10 @@ from itertools import chain
from treevalue import TreeValue
from treevalue import reduce_ as treevalue_reduce
__all__ = [
'kwreduce', 'ireduce', 'vreduce',
]
def kwreduce(rfunc):
def _decorator(func):
......
from .funcs import equal, array_equal, all
from .numpy import TreeNumpy
from .funcs import *
from .numpy import *
......@@ -5,6 +5,11 @@ from .numpy import TreeNumpy
from ..common import ireduce
from ..utils import replaceable_partial
__all__ = [
'all',
'equal', 'array_equal',
]
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeNumpy)
......
......@@ -3,6 +3,10 @@ from treevalue import method_treelize
from ..common import TreeObject, TreeData, ireduce
__all__ = [
'TreeNumpy'
]
class TreeNumpy(TreeData):
"""
......
from .funcs import zeros_like, full_like, ones_like, randint_like, randn_like, empty_like, zeros, randn, randint, \
ones, empty, full, all, eq, equal
from .size import TreeSize
from .tensor import TreeTensor
from .funcs import *
from .size import *
from .tensor import *
import torch
from treevalue import func_treelize as original_func_treelize
from .tensor import TreeTensor, _reduce_tensor_wrap
from ..common import TreeObject, ireduce
from .tensor import TreeTensor, tireduce
from ..common import TreeObject
from ..utils import replaceable_partial
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor)
__all__ = [
'zeros', 'zeros_like',
'randn', 'randn_like',
'randint', 'randint_like',
'ones', 'ones_like',
'full', 'full_like',
'empty', 'empty_like',
'all', 'any',
'eq', 'equal',
]
@func_treelize()
def zeros(size, *args, **kwargs):
......@@ -68,12 +79,18 @@ def empty_like(input_, *args, **kwargs):
return torch.empty_like(input_, *args, **kwargs)
@ireduce(_reduce_tensor_wrap(torch.all))
@tireduce(torch.all)
@func_treelize(return_type=TreeObject)
def all(input_, *args, **kwargs):
return torch.all(input_, *args, **kwargs)
@tireduce(torch.any)
@func_treelize(return_type=TreeObject)
def any(input_, *args, **kwargs):
return torch.any(input_, *args, **kwargs)
@func_treelize()
def eq(input_, other, *args, **kwargs):
return torch.eq(input_, other, *args, **kwargs)
......
......@@ -6,6 +6,10 @@ from ..utils import replaceable_partial
func_treelize = replaceable_partial(original_func_treelize)
__all__ = [
'TreeSize'
]
# noinspection PyTypeChecker
class TreeSize(TreeObject):
......
......@@ -7,7 +7,12 @@ from .size import TreeSize
from ..common import TreeObject, TreeData, ireduce
from ..numpy import TreeNumpy
__all__ = [
'TreeTensor'
]
_reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce)
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
......@@ -42,27 +47,27 @@ class TreeTensor(TreeData):
def shape(self: torch.Tensor):
return self.shape
@ireduce(_reduce_tensor_wrap(torch.all))
@tireduce(torch.all)
@method_treelize(return_type=TreeObject)
def all(self: torch.Tensor, *args, **kwargs) -> bool:
return self.all(*args, **kwargs)
@ireduce(_reduce_tensor_wrap(torch.any))
@tireduce(torch.any)
@method_treelize(return_type=TreeObject)
def any(self: torch.Tensor, *args, **kwargs) -> bool:
return self.any(*args, **kwargs)
@ireduce(_reduce_tensor_wrap(torch.max))
@tireduce(torch.max)
@method_treelize(return_type=TreeObject)
def max(self: torch.Tensor, *args, **kwargs):
return self.max(*args, **kwargs)
@ireduce(_reduce_tensor_wrap(torch.min))
@tireduce(torch.min)
@method_treelize(return_type=TreeObject)
def min(self: torch.Tensor, *args, **kwargs):
return self.min(*args, **kwargs)
@ireduce(_reduce_tensor_wrap(torch.sum))
@tireduce(torch.sum)
@method_treelize(return_type=TreeObject)
def sum(self: torch.Tensor, *args, **kwargs):
return self.sum(*args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册