提交 9dd98666 编写于 作者: HansBug's avatar HansBug 😆

dev, test(hansbug): add any in treetensor.numpy && refactor the usage in numpy's test

上级 e5ed1297
import numpy as np
import pytest
from treetensor.numpy import TreeNumpy, equal, array_equal
from treetensor.numpy import all as _numpy_all
import treetensor.numpy as tnp
# noinspection DuplicatedCode
@pytest.mark.unittest
class TestNumpyFuncs:
_DEMO_1 = TreeNumpy({
_DEMO_1 = tnp.TreeNumpy({
'a': np.array([[1, 2, 3], [5, 6, 7]]),
'b': np.array([1, 3, 5, 7]),
'x': {
......@@ -17,7 +16,7 @@ class TestNumpyFuncs:
}
})
_DEMO_2 = TreeNumpy({
_DEMO_2 = tnp.TreeNumpy({
'a': np.array([[1, 2, 3], [5, 6, 8]]),
'b': np.array([1, 3, 5, 7]),
'x': {
......@@ -26,7 +25,7 @@ class TestNumpyFuncs:
}
})
_DEMO_3 = TreeNumpy({
_DEMO_3 = tnp.TreeNumpy({
'a': np.array([[1, 2, 3], [5, 6, 7]]),
'b': np.array([1, 3, 5, 7]),
'x': {
......@@ -36,26 +35,53 @@ class TestNumpyFuncs:
})
def test_all(self):
assert not _numpy_all(np.array([True, True, False]))
assert _numpy_all(np.array([True, True, True]))
assert tnp.all(np.array([True, True, True]))
assert not tnp.all(np.array([True, True, False]))
assert not tnp.all(np.array([False, False, False]))
assert not _numpy_all(self._DEMO_1 == self._DEMO_2)
assert _numpy_all(self._DEMO_1 == self._DEMO_3)
assert not _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 4]))
assert _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 3]))
assert tnp.all(tnp.TreeNumpy({
'a': np.array([True, True, True]),
'b': np.array([True, True, True]),
}))
assert not tnp.all(tnp.TreeNumpy({
'a': np.array([True, True, True]),
'b': np.array([True, True, False]),
}))
assert not tnp.all(tnp.TreeNumpy({
'a': np.array([False, False, False]),
'b': np.array([False, False, False]),
}))
def test_any(self):
assert tnp.any(np.array([True, True, True]))
assert tnp.any(np.array([True, True, False]))
assert not tnp.any(np.array([False, False, False]))
assert tnp.any(tnp.TreeNumpy({
'a': np.array([True, True, True]),
'b': np.array([True, True, True]),
}))
assert tnp.any(tnp.TreeNumpy({
'a': np.array([True, True, True]),
'b': np.array([True, True, False]),
}))
assert not tnp.any(tnp.TreeNumpy({
'a': np.array([False, False, False]),
'b': np.array([False, False, False]),
}))
def test_equal(self):
assert _numpy_all(equal(
assert tnp.all(tnp.equal(
np.array([1, 2, 3]),
np.array([1, 2, 3]),
))
assert not _numpy_all(equal(
assert not tnp.all(tnp.equal(
np.array([1, 2, 3]),
np.array([1, 2, 4]),
))
assert _numpy_all(
equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({
assert tnp.all(
tnp.equal(self._DEMO_1, self._DEMO_2) == tnp.TreeNumpy({
'a': np.array([[True, True, True], [True, True, False]]),
'b': np.array([True, True, True, True]),
'x': {
......@@ -64,8 +90,8 @@ class TestNumpyFuncs:
}
})
)
assert _numpy_all(
equal(self._DEMO_1, self._DEMO_3) == TreeNumpy({
assert tnp.all(
tnp.equal(self._DEMO_1, self._DEMO_3) == tnp.TreeNumpy({
'a': np.array([[True, True, True], [True, True, True]]),
'b': np.array([True, True, True, True]),
'x': {
......@@ -76,16 +102,16 @@ class TestNumpyFuncs:
)
def test_array_equal(self):
assert _numpy_all(array_equal(
assert tnp.all(tnp.array_equal(
np.array([1, 2, 3]),
np.array([1, 2, 3]),
))
assert not _numpy_all(array_equal(
assert not tnp.all(tnp.array_equal(
np.array([1, 2, 3]),
np.array([1, 2, 4]),
))
assert array_equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({
assert tnp.array_equal(self._DEMO_1, self._DEMO_2) == tnp.TreeNumpy({
'a': False,
'b': True,
'x': {
......@@ -93,7 +119,7 @@ class TestNumpyFuncs:
'd': True,
}
})
assert array_equal(self._DEMO_1, self._DEMO_3) == TreeNumpy({
assert tnp.array_equal(self._DEMO_1, self._DEMO_3) == tnp.TreeNumpy({
'a': True,
'b': True,
'x': {
......
......@@ -2,11 +2,11 @@ import numpy as np
from treevalue import func_treelize as original_func_treelize
from .numpy import TreeNumpy
from ..common import ireduce
from ..common import ireduce, TreeObject
from ..utils import replaceable_partial
__all__ = [
'all',
'all', 'any',
'equal', 'array_equal',
]
......@@ -14,11 +14,17 @@ func_treelize = replaceable_partial(original_func_treelize, return_type=TreeNump
@ireduce(all)
@func_treelize()
@func_treelize(return_type=TreeObject)
def all(a, *args, **kwargs):
return np.all(a, *args, **kwargs)
@ireduce(any)
@func_treelize()
def any(a, *args, **kwargs):
return np.any(a, *args, **kwargs)
@func_treelize()
def equal(x1, x2, *args, **kwargs):
return np.equal(x1, x2, *args, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册