Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
treevalue
提交
c38bf233
T
treevalue
项目概览
OpenDILab开源决策智能平台
/
treevalue
9 个月 前同步成功
通知
3
Star
213
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
treevalue
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
c38bf233
编写于
2月 27, 2023
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): add generic_mapping function
上级
60d6b058
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
94 addition
and
87 deletion
+94
-87
test/tree/integration/test_general.py
test/tree/integration/test_general.py
+25
-11
treevalue/tree/integration/__init__.py
treevalue/tree/integration/__init__.py
+1
-1
treevalue/tree/integration/general.pxd
treevalue/tree/integration/general.pxd
+4
-3
treevalue/tree/integration/general.pyx
treevalue/tree/integration/general.pyx
+64
-72
未找到文件。
test/tree/integration/test_general.py
浏览文件 @
c38bf233
from
collections
import
namedtuple
from
dataclasses
import
dataclass
import
pytest
from
easydict
import
EasyDict
from
treevalue
import
generic_flatten
,
generic_unflatten
,
FastTreeValue
,
register_integrate_container
@
dataclass
class
DC
:
x
:
int
y
:
str
from
treevalue
import
generic_flatten
,
generic_unflatten
,
FastTreeValue
,
register_integrate_container
,
generic_mapping
nt
=
namedtuple
(
'nt'
,
[
'a'
,
'b'
])
...
...
@@ -28,7 +20,7 @@ class TestTreeIntegrationGeneral:
'b'
:
[
2
,
3
,
'f'
],
'c'
:
(
2
,
5
,
'ds'
,
EasyDict
({
'x'
:
None
,
'z'
:
DC
(
34
,
'1.2'
)
,
'z'
:
[
34
,
'1.2'
]
,
})),
'd'
:
nt
(
'f'
,
100
),
'e'
:
MyTreeValue
({
'x'
:
1
,
'y'
:
'dsfljk'
})
...
...
@@ -40,7 +32,7 @@ class TestTreeIntegrationGeneral:
assert
rv
==
demo_data
assert
isinstance
(
rv
[
'c'
][
-
1
],
EasyDict
)
assert
isinstance
(
rv
[
'd'
],
nt
)
assert
isinstance
(
rv
[
'c'
][
-
1
][
'z'
],
DC
)
assert
isinstance
(
rv
[
'c'
][
-
1
][
'z'
],
list
)
assert
isinstance
(
rv
[
'e'
],
MyTreeValue
)
def
test_register_my_class
(
self
):
...
...
@@ -79,3 +71,25 @@ class TestTreeIntegrationGeneral:
assert
isinstance
(
rv
[
'd'
],
nt
)
assert
isinstance
(
rv
[
'c'
][
-
1
][
'z'
],
MyDC
)
assert
isinstance
(
rv
[
'e'
],
MyTreeValue
)
def
test_generic_mapping
(
self
):
demo_data
=
{
'a'
:
1
,
'b'
:
[
2
,
3
,
'f'
],
'c'
:
(
2
,
5
,
'ds'
,
EasyDict
({
'x'
:
None
,
'z'
:
(
34
,
'1.2'
),
})),
'd'
:
nt
(
'f'
,
100
),
'e'
:
MyTreeValue
({
'x'
:
1
,
'y'
:
'dsfljk'
})
}
assert
generic_mapping
(
demo_data
,
str
)
==
{
'a'
:
'1'
,
'b'
:
[
'2'
,
'3'
,
'f'
],
'c'
:
(
'2'
,
'5'
,
'ds'
,
EasyDict
({
'x'
:
'None'
,
'z'
:
(
'34'
,
'1.2'
),
})),
'd'
:
nt
(
'f'
,
'100'
),
'e'
:
MyTreeValue
({
'x'
:
'1'
,
'y'
:
'dsfljk'
})
}
treevalue/tree/integration/__init__.py
浏览文件 @
c38bf233
from
typing
import
Type
from
.general
import
generic_flatten
,
generic_unflatten
,
register_integrate_container
from
.general
import
generic_flatten
,
generic_unflatten
,
register_integrate_container
,
generic_mapping
from
.jax
import
register_for_jax
from
.torch
import
register_for_torch
from
..tree
import
TreeValue
...
...
treevalue/tree/integration/general.pxd
浏览文件 @
c38bf233
...
...
@@ -12,9 +12,6 @@ cdef object _list_and_tuple_unflatten(list values, object spec)
cdef
tuple
_namedtuple_flatten
(
object
l
)
cdef
object
_namedtuple_unflatten
(
list
values
,
object
spec
)
cdef
tuple
_dataclass_flatten
(
object
l
)
cdef
object
_dataclass_unflatten
(
list
values
,
tuple
spec
)
cdef
tuple
_treevalue_flatten
(
object
l
)
cdef
object
_treevalue_unflatten
(
list
values
,
tuple
spec
)
...
...
@@ -22,5 +19,9 @@ cdef bool _is_namedtuple_instance(pytree) except*
cpdef
void
register_integrate_container
(
object
type_
,
object
flatten_func
,
object
unflatten_func
)
except
*
cdef
tuple
_c_get_flatted_values_and_spec
(
object
v
)
cdef
object
_c_get_object_from_flatted
(
object
values
,
object
type_
,
object
spec
)
cpdef
object
generic_flatten
(
object
v
)
cpdef
object
generic_unflatten
(
object
v
,
tuple
gspec
)
cpdef
object
generic_mapping
(
object
v
,
object
func
)
treevalue/tree/integration/general.pyx
浏览文件 @
c38bf233
...
...
@@ -2,7 +2,6 @@
# cython:language_level=3
from
collections
import
namedtuple
from
dataclasses
import
dataclass
,
is_dataclass
import
cython
from
libcpp
cimport
bool
...
...
@@ -46,23 +45,6 @@ cdef inline tuple _namedtuple_flatten(object l):
cdef
inline
object
_namedtuple_unflatten
(
list
values
,
object
spec
):
return
spec
(
*
values
)
cdef
inline
tuple
_dataclass_flatten
(
object
l
):
cdef
object
type_
=
type
(
l
)
cdef
list
keys
=
[]
cdef
list
values
=
[]
for
key
in
type_
.
__dataclass_fields__
.
keys
():
keys
.
append
(
key
)
values
.
append
(
getattr
(
l
,
key
))
return
values
,
(
type_
,
keys
)
cdef
inline
object
_dataclass_unflatten
(
list
values
,
tuple
spec
):
cdef
object
type_
cdef
list
keys
type_
,
keys
=
spec
return
type_
(
**
{
key
:
value
for
key
,
value
in
zip
(
keys
,
values
)})
cdef
inline
tuple
_treevalue_flatten
(
object
l
):
return
_c_flatten_for_integration
(
l
)
...
...
@@ -132,40 +114,73 @@ cpdef inline void register_integrate_container(object type_, object flatten_func
"""
_REGISTERED_CONTAINERS
[
type_
]
=
(
flatten_func
,
unflatten_func
)
cdef
inline
tuple
_c_get_flatted_values_and_spec
(
object
v
):
cdef
list
values
cdef
object
spec
,
type_
cdef
object
flatten_func
if
isinstance
(
v
,
dict
):
values
,
spec
=
_dict_flatten
(
v
)
type_
=
dict
elif
_is_namedtuple_instance
(
v
):
values
,
spec
=
_namedtuple_flatten
(
v
)
type_
=
namedtuple
elif
isinstance
(
v
,
(
list
,
tuple
)):
values
,
spec
=
_list_and_tuple_flatten
(
v
)
type_
=
list
elif
isinstance
(
v
,
TreeValue
):
values
,
spec
=
_treevalue_flatten
(
v
)
type_
=
TreeValue
elif
type
(
v
)
in
_REGISTERED_CONTAINERS
:
flatten_func
,
_
=
_REGISTERED_CONTAINERS
[
type
(
v
)]
values
,
spec
=
flatten_func
(
v
)
type_
=
type
(
v
)
else
:
return
v
,
None
,
None
return
values
,
type_
,
spec
cdef
inline
object
_c_get_object_from_flatted
(
object
values
,
object
type_
,
object
spec
):
cdef
object
unflatten_func
if
type_
is
dict
:
return
_dict_unflatten
(
values
,
spec
)
elif
type_
is
namedtuple
:
return
_namedtuple_unflatten
(
values
,
spec
)
elif
type_
is
list
:
return
_list_and_tuple_unflatten
(
values
,
spec
)
elif
type_
is
TreeValue
:
return
_treevalue_unflatten
(
values
,
spec
)
elif
type_
in
_REGISTERED_CONTAINERS
:
_
,
unflatten_func
=
_REGISTERED_CONTAINERS
[
type_
]
return
unflatten_func
(
values
,
spec
)
else
:
raise
TypeError
(
f
'Unknown type for unflatten -
{
values
!
r
}
,
{
spec
!
r
}
.'
)
# pragma: no cover
@
cython
.
binding
(
True
)
cpdef
inline
object
generic_flatten
(
object
v
):
"""
Overview:
Flatten generic data, including native objects, ``TreeValue``, namedtuples
and dataclasses
.
Flatten generic data, including native objects, ``TreeValue``, namedtuples.
:param v: Value to be flatted.
:return: Flatted value.
Examples::
>>> from collections import namedtuple
>>> from dataclasses import dataclass
>>> from easydict import EasyDict
>>> from treevalue import FastTreeValue, generic_flatten, generic_unflatten
>>>
>>> class MyTreeValue(FastTreeValue):
... pass
>>>
>>> @dataclass
... class DC:
... x: int
... y: float
...
... def __repr__(self):
... return f'DC({self.x}, {self.y})'
>>>
>>> nt = namedtuple('nt', ['a', 'b'])
>>>
>>> origin = {
... 'a': 1,
... 'b':
[2, 3, 'f', ]
,
... 'b':
(2, 3, 'f',)
,
... 'c': (2, 5, 'ds', EasyDict({ # dict's child class
... 'x': None,
... 'z':
DC(34, '1.2')
, # dataclass
... 'z':
[34, '1.2']
, # dataclass
... })),
... 'd': nt('f', 100), # namedtuple
... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue
...
...
@@ -175,38 +190,17 @@ cpdef inline object generic_flatten(object v):
[1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]
>>>
>>> rv = generic_unflatten(v, spec)
>>> rv
# all the data, including types, are recovered
{'a': 1, 'b':
[2, 3, 'f'], 'c': (2, 5, 'ds', {'x': None, 'z': DC(34, 1.2)}), 'd': nt(a='f', b=100), 'e': <MyTreeValue 0x7fba23ef9c5
0>
>>> rv
{'a': 1, 'b':
(2, 3, 'f'), 'c': (2, 5, 'ds', {'x': None, 'z': [34, '1.2']}), 'd': nt(a='f', b=100), 'e': <MyTreeValue 0x7fb6026d7b1
0>
├── 'x' --> 1
└── 'y' --> 'dsfljk'
}
>>> type(rv['c'][-1])
<class 'easydict.EasyDict'>
"""
cdef
list
values
cdef
object
spec
,
type_
cdef
object
flatten_func
if
isinstance
(
v
,
dict
):
values
,
spec
=
_dict_flatten
(
v
)
type_
=
dict
elif
_is_namedtuple_instance
(
v
):
values
,
spec
=
_namedtuple_flatten
(
v
)
type_
=
namedtuple
elif
isinstance
(
v
,
(
list
,
tuple
)):
values
,
spec
=
_list_and_tuple_flatten
(
v
)
type_
=
list
elif
is_dataclass
(
v
):
values
,
spec
=
_dataclass_flatten
(
v
)
type_
=
dataclass
elif
isinstance
(
v
,
TreeValue
):
values
,
spec
=
_treevalue_flatten
(
v
)
type_
=
TreeValue
elif
type
(
v
)
in
_REGISTERED_CONTAINERS
:
flatten_func
,
_
=
_REGISTERED_CONTAINERS
[
type
(
v
)]
values
,
spec
=
flatten_func
(
v
)
type_
=
type
(
v
)
else
:
return
v
,
(
None
,
None
,
None
)
values
,
type_
,
spec
=
_c_get_flatted_values_and_spec
(
v
)
if
type_
is
None
:
return
values
,
(
None
,
None
,
None
)
cdef
list
child_values
=
[]
cdef
list
child_specs
=
[]
...
...
@@ -241,19 +235,17 @@ cpdef inline object generic_unflatten(object v, tuple gspec):
for
_i_value
,
_i_spec
in
zip
(
v
,
child_specs
):
values
.
append
(
generic_unflatten
(
_i_value
,
_i_spec
))
cdef
object
unflatten_func
if
type_
is
dict
:
return
_dict_unflatten
(
values
,
spec
)
elif
type_
is
namedtuple
:
return
_namedtuple_unflatten
(
values
,
spec
)
elif
type_
is
list
:
return
_list_and_tuple_unflatten
(
values
,
spec
)
elif
type_
is
dataclass
:
return
_dataclass_unflatten
(
values
,
spec
)
elif
type_
is
TreeValue
:
return
_treevalue_unflatten
(
values
,
spec
)
elif
type_
in
_REGISTERED_CONTAINERS
:
_
,
unflatten_func
=
_REGISTERED_CONTAINERS
[
type_
]
return
unflatten_func
(
values
,
spec
)
else
:
raise
TypeError
(
f
'Unknown type for unflatten -
{
values
!
r
}
,
{
gspec
!
r
}
.'
)
# pragma: no cover
return
_c_get_object_from_flatted
(
values
,
type_
,
spec
)
@
cython
.
binding
(
True
)
cpdef
inline
object
generic_mapping
(
object
v
,
object
func
):
values
,
type_
,
spec
=
_c_get_flatted_values_and_spec
(
v
)
if
type_
is
None
:
return
func
(
values
)
cdef
list
retvals
=
[]
cdef
object
value
for
value
in
values
:
retvals
.
append
(
generic_mapping
(
value
,
func
))
return
_c_get_object_from_flatted
(
retvals
,
type_
,
spec
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录