From e6afc90ee793c7308c4ea807d8464863b800b5fd Mon Sep 17 00:00:00 2001 From: HansBug Date: Sun, 12 Sep 2021 23:31:54 +0800 Subject: [PATCH] doc(hansbug): optimize documentation in treetensor.torch --- docs/source/apidoc_gen.py | 22 +++++++++++--------- treetensor/torch/funcs.py | 43 +++++++++++++++++++++------------------ treetensor/utils/doc.py | 10 ++++++++- 3 files changed, 44 insertions(+), 31 deletions(-) diff --git a/docs/source/apidoc_gen.py b/docs/source/apidoc_gen.py index 939814799..6b0f27c3a 100644 --- a/docs/source/apidoc_gen.py +++ b/docs/source/apidoc_gen.py @@ -3,6 +3,7 @@ import types from typing import List _DOC_TAG = '__doc_names__' +_DIRECT_DOC_TAG = '__direct_doc__' def _is_tagged_name(clazz, name): @@ -38,14 +39,15 @@ if __name__ == '__main__': print() _item = getattr(_module, _name) - if isinstance(_item, types.FunctionType): - print(f'.. autofunction:: {package_name}.{_name}') - print() - elif isinstance(_item, type): - print(f'.. autoclass:: {package_name}.{_name}') - print(f' :members: {", ".join(sorted(_find_class_members(_item)))}') - print() + if getattr(_item, _DIRECT_DOC_TAG, None): + print(_item.__doc__) else: - print(f'.. autodata:: {package_name}.{_name}') - print(f' :annotation:') - print() + if isinstance(_item, types.FunctionType): + print(f'.. autofunction:: {package_name}.{_name}') + elif isinstance(_item, type): + print(f'.. autoclass:: {package_name}.{_name}') + print(f' :members: {", ".join(sorted(_find_class_members(_item)))}') + else: + print(f'.. autodata:: {package_name}.{_name}') + print(f' :annotation:') + print() diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index 253221dd4..a977ab009 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -3,15 +3,18 @@ from typing import List import torch from treevalue import func_treelize as original_func_treelize +from treevalue.utils import post_process from .tensor import TreeTensor, tireduce from ..common import TreeObject, ireduce -from ..utils import inherit_doc as original_inherit_doc -from ..utils import replaceable_partial +from ..utils import replaceable_partial, direct_doc, inherit_doc def _doc_stripper(src, _, lines: List[str]): _name, _version = src.__name__, torch.__version__ + if lines: + lines[0] = f'.. function:: {lines[0]}' + return [ f'.. note::', f'', @@ -20,12 +23,12 @@ def _doc_stripper(src, _, lines: List[str]): f' in `torch v{_version} `_.', f' **Its arguments\' arrangements depend on the version of pytorch you installed**.', f'', - *lines[1:] + *lines, ] func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor) -inherit_doc = replaceable_partial(original_inherit_doc, stripper=_doc_stripper) +docs = post_process(post_process(direct_doc))(replaceable_partial(inherit_doc, stripper=_doc_stripper)) __all__ = [ 'zeros', 'zeros_like', @@ -39,99 +42,99 @@ __all__ = [ ] -@inherit_doc(torch.zeros) +@docs(torch.zeros) @func_treelize() def zeros(*args, **kwargs): return torch.zeros(*args, **kwargs) -@inherit_doc(torch.zeros_like) +@docs(torch.zeros_like) @func_treelize() def zeros_like(input_, *args, **kwargs): return torch.zeros_like(input_, *args, **kwargs) -@inherit_doc(torch.randn) +@docs(torch.randn) @func_treelize() def randn(*args, **kwargs): return torch.randn(*args, **kwargs) -@inherit_doc(torch.randn_like) +@docs(torch.randn_like) @func_treelize() def randn_like(input_, *args, **kwargs): return torch.randn_like(input_, *args, **kwargs) -@inherit_doc(torch.randint) +@docs(torch.randint) @func_treelize() def randint(*args, **kwargs): return torch.randint(*args, **kwargs) -@inherit_doc(torch.randint_like) +@docs(torch.randint_like) @func_treelize() def randint_like(input_, *args, **kwargs): return torch.randint_like(input_, *args, **kwargs) -@inherit_doc(torch.ones) +@docs(torch.ones) @func_treelize() def ones(*args, **kwargs): return torch.ones(*args, **kwargs) -@inherit_doc(torch.ones_like) +@docs(torch.ones_like) @func_treelize() def ones_like(input_, *args, **kwargs): return torch.ones_like(input_, *args, **kwargs) -@inherit_doc(torch.full) +@docs(torch.full) @func_treelize() def full(*args, **kwargs): return torch.full(*args, **kwargs) -@inherit_doc(torch.full_like) +@docs(torch.full_like) @func_treelize() def full_like(input_, *args, **kwargs): return torch.full_like(input_, *args, **kwargs) -@inherit_doc(torch.empty) +@docs(torch.empty) @func_treelize() def empty(*args, **kwargs): return torch.empty(*args, **kwargs) -@inherit_doc(torch.empty_like) +@docs(torch.empty_like) @func_treelize() def empty_like(input_, *args, **kwargs): return torch.empty_like(input_, *args, **kwargs) -@inherit_doc(torch.all) +@docs(torch.all) @tireduce(torch.all) @func_treelize(return_type=TreeObject) def all(input_, *args, **kwargs): return torch.all(input_, *args, **kwargs) -@inherit_doc(torch.any) +@docs(torch.any) @tireduce(torch.any) @func_treelize(return_type=TreeObject) def any(input_, *args, **kwargs): return torch.any(input_, *args, **kwargs) -@inherit_doc(torch.eq) +@docs(torch.eq) @func_treelize() def eq(input_, other, *args, **kwargs): return torch.eq(input_, other, *args, **kwargs) -@inherit_doc(torch.equal) +@docs(torch.equal) @ireduce(builtins.all) @func_treelize() def equal(input_, other, *args, **kwargs): diff --git a/treetensor/utils/doc.py b/treetensor/utils/doc.py index a2782041e..6c0f3698e 100644 --- a/treetensor/utils/doc.py +++ b/treetensor/utils/doc.py @@ -2,7 +2,7 @@ import os from typing import List, Optional, Callable, Any __all__ = [ - 'inherit_doc', + 'inherit_doc', 'direct_doc', ] @@ -37,3 +37,11 @@ def inherit_doc(src, stripper: Optional[Callable[[Any, Any, List[str]], List[str return obj return _decorator + + +_DIRECT_DOC = '__direct_doc__' + + +def direct_doc(obj): + setattr(obj, _DIRECT_DOC, True) + return obj -- GitLab