未验证 提交 3054a629 编写于 作者: HansBug's avatar HansBug 😆 提交者: GitHub

Merge pull request #19 from opendilab/dev/flatten

dev(hansbug): create flatten_values and flatten_keys
......@@ -52,6 +52,22 @@ flatten
.. autofunction:: flatten
.. _apidoc_tree_tree_flatten_values:
flatten_values
-------------------
.. autofunction:: flatten_values
.. _apidoc_tree_tree_flatten_keys:
flatten_keys
-------------------
.. autofunction:: flatten_keys
.. _apidoc_tree_tree_unflatten:
unflatten
......
import pytest
from treevalue.tree import TreeValue, raw, flatten, unflatten
from treevalue.tree import TreeValue, raw, flatten, unflatten, flatten_values, flatten_keys
class MyTreeValue(TreeValue):
......@@ -22,6 +22,26 @@ class TestTreeTreeFlatten:
(('d', 'y'), 4)
]
def test_flatten_values(self):
t = TreeValue({'a': 1, 'b': 5, 'c': {'x': 3, 'y': 4}, 'd': {'x': 3, 'y': 4}})
flatted_values = sorted(flatten_values(t))
assert flatted_values == [1, 3, 3, 4, 4, 5]
def test_flatten_keys(self):
t = TreeValue({'a': 1, 'd': {'x': 3, 'y': 4}, 'e': raw({'x': 3, 'y': 4}), 'b': 5, 'c': {'x': 3, 'y': 4}})
flatted_keys = sorted(flatten_keys(t))
assert flatted_keys == [
('a',),
('b',),
('c', 'x',),
('c', 'y',),
('d', 'x',),
('d', 'y',),
('e',),
]
def test_unflatten(self):
flatted = [
(('a',), 1),
......
from .flatten import flatten, unflatten
from .flatten import flatten, unflatten, flatten_values, flatten_keys
from .functional import mapping, filter_, mask, reduce_
from .graph import graphics
from .io import loads, load, dumps, dump
......
......@@ -9,5 +9,11 @@ from ..common.storage cimport TreeStorage
cdef void _c_flatten(TreeStorage st, tuple path, list res) except *
cpdef list flatten(TreeValue tree)
cdef void _c_flatten_values(TreeStorage st, list res) except *
cpdef list flatten_values(TreeValue tree)
cdef void _c_flatten_keys(TreeStorage st, tuple path, list res) except *
cpdef list flatten_keys(TreeValue tree)
cdef TreeStorage _c_unflatten(object pairs)
cpdef TreeValue unflatten(object pairs, object return_type= *)
......@@ -25,10 +25,10 @@ cdef void _c_flatten(TreeStorage st, tuple path, list res) except *:
cpdef list flatten(TreeValue tree):
r"""
Overview:
Flatten the values in the tree.
Flatten the key-value pairs in the tree.
Arguments:
- tree (:obj:`TreeValue`): Tree object to be flatten.
- tree (:obj:`TreeValue`): Tree object to be flatten.
Returns:
- flatted (:obj:`list`): Flatted tree, a list of tuple with (path, value).
......@@ -43,6 +43,62 @@ cpdef list flatten(TreeValue tree):
_c_flatten(tree._detach(), (), result)
return result
cdef void _c_flatten_values(TreeStorage st, list res) except *:
cdef dict data = st.detach()
cdef str k
cdef object v
for k, v in data.items():
if isinstance(v, TreeStorage):
_c_flatten_values(v, res)
else:
res.append(v)
@cython.binding(True)
cpdef list flatten_values(TreeValue tree):
r"""
Overview:
Flatten the values in the tree.
Arguments:
- tree (:obj:`TreeValue`): Tree object to be flatten.
Returns:
- flatted (:obj:`list`): Flatted tree, a list of values.
"""
cdef list result = []
_c_flatten_values(tree._detach(), result)
return result
cdef void _c_flatten_keys(TreeStorage st, tuple path, list res) except *:
cdef dict data = st.detach()
cdef tuple curpath
cdef str k
cdef object v
for k, v in data.items():
curpath = path + (k,)
if isinstance(v, TreeStorage):
_c_flatten_keys(v, curpath, res)
else:
res.append(curpath)
@cython.binding(True)
cpdef list flatten_keys(TreeValue tree):
r"""
Overview:
Flatten the keys in the tree.
Arguments:
- tree (:obj:`TreeValue`): Tree object to be flatten.
Returns:
- flatted (:obj:`list`): Flatted tree, a list of paths.
"""
cdef list result = []
_c_flatten_keys(tree._detach(), (), result)
return result
cdef TreeStorage _c_unflatten(object pairs):
cdef dict raw_data = {}
cdef TreeStorage result = TreeStorage(raw_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册