diff --git a/test/numpy/test_funcs.py b/test/numpy/test_funcs.py index 74ac1164687823583af8143ad7d4cbea234f6f84..f4c4b4c2982978a71be51aa745c6e81f67b5fca3 100644 --- a/test/numpy/test_funcs.py +++ b/test/numpy/test_funcs.py @@ -1,14 +1,13 @@ import numpy as np import pytest -from treetensor.numpy import TreeNumpy, equal, array_equal -from treetensor.numpy import all as _numpy_all +import treetensor.numpy as tnp # noinspection DuplicatedCode @pytest.mark.unittest class TestNumpyFuncs: - _DEMO_1 = TreeNumpy({ + _DEMO_1 = tnp.TreeNumpy({ 'a': np.array([[1, 2, 3], [5, 6, 7]]), 'b': np.array([1, 3, 5, 7]), 'x': { @@ -17,7 +16,7 @@ class TestNumpyFuncs: } }) - _DEMO_2 = TreeNumpy({ + _DEMO_2 = tnp.TreeNumpy({ 'a': np.array([[1, 2, 3], [5, 6, 8]]), 'b': np.array([1, 3, 5, 7]), 'x': { @@ -26,7 +25,7 @@ class TestNumpyFuncs: } }) - _DEMO_3 = TreeNumpy({ + _DEMO_3 = tnp.TreeNumpy({ 'a': np.array([[1, 2, 3], [5, 6, 7]]), 'b': np.array([1, 3, 5, 7]), 'x': { @@ -36,26 +35,53 @@ class TestNumpyFuncs: }) def test_all(self): - assert not _numpy_all(np.array([True, True, False])) - assert _numpy_all(np.array([True, True, True])) + assert tnp.all(np.array([True, True, True])) + assert not tnp.all(np.array([True, True, False])) + assert not tnp.all(np.array([False, False, False])) - assert not _numpy_all(self._DEMO_1 == self._DEMO_2) - assert _numpy_all(self._DEMO_1 == self._DEMO_3) - assert not _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 4])) - assert _numpy_all(np.array([1, 2, 3]) == np.array([1, 2, 3])) + assert tnp.all(tnp.TreeNumpy({ + 'a': np.array([True, True, True]), + 'b': np.array([True, True, True]), + })) + assert not tnp.all(tnp.TreeNumpy({ + 'a': np.array([True, True, True]), + 'b': np.array([True, True, False]), + })) + assert not tnp.all(tnp.TreeNumpy({ + 'a': np.array([False, False, False]), + 'b': np.array([False, False, False]), + })) + + def test_any(self): + assert tnp.any(np.array([True, True, True])) + assert tnp.any(np.array([True, True, False])) + assert not tnp.any(np.array([False, False, False])) + + assert tnp.any(tnp.TreeNumpy({ + 'a': np.array([True, True, True]), + 'b': np.array([True, True, True]), + })) + assert tnp.any(tnp.TreeNumpy({ + 'a': np.array([True, True, True]), + 'b': np.array([True, True, False]), + })) + assert not tnp.any(tnp.TreeNumpy({ + 'a': np.array([False, False, False]), + 'b': np.array([False, False, False]), + })) def test_equal(self): - assert _numpy_all(equal( + assert tnp.all(tnp.equal( np.array([1, 2, 3]), np.array([1, 2, 3]), )) - assert not _numpy_all(equal( + assert not tnp.all(tnp.equal( np.array([1, 2, 3]), np.array([1, 2, 4]), )) - assert _numpy_all( - equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({ + assert tnp.all( + tnp.equal(self._DEMO_1, self._DEMO_2) == tnp.TreeNumpy({ 'a': np.array([[True, True, True], [True, True, False]]), 'b': np.array([True, True, True, True]), 'x': { @@ -64,8 +90,8 @@ class TestNumpyFuncs: } }) ) - assert _numpy_all( - equal(self._DEMO_1, self._DEMO_3) == TreeNumpy({ + assert tnp.all( + tnp.equal(self._DEMO_1, self._DEMO_3) == tnp.TreeNumpy({ 'a': np.array([[True, True, True], [True, True, True]]), 'b': np.array([True, True, True, True]), 'x': { @@ -76,16 +102,16 @@ class TestNumpyFuncs: ) def test_array_equal(self): - assert _numpy_all(array_equal( + assert tnp.all(tnp.array_equal( np.array([1, 2, 3]), np.array([1, 2, 3]), )) - assert not _numpy_all(array_equal( + assert not tnp.all(tnp.array_equal( np.array([1, 2, 3]), np.array([1, 2, 4]), )) - assert array_equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({ + assert tnp.array_equal(self._DEMO_1, self._DEMO_2) == tnp.TreeNumpy({ 'a': False, 'b': True, 'x': { @@ -93,7 +119,7 @@ class TestNumpyFuncs: 'd': True, } }) - assert array_equal(self._DEMO_1, self._DEMO_3) == TreeNumpy({ + assert tnp.array_equal(self._DEMO_1, self._DEMO_3) == tnp.TreeNumpy({ 'a': True, 'b': True, 'x': { diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index b67a9e8f78e07c1a8cce8c9e958d788ad950f8e3..6322783f227fcc580f3a741537f8882e0f2e3ed9 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -2,11 +2,11 @@ import numpy as np from treevalue import func_treelize as original_func_treelize from .numpy import TreeNumpy -from ..common import ireduce +from ..common import ireduce, TreeObject from ..utils import replaceable_partial __all__ = [ - 'all', + 'all', 'any', 'equal', 'array_equal', ] @@ -14,11 +14,17 @@ func_treelize = replaceable_partial(original_func_treelize, return_type=TreeNump @ireduce(all) -@func_treelize() +@func_treelize(return_type=TreeObject) def all(a, *args, **kwargs): return np.all(a, *args, **kwargs) +@ireduce(any) +@func_treelize() +def any(a, *args, **kwargs): + return np.any(a, *args, **kwargs) + + @func_treelize() def equal(x1, x2, *args, **kwargs): return np.equal(x1, x2, *args, **kwargs)