...
 
Commits (13)
    https://gitcode.net/opendilab/treevalue/-/commit/874bd09c1193792eec4037f93e8807d3e1d2d424 dev(hansug): add support for torch integration 2023-02-27T15:17:29+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/28c99a4d8b931516896f49f75c9de4a8bca8f06d dev(hansbug): add unpack 2023-02-27T15:47:14+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/a504c0e637699ddcdc75c63d9297bbe848e754a4 dev(hansbug): add generic_flatten and generic_unflatten 2023-02-27T17:46:36+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/b2b01bf338e9233f86661a9c2600a5bb0e4a8758 dev(hansbug): fix this bug 2023-02-27T18:10:39+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/ac90658759a4a138a4d02b3935124e85313e829e Merge pull request #80 from opendilab/dev/unpack 2023-02-27T18:18:11+08:00 Hankson Bradley hansbug@buaa.edu.cn dev(hansbug): add unpack https://gitcode.net/opendilab/treevalue/-/commit/adc645eaabe1da8bc54c47c54885dcd44466396a Merge branch 'main' into dev/torch 2023-02-27T18:19:18+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/704a8309cb8dba43ab86f039618be651f1ae1d95 Merge pull request #79 from opendilab/dev/torch 2023-02-27T18:31:32+08:00 Hankson Bradley hansbug@buaa.edu.cn dev(hansug): add support for torch integration https://gitcode.net/opendilab/treevalue/-/commit/60d6b0582400739c991d5a7985ba74dc060285f6 Merge branch 'main' into dev/int 2023-02-27T18:32:45+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/c38bf2339f7b707faa5443820a6a224033fc165f dev(hansbug): add generic_mapping function 2023-02-27T19:00:04+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/dd9ff51038b3b06f3fdce24e0d0fd3534cdb4ca0 dev(hansbug): remove support for dataclasses 2023-02-27T19:05:57+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/c9a27c36255a01bf831d3c66acc3c3548ae52dfd release(hansbug): use version 1.4.7 2023-02-27T19:21:22+08:00 HansBug hansbug@buaa.edu.cn https://gitcode.net/opendilab/treevalue/-/commit/fbb7ad1854f8455286e2cedd40b3312212c7f17a Merge pull request #81 from opendilab/dev/int 2023-02-27T20:16:01+08:00 Hankson Bradley hansbug@buaa.edu.cn dev(hansbug): add generic_flatten, generic_unflatten and register_integrate_container for integration module https://gitcode.net/opendilab/treevalue/-/commit/15debc52f39afd538030c70c475202e5fea06e32 doc(hansbug): remove useless docs 2023-02-28T11:23:34+08:00 HansBug hansbug@buaa.edu.cn
......@@ -7,7 +7,56 @@ treevalue.tree.integration
.. _apidoc_tree_integration_register_for_jax:
register_for_jax
------------------------
---------------------------
.. autofunction:: register_for_jax
.. _apidoc_tree_integration_register_for_torch:
register_for_torch
---------------------------
.. autofunction:: register_for_torch
.. _apidoc_tree_integration_register_treevalue_class:
register_treevalue_class
---------------------------
.. autofunction:: register_treevalue_class
.. _apidoc_tree_integration_register_integrate_container:
register_integrate_container
--------------------------------
.. autofunction:: register_integrate_container
.. _apidoc_tree_integration_generic_flatten:
generic_flatten
--------------------------------
.. autofunction:: generic_flatten
.. _apidoc_tree_integration_generic_unflatten:
generic_unflatten
--------------------------------
.. autofunction:: generic_unflatten
.. _apidoc_tree_integration_generic_mapping:
generic_mapping
--------------------------------
.. autofunction:: generic_mapping
......@@ -9,7 +9,7 @@ TreeValue
---------------
.. autoclass:: TreeValue
:members: __init__, __getattribute__, __setattr__, __delattr__, __contains__, __repr__, __iter__, __hash__, __eq__, _attr_extern, __len__, __bool__, __str__, __getstate__, __setstate__, get, pop, keys, values, items, __getitem__, __setitem__, __delitem__, _getitem_extern, _setitem_extern, _delitem_extern, popitem, clear, update, setdefault, __reversed__, _detach
:members: __init__, __getattribute__, __setattr__, __delattr__, __contains__, __repr__, __iter__, __hash__, __eq__, _attr_extern, __len__, __bool__, __str__, __getstate__, __setstate__, get, pop, keys, values, items, __getitem__, __setitem__, __delitem__, _getitem_extern, _setitem_extern, _delitem_extern, popitem, clear, update, setdefault, __reversed__, _detach, unpack
.. _apidoc_tree_tree_delayed:
......
......@@ -787,4 +787,26 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]):
with pytest.raises(TypeError):
_ = tv3[0]
def test_unpack(self):
t1 = treevalue_class({
'a': 21, 'b': {'x': 'f-49', 'y': 7.7},
})
a, b = t1.unpack('a', 'b')
assert a == 21
assert isinstance(b, treevalue_class)
assert b.x == 'f-49'
assert b.y == pytest.approx(7.7)
x, y = b.unpack('x', 'y')
assert x == 'f-49'
assert y == pytest.approx(7.7)
with pytest.raises(KeyError):
_ = b.unpack('x', 'y', 'z')
x, y, z = b.unpack('x', 'y', 'z', default=None)
assert x == 'f-49'
assert y == pytest.approx(7.7)
assert z is None
return _TestClass
from collections import namedtuple
import pytest
from easydict import EasyDict
from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container, generic_mapping
nt = namedtuple('nt', ['a', 'b'])
class MyTreeValue(FastTreeValue):
pass
@pytest.mark.unittest
class TestTreeIntegrationGeneral:
def test_general_flatten_and_unflatten(self):
demo_data = {
'a': 1,
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': [34, '1.2'],
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
}
v, spec = generic_flatten(demo_data)
assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]
rv = generic_unflatten(v, spec)
assert rv == demo_data
assert isinstance(rv['c'][-1], EasyDict)
assert isinstance(rv['d'], nt)
assert isinstance(rv['c'][-1]['z'], list)
assert isinstance(rv['e'], MyTreeValue)
def test_register_my_class(self):
class MyDC:
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
return isinstance(other, MyDC) and self.x == other.x and self.y == other.y
def _mydc_flatten(v):
return [v.x, v.y], MyDC
def _mydc_unflatten(v, spec):
return spec(*v)
register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten)
demo_data = {
'a': 1,
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': MyDC(34, '1.2'),
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
}
v, spec = generic_flatten(demo_data)
assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]
rv = generic_unflatten(v, spec)
assert rv == demo_data
assert isinstance(rv['c'][-1], EasyDict)
assert isinstance(rv['d'], nt)
assert isinstance(rv['c'][-1]['z'], MyDC)
assert isinstance(rv['e'], MyTreeValue)
def test_generic_mapping(self):
demo_data = {
'a': 1,
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': (34, '1.2'),
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
}
assert generic_mapping(demo_data, str) == {
'a': '1',
'b': ['2', '3', 'f'],
'c': ('2', '5', 'ds', EasyDict({
'x': 'None',
'z': ('34', '1.2'),
})),
'd': nt('f', '100'),
'e': MyTreeValue({'x': '1', 'y': 'dsfljk'})
}
from unittest import skipUnless
import pytest
from treevalue import register_treevalue_class, FastTreeValue
try:
import torch
except (ImportError, ModuleNotFoundError):
torch = None
try:
import jax
except (ModuleNotFoundError, ImportError):
jax = None
@pytest.mark.unittest
class TestTreeIntegrationInit:
@skipUnless(torch and jax, 'Torch and jax required.')
def test_register_custom_class_all(self):
class MyTreeValue(FastTreeValue):
pass
with pytest.warns(None):
register_treevalue_class(MyTreeValue)
@skipUnless(not torch or not jax, 'Not all torch and jax required.')
def test_register_custom_class_some(self):
class MyTreeValue(FastTreeValue):
pass
with pytest.warns(UserWarning):
register_treevalue_class(MyTreeValue)
......@@ -3,7 +3,7 @@ from unittest import skipUnless
import numpy as np
import pytest
from treevalue import FastTreeValue
from treevalue import FastTreeValue, register_for_jax
try:
import jax
......@@ -26,5 +26,39 @@ class TestTreeTreeIntegration:
'y': np.random.randn(2, 3)
}
})
assert FastTreeValue.func()(np.isclose)(double(t1), t1 * 2 + 1.5).all() == \
r1 = double(t1)
assert type(r1) is FastTreeValue
assert FastTreeValue.func()(np.isclose)(r1, t1 * 2 + 1.5).all() == \
FastTreeValue({'a': True, 'b': {'x': True, 'y': True}})
class MyTreeValue(FastTreeValue):
pass
register_for_jax(MyTreeValue)
t2 = MyTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'b': {
'x': np.asarray(233.0),
'y': np.random.randn(2, 3)
}
})
r2 = double(t2)
assert type(r2) is MyTreeValue
assert MyTreeValue.func()(np.isclose)(r2, t2 * 2 + 1.5).all() == \
MyTreeValue({'a': True, 'b': {'x': True, 'y': True}})
@skipUnless(jax, 'Jax required.')
def test_error_register(self):
with pytest.raises(TypeError):
register_for_jax(None)
with pytest.raises(TypeError):
register_for_jax(list)
@skipUnless(not jax, 'No jax required')
def test_ignored_register(self):
class MyTreeValueX(FastTreeValue):
pass
with pytest.warns(UserWarning):
register_for_jax(MyTreeValueX)
from unittest import skipUnless
import pytest
from treevalue import FastTreeValue, register_for_torch
try:
import torch
except (ImportError, ModuleNotFoundError):
torch = None
@pytest.mark.unittest
class TestTreeIntegrationTorch:
@skipUnless(torch, 'Torch required.')
def test_flatten_and_unflatten(self):
arr1 = torch.randint(0, 10, (2, 3))
arr2 = torch.randn(2, 3)
t1 = FastTreeValue({'a': arr1, 'b': {'x': torch.asarray(233.0), 'y': arr2}})
flatted, spec = torch.utils._pytree.tree_flatten(t1)
assert isinstance(flatted, list)
assert len(flatted) == 3
assert torch.isclose(flatted[0], arr1).all()
assert torch.isclose(flatted[1], torch.asarray(233.0)).all()
assert torch.isclose(flatted[2], arr2).all()
newt = torch.utils._pytree.tree_unflatten(flatted, spec)
assert type(newt) == FastTreeValue
assert FastTreeValue.func()(torch.isclose)(t1, newt).all()
class MyTreeValue(FastTreeValue):
pass
register_for_torch(MyTreeValue)
t2 = MyTreeValue({'a': arr1, 'b': {'x': torch.asarray(233.0), 'y': arr2}})
flatted, spec = torch.utils._pytree.tree_flatten(t2)
assert isinstance(flatted, list)
assert len(flatted) == 3
assert torch.isclose(flatted[0], arr1).all()
assert torch.isclose(flatted[1], torch.asarray(233.0)).all()
assert torch.isclose(flatted[2], arr2).all()
newt2 = torch.utils._pytree.tree_unflatten(flatted, spec)
assert type(newt2) == MyTreeValue
assert MyTreeValue.func()(torch.isclose)(t2, newt2).all()
@skipUnless(torch, 'Torch required.')
def test_error_register(self):
with pytest.raises(TypeError):
register_for_torch(None)
with pytest.raises(TypeError):
register_for_torch(list)
@skipUnless(not torch, 'No torch required')
def test_ignored_register(self):
class MyTreeValueX(FastTreeValue):
pass
with pytest.warns(UserWarning):
register_for_torch(MyTreeValueX)
......@@ -687,6 +687,26 @@ def get_treevalue_test(treevalue_class: Type[TreeValue]):
assert a_found and b_found, f'Key {"a"!r} or {"b"!r} not found in {t1!r}.'
def test_unpack(self):
t1 = get_demo_constraint_tree()
a, b = t1.unpack('a', 'b')
assert a == 21
assert isinstance(b, treevalue_class)
assert b.x == 'f-49'
assert b.y == pytest.approx(7.7)
x, y = b.unpack('x', 'y')
assert x == 'f-49'
assert y == pytest.approx(7.7)
with pytest.raises(KeyError):
_ = b.unpack('x', 'y', 'z')
x, y, z = b.unpack('x', 'y', 'z', default=None)
assert x == 'f-49'
assert y == pytest.approx(7.7)
assert z is None
def test_with_constraints(self):
t1 = get_demo_constraint_tree()
t2 = t1.with_constraints(GreaterThanConstraint(10))
......
......@@ -7,7 +7,7 @@ Overview:
__TITLE__ = "treevalue"
#: Version of this project.
__VERSION__ = "1.4.6"
__VERSION__ = "1.4.7"
#: Short description of the project, will be included in ``setup.py``.
__DESCRIPTION__ = 'A flexible, generalized tree-based data structure.'
......
from typing import Type
from .general import generic_flatten, generic_unflatten, register_integrate_container, generic_mapping
from .jax import register_for_jax
from .torch import register_for_torch
from ..tree import TreeValue
def register_treevalue_class(cls: Type[TreeValue], r_jax: bool = True, r_torch: bool = True):
"""
Overview:
Register treevalue class into all existing types.
:param cls: TreeValue class.
:param r_jax: Register for jax, default is `True`.
:param r_torch: Register for torch, default is `True`.
"""
if r_jax:
register_for_jax(cls)
if r_torch:
register_for_torch(cls)
# distutils:language=c++
# cython:language_level=3
cdef tuple _c_flatten_for_integration(object tv)
cdef object _c_unflatten_for_integration(object values, tuple spec)
# distutils:language=c++
# cython:language_level=3
from ..tree.flatten cimport _c_flatten, _c_unflatten
cdef inline tuple _c_flatten_for_integration(object tv):
cdef list result = []
_c_flatten(tv._detach(), (), result)
cdef list paths = []
cdef list values = []
for path, value in result:
paths.append(path)
values.append(value)
return values, (type(tv), paths)
cdef inline object _c_unflatten_for_integration(object values, tuple spec):
cdef object type_
cdef list paths
type_, paths = spec
return type_(_c_unflatten(zip(paths, values)))
......@@ -3,26 +3,14 @@
import cython
from ..tree.flatten cimport _c_flatten, _c_unflatten
from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration
from ..tree.tree cimport TreeValue
cdef inline tuple _c_flatten_for_jax(object tv):
cdef list result = []
_c_flatten(tv._detach(), (), result)
cdef list paths = []
cdef list values = []
for path, value in result:
paths.append(path)
values.append(value)
return values, (type(tv), paths)
return _c_flatten_for_integration(tv)
cdef inline object _c_unflatten_for_jax(tuple aux, tuple values):
cdef object type_
cdef list paths
type_, paths = aux
return type_(_c_unflatten(zip(paths, values)))
return _c_unflatten_for_integration(values, aux)
@cython.binding(True)
cpdef void register_for_jax(object cls) except*:
......
# distutils:language=c++
# cython:language_level=3
cdef tuple _c_flatten_for_torch(object tv)
cdef object _c_unflatten_for_torch(list values, tuple context)
cpdef void register_for_torch(object cls) except*
\ No newline at end of file
# distutils:language=c++
# cython:language_level=3
import cython
from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration
from ..tree.tree cimport TreeValue
cdef inline tuple _c_flatten_for_torch(object tv):
return _c_flatten_for_integration(tv)
cdef inline object _c_unflatten_for_torch(list values, tuple context):
return _c_unflatten_for_integration(values, context)
@cython.binding(True)
cpdef void register_for_torch(object cls) except*:
"""
Overview:
Register treevalue class for torch's pytree library.
:param cls: TreeValue class.
Examples::
>>> from treevalue import FastTreeValue, TreeValue, register_for_torch
>>> register_for_torch(TreeValue)
>>> register_for_torch(FastTreeValue)
.. warning::
This method will put a warning message and then do nothing when torch is not installed.
"""
if isinstance(cls, type) and issubclass(cls, TreeValue):
import torch
torch.utils._pytree._register_pytree_node(cls, _c_flatten_for_torch, _c_unflatten_for_torch)
else:
raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.')
# distutils:language=c++
# cython:language_level=3
from libcpp cimport bool
cdef tuple _dict_flatten(object d)
cdef object _dict_unflatten(list values, tuple spec)
cdef tuple _list_and_tuple_flatten(object l)
cdef object _list_and_tuple_unflatten(list values, object spec)
cdef tuple _namedtuple_flatten(object l)
cdef object _namedtuple_unflatten(list values, object spec)
cdef tuple _treevalue_flatten(object l)
cdef object _treevalue_unflatten(list values, tuple spec)
cdef bool _is_namedtuple_instance(pytree) except*
cpdef void register_integrate_container(object type_, object flatten_func, object unflatten_func) except*
cdef tuple _c_get_flatted_values_and_spec(object v)
cdef object _c_get_object_from_flatted(object values, object type_, object spec)
cpdef object generic_flatten(object v)
cpdef object generic_unflatten(object v, tuple gspec)
cpdef object generic_mapping(object v, object func)
# distutils:language=c++
# cython:language_level=3
from collections import namedtuple
import cython
from libcpp cimport bool
from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration
from ..tree.tree cimport TreeValue
_REGISTERED_CONTAINERS = {}
cdef inline tuple _dict_flatten(object d):
cdef list values = []
cdef list keys = []
cdef object key, value
for key, value in d.items():
keys.append(key)
values.append(value)
return values, (type(d), keys)
cdef inline object _dict_unflatten(list values, tuple spec):
cdef object type_
cdef list keys
type_, keys = spec
cdef dict retval = {}
for key, value in zip(keys, values):
retval[key] = value
return type_(retval)
cdef inline tuple _list_and_tuple_flatten(object l):
return list(l), type(l)
cdef inline object _list_and_tuple_unflatten(list values, object spec):
return spec(values)
cdef inline tuple _namedtuple_flatten(object l):
return list(l), type(l)
cdef inline object _namedtuple_unflatten(list values, object spec):
return spec(*values)
cdef inline tuple _treevalue_flatten(object l):
return _c_flatten_for_integration(l)
cdef inline object _treevalue_unflatten(list values, tuple spec):
return _c_unflatten_for_integration(values, spec)
cdef inline bool _is_namedtuple_instance(pytree) except*:
cdef object typ = type(pytree)
cdef tuple bases = typ.__bases__
if len(bases) != 1 or bases[0] != tuple:
return False
fields = getattr(typ, '_fields', None)
if not isinstance(fields, tuple):
return False # pragma: no cover
return all(type(entry) == str for entry in fields)
@cython.binding(True)
cpdef inline void register_integrate_container(object type_, object flatten_func, object unflatten_func) except*:
"""
Overview:
Register custom data class for generic flatten and unflatten.
:param type_: Class of data to be registered.
:param flatten_func: Function for flattening.
:param unflatten_func: Function for unflattening.
Examples::
>>> from treevalue import register_integrate_container, generic_flatten, FastTreeValue, generic_unflatten
>>>
>>> class MyDC:
... def __init__(self, x, y):
... self.x = x
... self.y = y
...
... def __eq__(self, other):
... return isinstance(other, MyDC) and self.x == other.x and self.y == other.y
>>>
>>> def _mydc_flatten(v):
... return [v.x, v.y], MyDC
>>>
>>> def _mydc_unflatten(v, spec): # spec will be MyDC
... return spec(*v)
>>>
>>> register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten) # register MyDC
>>>
>>> v, spec = generic_flatten({'a': MyDC(2, 3), 'b': MyDC((4, 5), FastTreeValue({'x': 1, 'y': 'f'}))})
>>> v
[[2, 3], [[4, 5], [1, 'f']]]
>>>
>>> rt=generic_unflatten(v, spec)
>>> rt
{'a': <__main__.MyDC object at 0x7fbda613f9d0>, 'b': <__main__.MyDC object at 0x7fbda6148150>}
>>> rt['a'].x
2
>>> rt['a'].y
3
>>> rt['b'].x
(4, 5)
>>> rt['b'].y
<FastTreeValue 0x7fbda5aed510>
├── 'x' --> 1
└── 'y' --> 'f'
"""
_REGISTERED_CONTAINERS[type_] = (flatten_func, unflatten_func)
cdef inline tuple _c_get_flatted_values_and_spec(object v):
cdef list values
cdef object spec, type_
cdef object flatten_func
if isinstance(v, dict):
values, spec = _dict_flatten(v)
type_ = dict
elif _is_namedtuple_instance(v):
values, spec = _namedtuple_flatten(v)
type_ = namedtuple
elif isinstance(v, (list, tuple)):
values, spec = _list_and_tuple_flatten(v)
type_ = list
elif isinstance(v, TreeValue):
values, spec = _treevalue_flatten(v)
type_ = TreeValue
elif type(v) in _REGISTERED_CONTAINERS:
flatten_func, _ = _REGISTERED_CONTAINERS[type(v)]
values, spec = flatten_func(v)
type_ = type(v)
else:
return v, None, None
return values, type_, spec
cdef inline object _c_get_object_from_flatted(object values, object type_, object spec):
cdef object unflatten_func
if type_ is dict:
return _dict_unflatten(values, spec)
elif type_ is namedtuple:
return _namedtuple_unflatten(values, spec)
elif type_ is list:
return _list_and_tuple_unflatten(values, spec)
elif type_ is TreeValue:
return _treevalue_unflatten(values, spec)
elif type_ in _REGISTERED_CONTAINERS:
_, unflatten_func = _REGISTERED_CONTAINERS[type_]
return unflatten_func(values, spec)
else:
raise TypeError(f'Unknown type for unflatten - {values!r}, {spec!r}.') # pragma: no cover
@cython.binding(True)
cpdef inline object generic_flatten(object v):
"""
Overview:
Flatten generic data, including native objects, ``TreeValue``, namedtuples and custom classes \
(see :func:`register_integrate_container`).
:param v: Value to be flatted.
:return: Flatted value.
Examples::
>>> from collections import namedtuple
>>> from easydict import EasyDict
>>> from treevalue import FastTreeValue, generic_flatten, generic_unflatten
>>>
>>> class MyTreeValue(FastTreeValue):
... pass
>>>
>>> nt = namedtuple('nt', ['a', 'b'])
>>>
>>> origin = {
... 'a': 1,
... 'b': (2, 3, 'f',),
... 'c': (2, 5, 'ds', EasyDict({ # dict's child class
... 'x': None,
... 'z': [34, '1.2'],
... })),
... 'd': nt('f', 100), # namedtuple
... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue
... }
>>> v, spec = generic_flatten(origin)
>>> v
[1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]
>>>
>>> rv = generic_unflatten(v, spec)
>>> rv
{'a': 1, 'b': (2, 3, 'f'), 'c': (2, 5, 'ds', {'x': None, 'z': [34, '1.2']}), 'd': nt(a='f', b=100), 'e': <MyTreeValue 0x7fb6026d7b10>
├── 'x' --> 1
└── 'y' --> 'dsfljk'
}
>>> type(rv['c'][-1])
<class 'easydict.EasyDict'>
"""
values, type_, spec = _c_get_flatted_values_and_spec(v)
if type_ is None:
return values, (None, None, None)
cdef list child_values = []
cdef list child_specs = []
cdef object value, cval, cspec
for value in values:
cval, cspec = generic_flatten(value)
child_values.append(cval)
child_specs.append(cspec)
return child_values, (type_, spec, child_specs)
@cython.binding(True)
cpdef inline object generic_unflatten(object v, tuple gspec):
"""
Overview:
Inverse operation of :func:`generic_flatten`.
:param v: Flatted values.
:param gspec: Spec data of original object.
Examples::
See :func:`generic_flatten`.
"""
cdef object type_, spec
cdef list child_specs
type_, spec, child_specs = gspec
if type_ is None:
return v
cdef list values = []
cdef object _i_value, _i_spec
for _i_value, _i_spec in zip(v, child_specs):
values.append(generic_unflatten(_i_value, _i_spec))
return _c_get_object_from_flatted(values, type_, spec)
@cython.binding(True)
cpdef inline object generic_mapping(object v, object func):
"""
Overview:
Generic map all the values, including native objects, ``TreeValue``, namedtuples and custom classes \
(see :func:`register_integrate_container`)
:param v: Original value, nested structure is supported.
:param func: Function to operate.
Examples::
>>> from collections import namedtuple
>>> from easydict import EasyDict
>>> from treevalue import FastTreeValue, generic_mapping
>>>
>>> class MyTreeValue(FastTreeValue):
... pass
>>>
>>> nt = namedtuple('nt', ['a', 'b'])
>>>
>>> origin = {
... 'a': 1,
... 'b': (2, 3, 'f',),
... 'c': (2, 5, 'ds', EasyDict({ # dict's child class
... 'x': None,
... 'z': [34, '1.2'],
... })),
... 'd': nt('f', 100), # namedtuple
... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue
... }
>>> generic_mapping(origin, str)
{'a': '1', 'b': ('2', '3', 'f'), 'c': ('2', '5', 'ds', {'x': 'None', 'z': ['34', '1.2']}), 'd': nt(a='f', b='100'), 'e': <MyTreeValue 0x7f72e4d4ac90>
├── 'x' --> '1'
└── 'y' --> 'dsfljk'
}
"""
values, type_, spec = _c_get_flatted_values_and_spec(v)
if type_ is None:
return func(values)
cdef list retvals = []
cdef object value
for value in values:
retvals.append(generic_mapping(value, func))
return _c_get_object_from_flatted(retvals, type_, spec)
import warnings
from functools import wraps
try:
import torch
except (ModuleNotFoundError, ImportError):
from .ctorch import register_for_torch as _original_register_for_torch
@wraps(_original_register_for_torch)
def register_for_torch(cls):
warnings.warn(f'Torch is not installed, registration of {cls!r} will be ignored.')
else:
from .ctorch import register_for_torch
from ..tree import TreeValue
from ..general import FastTreeValue
register_for_torch(TreeValue)
register_for_torch(FastTreeValue)
......@@ -45,6 +45,8 @@ cdef class TreeValue:
cpdef public treevalue_values values(self)
cpdef public treevalue_items items(self)
cdef tuple _unpack(self, tuple keys, object default=*)
cpdef void validate(self) except*
cdef str _prefix_fix(object text, object prefix)
......
......@@ -913,6 +913,46 @@ cdef class TreeValue:
"""
return treevalue_items(self, self._st)
cdef tuple _unpack(self, tuple keys, object default=_GET_NO_DEFAULT):
cdef object key
cdef list result = []
for key in keys:
if self._st.contains(key) or default is not _GET_NO_DEFAULT:
result.append(self.get(key, default))
else:
raise KeyError(f'Key not found - {key!r}.')
return tuple(result)
@cython.binding(True)
def unpack(self, *keys, default=_GET_NO_DEFAULT):
"""
Unpack values from current treevalue object.
:param keys: Keys to unpack.
:param default: Default value when key not found. Raise `KeyError` when not given.
Examples::
>>> from treevalue import FastTreeValue
>>> t1 = FastTreeValue({
... 'a': 21, 'b': {'x': 'f-49', 'y': 7.7},
... })
>>>
>>> t1.unpack('a', 'b')
(21, <FastTreeValue 0x7ffb77fc2850>
├── 'x' --> 'f-49'
└── 'y' --> 7.7
)
>>> t1.b.unpack('x', 'y')
('f-49', 7.7)
>>> t1.b.unpack('x', 'y', 'z')
KeyError: "Key not found - 'z'."
>>> t1.b.unpack('x', 'y', 'z', default=None)
('f-49', 7.7, None)
"""
return self._unpack(keys, default)
@cython.binding(True)
cpdef void validate(self) except*:
cdef bool retval
......