diff --git a/test/tree/tree/test_structural.py b/test/tree/tree/test_structural.py index 60a048152d39d374c3e3d76bb04383fd5aa02d43..90edc782783c498a504cb498f5fb659785185a6c 100644 --- a/test/tree/tree/test_structural.py +++ b/test/tree/tree/test_structural.py @@ -1,3 +1,5 @@ +from collections import namedtuple + import pytest from treevalue.tree import TreeValue, mapping, union, raw, subside, rise, delayed @@ -116,6 +118,18 @@ class TestTreeTreeStructural: assert subside({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}, 'e': [3, 4, 5]}) == \ {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}, 'e': [3, 4, 5]} + nt = namedtuple('nt', ['a', 'b', 'c']) + a = nt( + MyTreeValue({'x': 1, 'y': 2, 'z': {'v': 3}}), + MyTreeValue({'x': 4, 'y': 5, 'z': {'v': 6}}), + MyTreeValue({'x': 7, 'y': 8, 'z': {'v': 9}}), + ) + assert subside(a) == MyTreeValue({ + 'x': nt(1, 4, 7), + 'y': nt(2, 5, 8), + 'z': {'v': nt(3, 6, 9), }, + }) + def test_subside_delayed(self): class MyTreeValue(TreeValue): pass @@ -273,3 +287,15 @@ class TestTreeTreeStructural: MyTreeValue({'v': {'x': 3, 'y': 8}}), ] } + + nt = namedtuple('nt', ['a', 'b', 'c']) + t11 = MyTreeValue({ + 'x': nt(1, 4, 7), + 'y': nt(2, 5, 8), + 'z': {'v': nt(3, 6, 9), }, + }) + assert rise(t11) == nt( + MyTreeValue({'x': 1, 'y': 2, 'z': {'v': 3}}), + MyTreeValue({'x': 4, 'y': 5, 'z': {'v': 6}}), + MyTreeValue({'x': 7, 'y': 8, 'z': {'v': 9}}), + ) diff --git a/treevalue/tree/tree/structural.pyx b/treevalue/tree/tree/structural.pyx index 286da8c98ce577176e56de3a52889235a76dbae1..226ad14c4ac048f4e709667cd8c30c9dde878fea 100644 --- a/treevalue/tree/tree/structural.pyx +++ b/treevalue/tree/tree/structural.pyx @@ -16,6 +16,15 @@ from ..func.modes cimport _c_load_mode MISSING_NOT_ALLOW = SingletonMark("missing_not_allow") +cdef inline bool _c_is_namedtuple(object type_): + return issubclass(type_, tuple) and hasattr(type_, '_fields') and hasattr(type_, '_asdict') + +cdef inline object _c_create_sequence_with_type(object type_, object iters): + if _c_is_namedtuple(type_): + return type_(*iters) + else: + return type_(iters) + cdef object _c_subside_process(tuple value, object it): cdef type type_ cdef list items @@ -29,7 +38,7 @@ cdef object _c_subside_process(tuple value, object it): _l_res = [] for v in items: _l_res.append(_c_subside_process(v, it)) - return type_(_l_res) + return _c_create_sequence_with_type(type_, _l_res) elif issubclass(type_, dict): _d_res = {} for k, v in items: @@ -270,7 +279,7 @@ cdef object _c_rise_struct_builder(tuple p, object it): for v in item: _l_res.append(_c_rise_struct_builder(v, it)) - return type_(_l_res) + return _c_create_sequence_with_type(type_, _l_res) else: return next(it) @@ -298,7 +307,10 @@ cdef tuple _c_rise_struct_process(list objs, object template): raise ValueError(f"At least {repr(_l_temp - 2)} value expected due to template " f"{repr(template)}, but length is {repr(_l_obj_0)}.") - _a_template = type(template)(chain(template[:-2], (template[-2],) * (_l_obj_0 - _l_temp + 2))) + _a_template = _c_create_sequence_with_type( + type(template), + chain(template[:-2], (template[-2],) * (_l_obj_0 - _l_temp + 2)) + ) else: _a_template = template elif template is None: @@ -330,7 +342,7 @@ cdef tuple _c_rise_struct_process(list objs, object template): if failed: _a_template = object else: - _a_template = _t_type(None for _ in range(length)) + _a_template = _c_create_sequence_with_type(_t_type, (None for _ in range(length))) else: _a_template = object