提交 61e7ee05 编写于 作者: HansBug's avatar HansBug 😆

release(hansbug): version to 0.1.0

上级 75b7a2e2
from collections import namedtuple
from functools import wraps
from itertools import chain
from operator import itemgetter
from treevalue import TreeValue
from treevalue import reduce_ as treevalue_reduce
from treevalue import TreeValue, walk
__all__ = [
'kwreduce', 'ireduce', 'vreduce',
'ireduce',
'return_self',
]
def kwreduce(rfunc):
def _decorator(func):
@wraps(func)
def _new_func(*args, **kwargs):
_result = func(*args, **kwargs)
if isinstance(_result, TreeValue):
return treevalue_reduce(_result, rfunc)
else:
return _result
return _new_func
return _decorator
def vreduce(rfunc):
return kwreduce(lambda **kws: rfunc(kws.values()))
def ireduce(rfunc, piter=None):
_IterReduceWrapper = namedtuple("_IterReduceWrapper", ['v'])
piter = piter or (lambda x: x)
def _reduce_func(values):
_list = []
for item in values:
if isinstance(item, _IterReduceWrapper):
_list.append(item.v)
else:
_list.append([item])
return _IterReduceWrapper(chain(*_list))
def _decorator(func):
rifunc = vreduce(_reduce_func)(func)
@wraps(func)
def _new_func(*args, **kwargs):
_iw = rifunc(*args, **kwargs)
if isinstance(_iw, _IterReduceWrapper):
return rfunc(piter(_iw.v))
result = func(*args, **kwargs)
if isinstance(result, TreeValue):
it = map(itemgetter(1), walk(result, include_nodes=False))
return rfunc(piter(it))
else:
return _iw
return result
return _new_func
......
......@@ -7,7 +7,7 @@ Overview:
__TITLE__ = "DI-treetensor"
#: Version of this project.
__VERSION__ = "0.0.1"
__VERSION__ = "0.1.0"
#: Short description of the project, will be included in ``setup.py``.
__DESCRIPTION__ = 'A flexible, generalized tree-based tensor structure.'
......
import numpy as np
import torch as pytorch
from hbutils.reflection import post_process
from treevalue import method_treelize, TreeValue
from treevalue import method_treelize, TreeValue, typetrans
from .base import Torch, rmreduce, post_reduce, auto_reduce
from .size import Size
......@@ -15,7 +15,18 @@ __all__ = [
]
doc_from_base = replaceable_partial(original_doc_from_base, base=pytorch.Tensor)
_TorchProxy, _InstanceTorchProxy = get_tree_proxy(pytorch.Tensor)
def _auto_tensor(t):
if isinstance(t, TreeValue):
t = typetrans(t, Object)
if t.map(lambda x, _: pytorch.is_tensor(x)).all():
return typetrans(t, Tensor)
return t
_TorchProxy, _InstanceTorchProxy = get_tree_proxy(pytorch.Tensor, _auto_tensor)
def _to_tensor(data, *args, **kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册