Source code for treetensor.common.trees
from functools import partial
from typing import Type
from hbutils.reflection import post_process
from treevalue import func_treelize as original_func_treelize
from treevalue import general_tree_value, TreeValue, typetrans
from treevalue.tree.common import TreeStorage
from ..utils import replaceable_partial, args_mapping
__all__ = [
'BaseTreeStruct',
'clsmeta', 'auto_tree',
]
[docs]class BaseTreeStruct(general_tree_value()):
"""
Overview:
Base structure of all the trees in ``treetensor``.
"""
pass
[docs]def clsmeta(func, allow_dict: bool = False) -> Type[type]:
"""
Overview:
Create a metaclass based on generating function.
Used in :py:class:`treetensor.common.Object`,
:py:class:`treetensor.torch.Tensor` and :py:class:`treetensor.torch.Size`.
Can do modify onto the constructor function of the classes.
Arguments:
- func: Generating function.
- allow_dict (:obj:`bool`): Auto transform dictionary to :py:class:`treevalue.TreeValue` class, \
default is ``False``.
Returns:
- metaclass (:obj:`Type[type]`): Metaclass for creating a new class.
"""
class _TempTreeValue(TreeValue):
pass
def _mapping_func(_, x):
if isinstance(x, TreeValue):
return x
elif isinstance(x, TreeStorage):
return TreeValue(x)
elif allow_dict and isinstance(x, dict):
return TreeValue(x)
else:
return x
func_treelize = post_process(post_process(args_mapping(_mapping_func)))(
replaceable_partial(original_func_treelize, return_type=_TempTreeValue)
)
_wrapped_func = func_treelize()(func)
class _MetaClass(type):
def __call__(cls, data, *args, **kwargs):
if isinstance(data, TreeStorage):
return type.__call__(cls, data)
elif isinstance(data, cls) and not args and not kwargs:
return data
_result = _wrapped_func(data, *args, **kwargs)
if isinstance(_result, _TempTreeValue):
return type.__call__(cls, _result)
else:
return _result
return _MetaClass
def _auto_tree_func(t, cls):
from .object import Object
t = typetrans(t, return_type=Object)
for key, value in cls:
if isinstance(key, type):
predict = lambda x: isinstance(x, key)
elif callable(key):
predict = lambda x: key(x)
else:
raise TypeError(f'Unknown type of prediction - {repr(key)}.')
if t.map(predict).all():
return typetrans(t, return_type=value)
return t
# noinspection PyArgumentList
def auto_tree(v, cls):
if isinstance(cls, type) and issubclass(cls, TreeValue):
cls = partial(typetrans, return_type=cls)
elif isinstance(cls, (list, tuple)):
cls = partial(_auto_tree_func, cls=cls)
elif callable(cls):
pass
else:
raise TypeError(f'Unknown type of cls - {repr(cls)}.')
if isinstance(v, TreeValue):
return cls(v)
elif isinstance(v, (tuple, list, set)):
return type(v)((auto_tree(item, cls) for item in v))
elif isinstance(v, dict):
return type(v)({key: auto_tree(value, cls) for key, value in v.items()})
else:
return v