import numpy as np import pytest from treetensor.numpy import TreeNumpy, all_array_equal, equal, array_equal # noinspection DuplicatedCode @pytest.mark.unittest class TestNumpyFuncs: _DEMO_1 = TreeNumpy({ 'a': np.array([[1, 2, 3], [5, 6, 7]]), 'b': np.array([1, 3, 5, 7]), 'x': { 'c': np.array([3, 5, 7]), 'd': np.array([[7, 9]]), } }) _DEMO_2 = TreeNumpy({ 'a': np.array([[1, 2, 3], [5, 6, 8]]), 'b': np.array([1, 3, 5, 7]), 'x': { 'c': np.array([3, 5, 7]), 'd': np.array([[7, 9]]), } }) _DEMO_3 = TreeNumpy({ 'a': np.array([[1, 2, 3], [5, 6, 7]]), 'b': np.array([1, 3, 5, 7]), 'x': { 'c': np.array([3, 5, 7]), 'd': np.array([[7, 9]]), } }) def test_all_array_equal(self): assert not all_array_equal(self._DEMO_1, self._DEMO_2) assert all_array_equal(self._DEMO_1, self._DEMO_3) assert not all_array_equal(np.array([1, 2, 3]), np.array([1, 2, 4])) assert all_array_equal(np.array([1, 2, 3]), np.array([1, 2, 3])) def test_equal(self): assert all_array_equal( equal(self._DEMO_1, self._DEMO_2), TreeNumpy({ 'a': np.array([[True, True, True], [True, True, False]]), 'b': np.array([True, True, True, True]), 'x': { 'c': np.array([True, True, True]), 'd': np.array([[True, True]]), } }) ) assert all_array_equal( equal(self._DEMO_1, self._DEMO_3), TreeNumpy({ 'a': np.array([[True, True, True], [True, True, True]]), 'b': np.array([True, True, True, True]), 'x': { 'c': np.array([True, True, True]), 'd': np.array([[True, True]]), } }) ) def test_array_equal(self): assert array_equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({ 'a': False, 'b': True, 'x': { 'c': True, 'd': True, } }) assert array_equal(self._DEMO_1, self._DEMO_3) == TreeNumpy({ 'a': True, 'b': True, 'x': { 'c': True, 'd': True, } })