diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8176be2683e8ddd7a2165592e900ed832275c99e..1cb9facbfc2a46712b2e0108ccbfbd6c8d5e3770 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,48 +17,31 @@ jobs: - '3.7' - '3.8' - '3.9' + numpy-version: + - '1.18.0' + - '1.20.0' + - '1.22.0' torch-version: - - '1.1.0' - '1.2.0' - - '1.3.0' - '1.4.0' - - '1.5.0' - '1.6.0' - - '1.7.0' - '1.8.0' - - '1.9.0' - '1.10.0' exclude: - - os: 'ubuntu-18.04' - python-version: '3.8' - torch-version: '1.1.0' - - os: 'ubuntu-18.04' - python-version: '3.8' + - python-version: '3.6' + numpy-version: '1.20.0' + - python-version: '3.6' + numpy-version: '1.22.0' + - python-version: '3.7' + numpy-version: '1.22.0' + - python-version: '3.8' torch-version: '1.2.0' - - os: 'ubuntu-18.04' - python-version: '3.8' - torch-version: '1.3.0' - - os: 'ubuntu-18.04' - python-version: '3.9' - torch-version: '1.1.0' - - os: 'ubuntu-18.04' - python-version: '3.9' + - python-version: '3.9' torch-version: '1.2.0' - - os: 'ubuntu-18.04' - python-version: '3.9' - torch-version: '1.3.0' - - os: 'ubuntu-18.04' - python-version: '3.9' + - python-version: '3.9' torch-version: '1.4.0' - - os: 'ubuntu-18.04' - python-version: '3.9' - torch-version: '1.5.0' - - os: 'ubuntu-18.04' - python-version: '3.9' + - python-version: '3.9' torch-version: '1.6.0' - - os: 'ubuntu-18.04' - python-version: '3.9' - torch-version: '1.7.0' steps: - name: Checkout code @@ -80,6 +63,14 @@ jobs: run: | python -m pip install --upgrade pip pip install --upgrade flake8 setuptools wheel twine + - name: Install latest numpy + if: ${{ matrix.numpy-version == 'latest' }} + run: | + pip install 'numpy' + - name: Install numpy v${{ matrix.numpy-version }} + if: ${{ matrix.numpy-version != 'latest' }} + run: | + pip install 'numpy==${{ matrix.numpy-version }}' - name: Install latest pytorch if: ${{ matrix.torch-version == 'latest' }} run: | diff --git a/docs/source/api_doc/numpy/funcs.rst.py b/docs/source/api_doc/numpy/funcs.rst.py index f8d9f4bec5c0e2d2dcf5382201a20a9cdd6ed3d1..2ab8bee53cda5cec927b92279c95de161b47423c 100644 --- a/docs/source/api_doc/numpy/funcs.rst.py +++ b/docs/source/api_doc/numpy/funcs.rst.py @@ -73,7 +73,7 @@ with the following command and find its documentation. print_title(f"Description From Numpy v{_short_version}", levelc='-', file=p_func) current_module(np.__name__, file=p_func) - _origin_doc = _doc_process(_origin.__doc__ or "") + _origin_doc = _doc_process(_origin.__doc__ or "").lstrip() _doc_lines = _origin_doc.splitlines() _first_line, _other_lines = _doc_lines[0], _doc_lines[1:] if _first_line.strip(): diff --git a/test/numpy/test_funcs.py b/test/numpy/test_funcs.py index 6330922940ed646135b43f6ed324255ff3f20a95..230b0b54417234e03f38b46020341de6a07fa84a 100644 --- a/test/numpy/test_funcs.py +++ b/test/numpy/test_funcs.py @@ -127,3 +127,97 @@ class TestNumpyFuncs: 'd': True, } }) + + def test_zeros(self): + zs = tnp.zeros((2, 3)) + assert isinstance(zs, np.ndarray) + assert np.allclose(zs, np.zeros((2, 3))) + + zs = tnp.zeros({'a': (2, 3), 'c': {'x': (3, 4)}}) + assert tnp.allclose(zs, tnp.ndarray({ + 'a': np.zeros((2, 3)), + 'c': {'x': np.zeros((3, 4))} + })) + + def test_ones(self): + zs = tnp.ones((2, 3)) + assert isinstance(zs, np.ndarray) + assert np.allclose(zs, np.ones((2, 3))) + + zs = tnp.ones({'a': (2, 3), 'c': {'x': (3, 4)}}) + assert tnp.allclose(zs, tnp.ndarray({ + 'a': np.ones((2, 3)), + 'c': {'x': np.zeros((3, 4))} + })) + + def test_stack(self): + a = np.array([1, 2, 3]) + b = np.array([2, 3, 4]) + nd = tnp.stack((a, b)) + assert isinstance(nd, np.ndarray) + assert np.allclose(nd, np.array([[1, 2, 3], + [2, 3, 4]])) + + a = tnp.array({ + 'a': [1, 2, 3], + 'c': {'x': [11, 22, 33]}, + }) + b = tnp.array({ + 'a': [2, 3, 4], + 'c': {'x': [22, 33, 44]}, + }) + nd = tnp.stack((a, b)) + assert tnp.allclose(nd, tnp.array({ + 'a': [[1, 2, 3], [2, 3, 4]], + 'c': {'x': [[11, 22, 33], [22, 33, 44]]}, + })) + + def test_concatenate(self): + a = np.array([[1, 2], [3, 4]]) + b = np.array([[5, 6]]) + nd = tnp.concatenate((a, b), axis=0) + assert isinstance(nd, np.ndarray) + assert np.allclose(nd, np.array([[1, 2], + [3, 4], + [5, 6]])) + + a = tnp.array({ + 'a': [[1, 2], [3, 4]], + 'c': {'x': [[11, 22], [33, 44]]}, + }) + b = tnp.array({ + 'a': [[5, 6]], + 'c': {'x': [[55, 66]]}, + }) + nd = tnp.concatenate((a, b), axis=0) + assert tnp.allclose(nd, tnp.array({ + 'a': [[1, 2], [3, 4], [5, 6]], + 'c': {'x': [[11, 22], [33, 44], [55, 66]]}, + })) + + def test_split(self): + x = np.arange(9.0) + ns = tnp.split(x, 3) + assert len(ns) == 3 + assert isinstance(ns[0], np.ndarray) + assert np.allclose(ns[0], np.array([0.0, 1.0, 2.0])) + assert isinstance(ns[1], np.ndarray) + assert np.allclose(ns[1], np.array([3.0, 4.0, 5.0])) + assert isinstance(ns[2], np.ndarray) + assert np.allclose(ns[2], np.array([6.0, 7.0, 8.0])) + + xx = tnp.arange(tnp.ndarray({'a': 9.0, 'c': {'x': 18.0}})) + ns = tnp.split(xx, 3) + assert len(ns) == 3 + assert tnp.allclose(ns[0], tnp.array({ + 'a': [0.0, 1.0, 2.0], + 'c': {'x': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]}, + })) + assert tnp.allclose(ns[1], tnp.array({ + 'a': [3.0, 4.0, 5.0], + 'c': {'x': [6.0, 7.0, 8.0, 9.0, 10.0, 11.0]}, + })) + assert tnp.allclose(ns[2], tnp.array({ + 'a': [6.0, 7.0, 8.0], + 'c': {'x': [12.0, 13.0, 14.0, 15.0, 16.0, 17.0]}, + })) diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index b627bf8cfbb113533a45483e46eaa6284cf27afb..ba1d0ece2e522c8211937726d2e2031e0351d1f3 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -13,6 +13,8 @@ from ..utils import replaceable_partial, doc_from, args_mapping __all__ = [ 'all', 'any', 'array', 'equal', 'array_equal', + 'stack', 'concatenate', 'split', + 'zeros', 'ones', ] func_treelize = post_process(post_process(args_mapping( @@ -71,3 +73,33 @@ def array(p_object, *args, **kwargs): }) """ return np.array(p_object, *args, **kwargs) + + +@doc_from(np.stack) +@func_treelize(subside=True) +def stack(arrays, *args, **kwargs): + return np.stack(arrays, *args, **kwargs) + + +@doc_from(np.concatenate) +@func_treelize(subside=True) +def concatenate(arrays, *args, **kwargs): + return np.concatenate(arrays, *args, **kwargs) + + +@doc_from(np.split) +@func_treelize(rise=True) +def split(ary, *args, **kwargs): + return np.split(ary, *args, **kwargs) + + +@doc_from(np.zeros) +@func_treelize() +def zeros(shape, *args, **kwargs): + return np.zeros(shape, *args, **kwargs) + + +@doc_from(np.ones) +@func_treelize() +def ones(shape, *args, **kwargs): + return np.ones(shape, *args, **kwargs)