提交 683b07eb 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): update get_or_default && add test for __hash__ of TreeStorage

上级 00896958
......@@ -5,7 +5,7 @@ import pytest
from treevalue.tree.common import create_storage, raw, TreeStorage, delayed_partial
# noinspection PyArgumentList,DuplicatedCode
# noinspection PyArgumentList,DuplicatedCode,PyTypeChecker
@pytest.mark.unittest
class TestTreeStorage:
def test_init(self):
......@@ -80,6 +80,23 @@ class TestTreeStorage:
assert t.get_or_default('fff', 233) == 233
t1 = create_storage({
'a': delayed_partial(lambda: t.get('a')),
'b': delayed_partial(lambda: t.get('b')),
'c': delayed_partial(lambda: t.get('c')),
'd': delayed_partial(lambda: t.get('d')),
})
assert t1.get_or_default('a', 233) == 1
assert t1.get_or_default('b', 233) == 2
assert t1.get_or_default('c', 233) == {'x': 3, 'y': 4}
assert isinstance(t1.get_or_default('d', 233), TreeStorage)
assert t1.get_or_default('d', 233).get_or_default('x', 233) == 3
assert t1.get_or_default('d', 233).get_or_default('y', 233) == 4
assert t1.get_or_default('fff', 233) == 233
assert t1.get_or_default('fff', delayed_partial(lambda: 2345)) == 2345
assert not t1.contains('fff')
def test_set(self):
t = create_storage({})
t.set('a', 1)
......@@ -415,3 +432,24 @@ class TestTreeStorage:
assert v == h1
else:
pytest.fail('Should not reach here.')
def test_hash(self):
h = {}
h1 = {'x': 3, 'y': 4}
t = create_storage({'a': 1, 'b': 2, 'd': h1})
t1 = create_storage({
'a': delayed_partial(lambda: t.get('a')),
'b': delayed_partial(lambda: t.get('b')),
'd': delayed_partial(lambda: t.get('d')),
})
t2 = create_storage({
'a': delayed_partial(lambda: t.get('a')),
'b': delayed_partial(lambda: 3),
'd': delayed_partial(lambda: t.get('d')),
})
h[t] = 1
assert t1 in h
assert h[t1] == 1
assert t2 not in h
......@@ -49,7 +49,15 @@ cdef class TreeStorage:
raise KeyError(f"Key {repr(key)} not found in this tree.")
cpdef public object get_or_default(self, str key, object default):
return self.map.get(key, default)
cdef object v, nv
v = self.map.get(key, default)
nv = unwrap_proxy(v)
if nv is not v:
v = nv
if key in self.map:
self.map[key] = v
return v
cpdef public void del_(self, str key) except *:
try:
......@@ -189,7 +197,7 @@ cdef class TreeStorage:
cdef str k
cdef object v
cdef list _items = []
for k, v in sorted(self.map.items(), key=lambda x: x[0]):
for k, v in sorted(self.items(), key=lambda x: x[0]):
_items.append((k, v))
return hash(tuple(_items))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册