提交 c38bf233 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add generic_mapping function

上级 60d6b058
from collections import namedtuple
from dataclasses import dataclass
import pytest
from easydict import EasyDict
from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container
@dataclass
class DC:
x: int
y: str
from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container, generic_mapping
nt = namedtuple('nt', ['a', 'b'])
......@@ -28,7 +20,7 @@ class TestTreeIntegrationGeneral:
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': DC(34, '1.2'),
'z': [34, '1.2'],
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
......@@ -40,7 +32,7 @@ class TestTreeIntegrationGeneral:
assert rv == demo_data
assert isinstance(rv['c'][-1], EasyDict)
assert isinstance(rv['d'], nt)
assert isinstance(rv['c'][-1]['z'], DC)
assert isinstance(rv['c'][-1]['z'], list)
assert isinstance(rv['e'], MyTreeValue)
def test_register_my_class(self):
......@@ -79,3 +71,25 @@ class TestTreeIntegrationGeneral:
assert isinstance(rv['d'], nt)
assert isinstance(rv['c'][-1]['z'], MyDC)
assert isinstance(rv['e'], MyTreeValue)
def test_generic_mapping(self):
demo_data = {
'a': 1,
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': (34, '1.2'),
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
}
assert generic_mapping(demo_data, str) == {
'a': '1',
'b': ['2', '3', 'f'],
'c': ('2', '5', 'ds', EasyDict({
'x': 'None',
'z': ('34', '1.2'),
})),
'd': nt('f', '100'),
'e': MyTreeValue({'x': '1', 'y': 'dsfljk'})
}
from typing import Type
from .general import generic_flatten, generic_unflatten, register_integrate_container
from .general import generic_flatten, generic_unflatten, register_integrate_container, generic_mapping
from .jax import register_for_jax
from .torch import register_for_torch
from ..tree import TreeValue
......
......@@ -12,9 +12,6 @@ cdef object _list_and_tuple_unflatten(list values, object spec)
cdef tuple _namedtuple_flatten(object l)
cdef object _namedtuple_unflatten(list values, object spec)
cdef tuple _dataclass_flatten(object l)
cdef object _dataclass_unflatten(list values, tuple spec)
cdef tuple _treevalue_flatten(object l)
cdef object _treevalue_unflatten(list values, tuple spec)
......@@ -22,5 +19,9 @@ cdef bool _is_namedtuple_instance(pytree) except*
cpdef void register_integrate_container(object type_, object flatten_func, object unflatten_func) except*
cdef tuple _c_get_flatted_values_and_spec(object v)
cdef object _c_get_object_from_flatted(object values, object type_, object spec)
cpdef object generic_flatten(object v)
cpdef object generic_unflatten(object v, tuple gspec)
cpdef object generic_mapping(object v, object func)
......@@ -2,7 +2,6 @@
# cython:language_level=3
from collections import namedtuple
from dataclasses import dataclass, is_dataclass
import cython
from libcpp cimport bool
......@@ -46,23 +45,6 @@ cdef inline tuple _namedtuple_flatten(object l):
cdef inline object _namedtuple_unflatten(list values, object spec):
return spec(*values)
cdef inline tuple _dataclass_flatten(object l):
cdef object type_ = type(l)
cdef list keys = []
cdef list values = []
for key in type_.__dataclass_fields__.keys():
keys.append(key)
values.append(getattr(l, key))
return values, (type_, keys)
cdef inline object _dataclass_unflatten(list values, tuple spec):
cdef object type_
cdef list keys
type_, keys = spec
return type_(**{key: value for key, value in zip(keys, values)})
cdef inline tuple _treevalue_flatten(object l):
return _c_flatten_for_integration(l)
......@@ -132,40 +114,73 @@ cpdef inline void register_integrate_container(object type_, object flatten_func
"""
_REGISTERED_CONTAINERS[type_] = (flatten_func, unflatten_func)
cdef inline tuple _c_get_flatted_values_and_spec(object v):
cdef list values
cdef object spec, type_
cdef object flatten_func
if isinstance(v, dict):
values, spec = _dict_flatten(v)
type_ = dict
elif _is_namedtuple_instance(v):
values, spec = _namedtuple_flatten(v)
type_ = namedtuple
elif isinstance(v, (list, tuple)):
values, spec = _list_and_tuple_flatten(v)
type_ = list
elif isinstance(v, TreeValue):
values, spec = _treevalue_flatten(v)
type_ = TreeValue
elif type(v) in _REGISTERED_CONTAINERS:
flatten_func, _ = _REGISTERED_CONTAINERS[type(v)]
values, spec = flatten_func(v)
type_ = type(v)
else:
return v, None, None
return values, type_, spec
cdef inline object _c_get_object_from_flatted(object values, object type_, object spec):
cdef object unflatten_func
if type_ is dict:
return _dict_unflatten(values, spec)
elif type_ is namedtuple:
return _namedtuple_unflatten(values, spec)
elif type_ is list:
return _list_and_tuple_unflatten(values, spec)
elif type_ is TreeValue:
return _treevalue_unflatten(values, spec)
elif type_ in _REGISTERED_CONTAINERS:
_, unflatten_func = _REGISTERED_CONTAINERS[type_]
return unflatten_func(values, spec)
else:
raise TypeError(f'Unknown type for unflatten - {values!r}, {spec!r}.') # pragma: no cover
@cython.binding(True)
cpdef inline object generic_flatten(object v):
"""
Overview:
Flatten generic data, including native objects, ``TreeValue``, namedtuples and dataclasses.
Flatten generic data, including native objects, ``TreeValue``, namedtuples.
:param v: Value to be flatted.
:return: Flatted value.
Examples::
>>> from collections import namedtuple
>>> from dataclasses import dataclass
>>> from easydict import EasyDict
>>> from treevalue import FastTreeValue, generic_flatten, generic_unflatten
>>>
>>> class MyTreeValue(FastTreeValue):
... pass
>>>
>>> @dataclass
... class DC:
... x: int
... y: float
...
... def __repr__(self):
... return f'DC({self.x}, {self.y})'
>>>
>>> nt = namedtuple('nt', ['a', 'b'])
>>>
>>> origin = {
... 'a': 1,
... 'b': [2, 3, 'f', ],
... 'b': (2, 3, 'f',),
... 'c': (2, 5, 'ds', EasyDict({ # dict's child class
... 'x': None,
... 'z': DC(34, '1.2'), # dataclass
... 'z': [34, '1.2'], # dataclass
... })),
... 'd': nt('f', 100), # namedtuple
... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue
......@@ -175,38 +190,17 @@ cpdef inline object generic_flatten(object v):
[1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]
>>>
>>> rv = generic_unflatten(v, spec)
>>> rv # all the data, including types, are recovered
{'a': 1, 'b': [2, 3, 'f'], 'c': (2, 5, 'ds', {'x': None, 'z': DC(34, 1.2)}), 'd': nt(a='f', b=100), 'e': <MyTreeValue 0x7fba23ef9c50>
>>> rv
{'a': 1, 'b': (2, 3, 'f'), 'c': (2, 5, 'ds', {'x': None, 'z': [34, '1.2']}), 'd': nt(a='f', b=100), 'e': <MyTreeValue 0x7fb6026d7b10>
├── 'x' --> 1
└── 'y' --> 'dsfljk'
}
>>> type(rv['c'][-1])
<class 'easydict.EasyDict'>
"""
cdef list values
cdef object spec, type_
cdef object flatten_func
if isinstance(v, dict):
values, spec = _dict_flatten(v)
type_ = dict
elif _is_namedtuple_instance(v):
values, spec = _namedtuple_flatten(v)
type_ = namedtuple
elif isinstance(v, (list, tuple)):
values, spec = _list_and_tuple_flatten(v)
type_ = list
elif is_dataclass(v):
values, spec = _dataclass_flatten(v)
type_ = dataclass
elif isinstance(v, TreeValue):
values, spec = _treevalue_flatten(v)
type_ = TreeValue
elif type(v) in _REGISTERED_CONTAINERS:
flatten_func, _ = _REGISTERED_CONTAINERS[type(v)]
values, spec = flatten_func(v)
type_ = type(v)
else:
return v, (None, None, None)
values, type_, spec = _c_get_flatted_values_and_spec(v)
if type_ is None:
return values, (None, None, None)
cdef list child_values = []
cdef list child_specs = []
......@@ -241,19 +235,17 @@ cpdef inline object generic_unflatten(object v, tuple gspec):
for _i_value, _i_spec in zip(v, child_specs):
values.append(generic_unflatten(_i_value, _i_spec))
cdef object unflatten_func
if type_ is dict:
return _dict_unflatten(values, spec)
elif type_ is namedtuple:
return _namedtuple_unflatten(values, spec)
elif type_ is list:
return _list_and_tuple_unflatten(values, spec)
elif type_ is dataclass:
return _dataclass_unflatten(values, spec)
elif type_ is TreeValue:
return _treevalue_unflatten(values, spec)
elif type_ in _REGISTERED_CONTAINERS:
_, unflatten_func = _REGISTERED_CONTAINERS[type_]
return unflatten_func(values, spec)
else:
raise TypeError(f'Unknown type for unflatten - {values!r}, {gspec!r}.') # pragma: no cover
return _c_get_object_from_flatted(values, type_, spec)
@cython.binding(True)
cpdef inline object generic_mapping(object v, object func):
values, type_, spec = _c_get_flatted_values_and_spec(v)
if type_ is None:
return func(values)
cdef list retvals = []
cdef object value
for value in values:
retvals.append(generic_mapping(value, func))
return _c_get_object_from_flatted(retvals, type_, spec)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册