diff --git a/docs/source/api_doc/numpy/funcs.rst.py b/docs/source/api_doc/numpy/funcs.rst.py index 4ac8f4303eaa2e99300ad5df55d94af75d1aaafe..70aa2f6a3f000eec86769714d6fab5056d15eb7e 100644 --- a/docs/source/api_doc/numpy/funcs.rst.py +++ b/docs/source/api_doc/numpy/funcs.rst.py @@ -1,3 +1,4 @@ +import codecs import re import numpy as np @@ -21,34 +22,44 @@ if __name__ == '__main__': print_title(tnp.funcs.__name__, levelc='=') current_module(tnp.funcs.__name__) - for _name in sorted(tnp.funcs.__all__): - _item = getattr(tnp.funcs, _name) - _origin = get_origin(_item) - print_title(_name, levelc='-') - - with print_block('autofunction', value=_name): - pass - - if _origin and (_origin.__doc__ or '').strip(): - with print_block('admonition', value='Numpy Version Related', params={'class': 'tip'}) as f: - print_doc(f""" -This documentation is based on -`numpy.{_name} `_ -in `numpy v{_numpy_version} `_. -**Its arguments\' arrangements depend on the version of numpy you installed**. - -If some arguments listed here are not working properly, please check your numpy's version -with the following command and find its documentation. - -.. code-block:: shell - :linenos: - - python -c 'import numpy as np;print(np.__version__)' - -The arguments and keyword arguments supported in numpy v{_numpy_version} is listed below. - - """, file=f) - print() - - print_doc(_raw_doc_process(_origin.__doc__ or '')) - print() + with print_block('toctree', params=dict(maxdepth=1)) as tf: + for _name in sorted(tnp.funcs.__all__): + _file_tag = f'funcs.{_name}.auto' + _filename = f'{_file_tag}.rst' + print(_file_tag, file=tf) + + _item = getattr(tnp.funcs, _name) + _origin = get_origin(_item) + with codecs.open(_filename, 'w') as sf: + print_title(_name, levelc='=', file=sf) + current_module(tnp.funcs.__name__, file=sf) + + print_title("Documentation", levelc='-', file=sf) + with print_block('autofunction', value=_name, file=sf): + pass + + if _origin and (_origin.__doc__ or '').strip(): + with print_block('admonition', value='Numpy Version Related', + params={'class': 'tip'}, file=sf) as f: + print_doc(f""" + This documentation is based on + `numpy.{_name} `_ + in `numpy v{_numpy_version} `_. + **Its arguments\' arrangements depend on the version of numpy you installed**. + + If some arguments listed here are not working properly, please check your numpy's version + with the following command and find its documentation. + + .. code-block:: shell + :linenos: + + python -c 'import numpy as np;print(np.__version__)' + + The arguments and keyword arguments supported in numpy v{_numpy_version} is listed below. + + """, file=f) + print(file=sf) + + print_title(f"Description For numpy v{_short_version}", levelc='-', file=sf) + print_doc(_raw_doc_process(_origin.__doc__ or ''), file=sf) + print(file=sf) diff --git a/docs/source/api_doc/numpy/index.rst b/docs/source/api_doc/numpy/index.rst index 6072c7d127dbd09579b2005828983ae006b6f33e..9a31491240f8ee6905b3f7d1e93d67ce29d2beed 100644 --- a/docs/source/api_doc/numpy/index.rst +++ b/docs/source/api_doc/numpy/index.rst @@ -2,7 +2,7 @@ treetensor.numpy ===================== .. toctree:: - :maxdepth: 3 + :maxdepth: 2 funcs.auto numpy.auto diff --git a/docs/source/api_doc/torch/index.rst b/docs/source/api_doc/torch/index.rst index bba292bb6cd1119ae84cdd191460ba4d66ba17b8..9fbc6fe13e7fdd1f2961ff043ae0051055204730 100644 --- a/docs/source/api_doc/torch/index.rst +++ b/docs/source/api_doc/torch/index.rst @@ -2,7 +2,7 @@ treetensor.torch ===================== .. toctree:: - :maxdepth: 3 + :maxdepth: 2 funcs.auto size.auto diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index fc14105d0d7a633f00d765fb74a95f1058fd2204..08cc0eea2b11ec704f9b04c47f843d2ec2b9d768 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -1,35 +1,23 @@ import builtins -from typing import List import numpy as np +from treevalue import TreeValue from treevalue import func_treelize as original_func_treelize +from treevalue.utils import post_process from .numpy import TreeNumpy from ..common import ireduce, TreeObject -from ..utils import replaceable_partial, doc_from +from ..utils import replaceable_partial, doc_from, args_mapping __all__ = [ 'all', 'any', 'equal', 'array_equal', ] - -def _doc_stripper(src, _, lines: List[str]): - _name, _version = src.__name__, np.__version__ - _short_version = '.'.join(_version.split('.')[:2]) - return [ - f'.. note::', - f'', - f' This documentation is based on ' - f' `numpy.{_name} `_ ' - f' in `numpy v{_version} `_.', - f' **Its arguments\' arrangements depend on the version of numpy you installed**.', - f'', - *lines, - ] - - -func_treelize = replaceable_partial(original_func_treelize, return_type=TreeNumpy) +func_treelize = post_process(post_process(args_mapping( + lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeValue)) else x)))( + replaceable_partial(original_func_treelize, return_type=TreeNumpy) +) @doc_from(np.all) diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index a0f146e1bf5cb8be9d30395d11f285521c07e748..8d53b96fcd56572debccd025ae839d92524ed17b 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -27,7 +27,7 @@ __all__ = [ ] func_treelize = post_process(post_process(args_mapping( - lambda i, x: Tensor(x) if isinstance(x, (dict, TreeValue)) else x)))( + lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeValue)) else x)))( replaceable_partial(original_func_treelize, return_type=Tensor) )