Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
treevalue
提交
a504c0e6
T
treevalue
项目概览
OpenDILab开源决策智能平台
/
treevalue
9 个月 前同步成功
通知
3
Star
213
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
treevalue
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
a504c0e6
编写于
2月 27, 2023
作者:
HansBug
😆
1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): add generic_flatten and generic_unflatten
上级
874bd09c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
392 addition
and
1 deletion
+392
-1
docs/source/api_doc/tree/integration.rst
docs/source/api_doc/tree/integration.rst
+25
-0
test/tree/integration/test_general.py
test/tree/integration/test_general.py
+81
-0
treevalue/tree/integration/__init__.py
treevalue/tree/integration/__init__.py
+1
-0
treevalue/tree/integration/base.pyx
treevalue/tree/integration/base.pyx
+0
-1
treevalue/tree/integration/general.pxd
treevalue/tree/integration/general.pxd
+26
-0
treevalue/tree/integration/general.pyx
treevalue/tree/integration/general.pyx
+259
-0
未找到文件。
docs/source/api_doc/tree/integration.rst
浏览文件 @
a504c0e6
...
...
@@ -27,3 +27,28 @@ register_treevalue_class
.. autofunction:: register_treevalue_class
.. _apidoc_tree_integration_register_integrate_container:
register_integrate_container
--------------------------------
.. autofunction:: register_integrate_container
.. _apidoc_tree_integration_generic_flatten:
generic_flatten
--------------------------------
.. autofunction:: generic_flatten
.. _apidoc_tree_integration_generic_unflatten:
generic_unflatten
--------------------------------
.. autofunction:: generic_unflatten
test/tree/integration/test_general.py
0 → 100644
浏览文件 @
a504c0e6
from
collections
import
namedtuple
from
dataclasses
import
dataclass
import
pytest
from
easydict
import
EasyDict
from
treevalue
import
generic_flatten
,
generic_unflatten
,
FastTreeValue
,
register_integrate_container
@
dataclass
class
DC
:
x
:
int
y
:
str
nt
=
namedtuple
(
'nt'
,
[
'a'
,
'b'
])
class
MyTreeValue
(
FastTreeValue
):
pass
@
pytest
.
mark
.
unittest
class
TestTreeIntegrationGeneral
:
def
test_general_flatten_and_unflatten
(
self
):
demo_data
=
{
'a'
:
1
,
'b'
:
[
2
,
3
,
'f'
],
'c'
:
(
2
,
5
,
'ds'
,
EasyDict
({
'x'
:
None
,
'z'
:
DC
(
34
,
'1.2'
),
})),
'd'
:
nt
(
'f'
,
100
),
'e'
:
MyTreeValue
({
'x'
:
1
,
'y'
:
'dsfljk'
})
}
v
,
spec
=
generic_flatten
(
demo_data
)
assert
v
==
[
1
,
[
2
,
3
,
'f'
],
[
2
,
5
,
'ds'
,
[
None
,
[
34
,
'1.2'
]]],
[
'f'
,
100
],
[
1
,
'dsfljk'
]]
rv
=
generic_unflatten
(
v
,
spec
)
assert
rv
==
demo_data
assert
isinstance
(
rv
[
'c'
][
-
1
],
EasyDict
)
assert
isinstance
(
rv
[
'd'
],
nt
)
assert
isinstance
(
rv
[
'c'
][
-
1
][
'z'
],
DC
)
assert
isinstance
(
rv
[
'e'
],
MyTreeValue
)
def
test_register_my_class
(
self
):
class
MyDC
:
def
__init__
(
self
,
x
,
y
):
self
.
x
=
x
self
.
y
=
y
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyDC
)
and
self
.
x
==
other
.
x
and
self
.
y
==
other
.
y
def
_mydc_flatten
(
v
):
return
[
v
.
x
,
v
.
y
],
MyDC
def
_mydc_unflatten
(
v
,
spec
):
return
spec
(
*
v
)
register_integrate_container
(
MyDC
,
_mydc_flatten
,
_mydc_unflatten
)
demo_data
=
{
'a'
:
1
,
'b'
:
[
2
,
3
,
'f'
],
'c'
:
(
2
,
5
,
'ds'
,
EasyDict
({
'x'
:
None
,
'z'
:
MyDC
(
34
,
'1.2'
),
})),
'd'
:
nt
(
'f'
,
100
),
'e'
:
MyTreeValue
({
'x'
:
1
,
'y'
:
'dsfljk'
})
}
v
,
spec
=
generic_flatten
(
demo_data
)
assert
v
==
[
1
,
[
2
,
3
,
'f'
],
[
2
,
5
,
'ds'
,
[
None
,
[
34
,
'1.2'
]]],
[
'f'
,
100
],
[
1
,
'dsfljk'
]]
rv
=
generic_unflatten
(
v
,
spec
)
assert
rv
==
demo_data
assert
isinstance
(
rv
[
'c'
][
-
1
],
EasyDict
)
assert
isinstance
(
rv
[
'd'
],
nt
)
assert
isinstance
(
rv
[
'c'
][
-
1
][
'z'
],
MyDC
)
assert
isinstance
(
rv
[
'e'
],
MyTreeValue
)
treevalue/tree/integration/__init__.py
浏览文件 @
a504c0e6
from
typing
import
Type
from
.general
import
generic_flatten
,
generic_unflatten
,
register_integrate_container
from
.jax
import
register_for_jax
from
.torch
import
register_for_torch
from
..tree
import
TreeValue
...
...
treevalue/tree/integration/base.pyx
浏览文件 @
a504c0e6
...
...
@@ -14,7 +14,6 @@ cdef inline tuple _c_flatten_for_integration(object tv):
values
.
append
(
value
)
return
values
,
(
type
(
tv
),
paths
)
pass
cdef
inline
object
_c_unflatten_for_integration
(
object
values
,
tuple
spec
):
cdef
object
type_
...
...
treevalue/tree/integration/general.pxd
0 → 100644
浏览文件 @
a504c0e6
# distutils:language=c++
# cython:language_level=3
from
libcpp
cimport
bool
cdef
tuple
_dict_flatten
(
object
d
)
cdef
object
_dict_unflatten
(
list
values
,
tuple
spec
)
cdef
tuple
_list_and_tuple_flatten
(
object
l
)
cdef
object
_list_and_tuple_unflatten
(
list
values
,
object
spec
)
cdef
tuple
_namedtuple_flatten
(
object
l
)
cdef
object
_namedtuple_unflatten
(
list
values
,
object
spec
)
cdef
tuple
_dataclass_flatten
(
object
l
)
cdef
object
_dataclass_unflatten
(
list
values
,
tuple
spec
)
cdef
tuple
_treevalue_flatten
(
object
l
)
cdef
object
_treevalue_unflatten
(
list
values
,
tuple
spec
)
cdef
bool
_is_namedtuple_instance
(
pytree
)
except
*
cpdef
void
register_integrate_container
(
object
type_
,
object
flatten_func
,
object
unflatten_func
)
except
*
cpdef
object
generic_flatten
(
object
v
)
cpdef
object
generic_unflatten
(
object
v
,
tuple
gspec
)
treevalue/tree/integration/general.pyx
0 → 100644
浏览文件 @
a504c0e6
# distutils:language=c++
# cython:language_level=3
from
collections
import
namedtuple
from
dataclasses
import
dataclass
,
is_dataclass
import
cython
from
libcpp
cimport
bool
from
.base
cimport
_c_flatten_for_integration
,
_c_unflatten_for_integration
from
..tree.tree
cimport
TreeValue
_REGISTERED_CONTAINERS
=
{}
cdef
inline
tuple
_dict_flatten
(
object
d
):
cdef
list
values
=
[]
cdef
list
keys
=
[]
cdef
object
key
,
value
for
key
,
value
in
d
.
items
():
keys
.
append
(
key
)
values
.
append
(
value
)
return
values
,
(
type
(
d
),
keys
)
cdef
inline
object
_dict_unflatten
(
list
values
,
tuple
spec
):
cdef
object
type_
cdef
list
keys
type_
,
keys
=
spec
cdef
dict
retval
=
{}
for
key
,
value
in
zip
(
keys
,
values
):
retval
[
key
]
=
value
return
type_
(
retval
)
cdef
inline
tuple
_list_and_tuple_flatten
(
object
l
):
return
list
(
l
),
type
(
l
)
cdef
inline
object
_list_and_tuple_unflatten
(
list
values
,
object
spec
):
return
spec
(
values
)
cdef
inline
tuple
_namedtuple_flatten
(
object
l
):
return
list
(
l
),
type
(
l
)
cdef
inline
object
_namedtuple_unflatten
(
list
values
,
object
spec
):
return
spec
(
*
values
)
cdef
inline
tuple
_dataclass_flatten
(
object
l
):
cdef
object
type_
=
type
(
l
)
cdef
list
keys
=
[]
cdef
list
values
=
[]
for
key
in
type_
.
__dataclass_fields__
.
keys
():
keys
.
append
(
key
)
values
.
append
(
getattr
(
l
,
key
))
return
values
,
(
type_
,
keys
)
cdef
inline
object
_dataclass_unflatten
(
list
values
,
tuple
spec
):
cdef
object
type_
cdef
list
keys
type_
,
keys
=
spec
return
type_
(
**
{
key
:
value
for
key
,
value
in
zip
(
keys
,
values
)})
cdef
inline
tuple
_treevalue_flatten
(
object
l
):
return
_c_flatten_for_integration
(
l
)
cdef
inline
object
_treevalue_unflatten
(
list
values
,
tuple
spec
):
return
_c_unflatten_for_integration
(
values
,
spec
)
cdef
inline
bool
_is_namedtuple_instance
(
pytree
)
except
*
:
cdef
object
typ
=
type
(
pytree
)
cdef
tuple
bases
=
typ
.
__bases__
if
len
(
bases
)
!=
1
or
bases
[
0
]
!=
tuple
:
return
False
fields
=
getattr
(
typ
,
'_fields'
,
None
)
if
not
isinstance
(
fields
,
tuple
):
return
False
# pragma: no cover
return
all
(
type
(
entry
)
==
str
for
entry
in
fields
)
@
cython
.
binding
(
True
)
cpdef
inline
void
register_integrate_container
(
object
type_
,
object
flatten_func
,
object
unflatten_func
)
except
*
:
"""
Overview:
Register custom data class for generic flatten and unflatten.
:param type_: Class of data to be registered.
:param flatten_func: Function for flattening.
:param unflatten_func: Function for unflattening.
Examples::
>>> from treevalue import register_integrate_container, generic_flatten, FastTreeValue, generic_unflatten
>>>
>>> class MyDC:
... def __init__(self, x, y):
... self.x = x
... self.y = y
...
... def __eq__(self, other):
... return isinstance(other, MyDC) and self.x == other.x and self.y == other.y
>>>
>>> def _mydc_flatten(v):
... return [v.x, v.y], MyDC
>>>
>>> def _mydc_unflatten(v, spec): # spec will be MyDC
... return spec(*v)
>>>
>>> register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten) # register MyDC
>>>
>>> v, spec = generic_flatten({'a': MyDC(2, 3), 'b': MyDC((4, 5), FastTreeValue({'x': 1, 'y': 'f'}))})
>>> v
[[2, 3], [[4, 5], [1, 'f']]]
>>>
>>> rt=generic_unflatten(v, spec)
>>> rt
{'a': <__main__.MyDC object at 0x7fbda613f9d0>, 'b': <__main__.MyDC object at 0x7fbda6148150>}
>>> rt['a'].x
2
>>> rt['a'].y
3
>>> rt['b'].x
(4, 5)
>>> rt['b'].y
<FastTreeValue 0x7fbda5aed510>
├── 'x' --> 1
└── 'y' --> 'f'
"""
_REGISTERED_CONTAINERS
[
type_
]
=
(
flatten_func
,
unflatten_func
)
@
cython
.
binding
(
True
)
cpdef
inline
object
generic_flatten
(
object
v
):
"""
Overview:
Flatten generic data, including native objects, ``TreeValue``, namedtuples and dataclasses.
:param v: Value to be flatted.
:return: Flatted value.
Examples::
>>> from collections import namedtuple
>>> from dataclasses import dataclass
>>> from easydict import EasyDict
>>> from treevalue import FastTreeValue, generic_flatten, generic_unflatten
>>>
>>> class MyTreeValue(FastTreeValue):
... pass
>>>
>>> @dataclass
... class DC:
... x: int
... y: float
...
... def __repr__(self):
... return f'DC({self.x}, {self.y})'
>>>
>>> nt = namedtuple('nt', ['a', 'b'])
>>>
>>> origin = {
... 'a': 1,
... 'b': [2, 3, 'f', ],
... 'c': (2, 5, 'ds', EasyDict({ # dict's child class
... 'x': None,
... 'z': DC(34, '1.2'), # dataclass
... })),
... 'd': nt('f', 100), # namedtuple
... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue
... }
>>> v, spec = generic_flatten(origin)
>>> v
[1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]
>>>
>>> rv = generic_unflatten(v, spec)
>>> rv # all the data, including types, are recovered
{'a': 1, 'b': [2, 3, 'f'], 'c': (2, 5, 'ds', {'x': None, 'z': DC(34, 1.2)}), 'd': nt(a='f', b=100), 'e': <MyTreeValue 0x7fba23ef9c50>
├── 'x' --> 1
└── 'y' --> 'dsfljk'
}
>>> type(rv['c'][-1])
<class 'easydict.EasyDict'>
"""
cdef
list
values
cdef
object
spec
,
type_
cdef
object
flatten_func
if
isinstance
(
v
,
dict
):
values
,
spec
=
_dict_flatten
(
v
)
type_
=
dict
elif
_is_namedtuple_instance
(
v
):
values
,
spec
=
_namedtuple_flatten
(
v
)
type_
=
namedtuple
elif
isinstance
(
v
,
(
list
,
tuple
)):
values
,
spec
=
_list_and_tuple_flatten
(
v
)
type_
=
list
elif
is_dataclass
(
v
):
values
,
spec
=
_dataclass_flatten
(
v
)
type_
=
dataclass
elif
isinstance
(
v
,
TreeValue
):
values
,
spec
=
_treevalue_flatten
(
v
)
type_
=
TreeValue
elif
type
(
v
)
in
_REGISTERED_CONTAINERS
:
flatten_func
,
_
=
_REGISTERED_CONTAINERS
[
type
(
v
)]
values
,
spec
=
flatten_func
(
v
)
type_
=
type
(
v
)
else
:
return
v
,
(
None
,
None
,
None
)
cdef
list
child_values
=
[]
cdef
list
child_specs
=
[]
cdef
object
value
,
cval
,
cspec
for
value
in
values
:
cval
,
cspec
=
generic_flatten
(
value
)
child_values
.
append
(
cval
)
child_specs
.
append
(
cspec
)
return
child_values
,
(
type_
,
spec
,
child_specs
)
@
cython
.
binding
(
True
)
cpdef
inline
object
generic_unflatten
(
object
v
,
tuple
gspec
):
"""
Overview:
Inverse operation of :func:`generic_flatten`.
:param v: Flatted values.
:param gspec: Spec data of original object.
Examples::
See :func:`generic_flatten`.
"""
cdef
object
type_
,
spec
cdef
list
child_specs
type_
,
spec
,
child_specs
=
gspec
if
type_
is
None
:
return
v
cdef
list
values
=
[]
cdef
object
_i_value
,
_i_spec
for
_i_value
,
_i_spec
in
zip
(
v
,
child_specs
):
values
.
append
(
generic_unflatten
(
_i_value
,
_i_spec
))
cdef
object
unflatten_func
if
type_
is
dict
:
return
_dict_unflatten
(
values
,
spec
)
elif
type_
is
namedtuple
:
return
_namedtuple_unflatten
(
values
,
spec
)
elif
type_
is
list
:
return
_list_and_tuple_unflatten
(
values
,
spec
)
elif
type_
is
dataclass
:
return
_dataclass_unflatten
(
values
,
spec
)
elif
type_
is
TreeValue
:
return
_treevalue_unflatten
(
values
,
spec
)
elif
type_
in
_REGISTERED_CONTAINERS
:
_
,
unflatten_func
=
_REGISTERED_CONTAINERS
[
type_
]
return
unflatten_func
(
values
,
spec
)
else
:
raise
TypeError
(
f
'Unknown type for unflatten -
{
values
!
r
}
,
{
gspec
!
r
}
.'
)
# pragma: no cover
HansBug
😆
@HansBug
mentioned in commit
36819a68
·
3月 01, 2023
mentioned in commit
36819a68
mentioned in commit 36819a68a5e9888c350885ae91f3997fa5f779bd
开关提交列表
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录