未验证 提交 716b4b91 编写于 作者: HansBug's avatar HansBug 😆 提交者: GitHub

Merge pull request #8 from opendilab/dev/numpy

doc(hansbug): add documentation for np.stack, np.split and other 3 functions.
......@@ -17,48 +17,31 @@ jobs:
- '3.7'
- '3.8'
- '3.9'
numpy-version:
- '1.18.0'
- '1.20.0'
- '1.22.0'
torch-version:
- '1.1.0'
- '1.2.0'
- '1.3.0'
- '1.4.0'
- '1.5.0'
- '1.6.0'
- '1.7.0'
- '1.8.0'
- '1.9.0'
- '1.10.0'
exclude:
- os: 'ubuntu-18.04'
python-version: '3.8'
torch-version: '1.1.0'
- os: 'ubuntu-18.04'
python-version: '3.8'
- python-version: '3.6'
numpy-version: '1.20.0'
- python-version: '3.6'
numpy-version: '1.22.0'
- python-version: '3.7'
numpy-version: '1.22.0'
- python-version: '3.8'
torch-version: '1.2.0'
- os: 'ubuntu-18.04'
python-version: '3.8'
torch-version: '1.3.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
torch-version: '1.1.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
- python-version: '3.9'
torch-version: '1.2.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
torch-version: '1.3.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
- python-version: '3.9'
torch-version: '1.4.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
torch-version: '1.5.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
- python-version: '3.9'
torch-version: '1.6.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
torch-version: '1.7.0'
steps:
- name: Checkout code
......@@ -80,6 +63,14 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install --upgrade flake8 setuptools wheel twine
- name: Install latest numpy
if: ${{ matrix.numpy-version == 'latest' }}
run: |
pip install 'numpy'
- name: Install numpy v${{ matrix.numpy-version }}
if: ${{ matrix.numpy-version != 'latest' }}
run: |
pip install 'numpy==${{ matrix.numpy-version }}'
- name: Install latest pytorch
if: ${{ matrix.torch-version == 'latest' }}
run: |
......
......@@ -73,7 +73,7 @@ with the following command and find its documentation.
print_title(f"Description From Numpy v{_short_version}", levelc='-', file=p_func)
current_module(np.__name__, file=p_func)
_origin_doc = _doc_process(_origin.__doc__ or "")
_origin_doc = _doc_process(_origin.__doc__ or "").lstrip()
_doc_lines = _origin_doc.splitlines()
_first_line, _other_lines = _doc_lines[0], _doc_lines[1:]
if _first_line.strip():
......
......@@ -127,3 +127,97 @@ class TestNumpyFuncs:
'd': True,
}
})
def test_zeros(self):
zs = tnp.zeros((2, 3))
assert isinstance(zs, np.ndarray)
assert np.allclose(zs, np.zeros((2, 3)))
zs = tnp.zeros({'a': (2, 3), 'c': {'x': (3, 4)}})
assert tnp.allclose(zs, tnp.ndarray({
'a': np.zeros((2, 3)),
'c': {'x': np.zeros((3, 4))}
}))
def test_ones(self):
zs = tnp.ones((2, 3))
assert isinstance(zs, np.ndarray)
assert np.allclose(zs, np.ones((2, 3)))
zs = tnp.ones({'a': (2, 3), 'c': {'x': (3, 4)}})
assert tnp.allclose(zs, tnp.ndarray({
'a': np.ones((2, 3)),
'c': {'x': np.zeros((3, 4))}
}))
def test_stack(self):
a = np.array([1, 2, 3])
b = np.array([2, 3, 4])
nd = tnp.stack((a, b))
assert isinstance(nd, np.ndarray)
assert np.allclose(nd, np.array([[1, 2, 3],
[2, 3, 4]]))
a = tnp.array({
'a': [1, 2, 3],
'c': {'x': [11, 22, 33]},
})
b = tnp.array({
'a': [2, 3, 4],
'c': {'x': [22, 33, 44]},
})
nd = tnp.stack((a, b))
assert tnp.allclose(nd, tnp.array({
'a': [[1, 2, 3], [2, 3, 4]],
'c': {'x': [[11, 22, 33], [22, 33, 44]]},
}))
def test_concatenate(self):
a = np.array([[1, 2], [3, 4]])
b = np.array([[5, 6]])
nd = tnp.concatenate((a, b), axis=0)
assert isinstance(nd, np.ndarray)
assert np.allclose(nd, np.array([[1, 2],
[3, 4],
[5, 6]]))
a = tnp.array({
'a': [[1, 2], [3, 4]],
'c': {'x': [[11, 22], [33, 44]]},
})
b = tnp.array({
'a': [[5, 6]],
'c': {'x': [[55, 66]]},
})
nd = tnp.concatenate((a, b), axis=0)
assert tnp.allclose(nd, tnp.array({
'a': [[1, 2], [3, 4], [5, 6]],
'c': {'x': [[11, 22], [33, 44], [55, 66]]},
}))
def test_split(self):
x = np.arange(9.0)
ns = tnp.split(x, 3)
assert len(ns) == 3
assert isinstance(ns[0], np.ndarray)
assert np.allclose(ns[0], np.array([0.0, 1.0, 2.0]))
assert isinstance(ns[1], np.ndarray)
assert np.allclose(ns[1], np.array([3.0, 4.0, 5.0]))
assert isinstance(ns[2], np.ndarray)
assert np.allclose(ns[2], np.array([6.0, 7.0, 8.0]))
xx = tnp.arange(tnp.ndarray({'a': 9.0, 'c': {'x': 18.0}}))
ns = tnp.split(xx, 3)
assert len(ns) == 3
assert tnp.allclose(ns[0], tnp.array({
'a': [0.0, 1.0, 2.0],
'c': {'x': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]},
}))
assert tnp.allclose(ns[1], tnp.array({
'a': [3.0, 4.0, 5.0],
'c': {'x': [6.0, 7.0, 8.0, 9.0, 10.0, 11.0]},
}))
assert tnp.allclose(ns[2], tnp.array({
'a': [6.0, 7.0, 8.0],
'c': {'x': [12.0, 13.0, 14.0, 15.0, 16.0, 17.0]},
}))
......@@ -13,6 +13,8 @@ from ..utils import replaceable_partial, doc_from, args_mapping
__all__ = [
'all', 'any', 'array',
'equal', 'array_equal',
'stack', 'concatenate', 'split',
'zeros', 'ones',
]
func_treelize = post_process(post_process(args_mapping(
......@@ -71,3 +73,33 @@ def array(p_object, *args, **kwargs):
})
"""
return np.array(p_object, *args, **kwargs)
@doc_from(np.stack)
@func_treelize(subside=True)
def stack(arrays, *args, **kwargs):
return np.stack(arrays, *args, **kwargs)
@doc_from(np.concatenate)
@func_treelize(subside=True)
def concatenate(arrays, *args, **kwargs):
return np.concatenate(arrays, *args, **kwargs)
@doc_from(np.split)
@func_treelize(rise=True)
def split(ary, *args, **kwargs):
return np.split(ary, *args, **kwargs)
@doc_from(np.zeros)
@func_treelize()
def zeros(shape, *args, **kwargs):
return np.zeros(shape, *args, **kwargs)
@doc_from(np.ones)
@func_treelize()
def ones(shape, *args, **kwargs):
return np.ones(shape, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册