提交 e6afc90e 编写于 作者: HansBug's avatar HansBug 😆

doc(hansbug): optimize documentation in treetensor.torch

上级 ec55fa86
......@@ -3,6 +3,7 @@ import types
from typing import List
_DOC_TAG = '__doc_names__'
_DIRECT_DOC_TAG = '__direct_doc__'
def _is_tagged_name(clazz, name):
......@@ -38,14 +39,15 @@ if __name__ == '__main__':
print()
_item = getattr(_module, _name)
if isinstance(_item, types.FunctionType):
print(f'.. autofunction:: {package_name}.{_name}')
print()
elif isinstance(_item, type):
print(f'.. autoclass:: {package_name}.{_name}')
print(f' :members: {", ".join(sorted(_find_class_members(_item)))}')
print()
if getattr(_item, _DIRECT_DOC_TAG, None):
print(_item.__doc__)
else:
print(f'.. autodata:: {package_name}.{_name}')
print(f' :annotation:')
print()
if isinstance(_item, types.FunctionType):
print(f'.. autofunction:: {package_name}.{_name}')
elif isinstance(_item, type):
print(f'.. autoclass:: {package_name}.{_name}')
print(f' :members: {", ".join(sorted(_find_class_members(_item)))}')
else:
print(f'.. autodata:: {package_name}.{_name}')
print(f' :annotation:')
print()
......@@ -3,15 +3,18 @@ from typing import List
import torch
from treevalue import func_treelize as original_func_treelize
from treevalue.utils import post_process
from .tensor import TreeTensor, tireduce
from ..common import TreeObject, ireduce
from ..utils import inherit_doc as original_inherit_doc
from ..utils import replaceable_partial
from ..utils import replaceable_partial, direct_doc, inherit_doc
def _doc_stripper(src, _, lines: List[str]):
_name, _version = src.__name__, torch.__version__
if lines:
lines[0] = f'.. function:: {lines[0]}'
return [
f'.. note::',
f'',
......@@ -20,12 +23,12 @@ def _doc_stripper(src, _, lines: List[str]):
f' in `torch v{_version} <https://pytorch.org/docs/{_version}/>`_.',
f' **Its arguments\' arrangements depend on the version of pytorch you installed**.',
f'',
*lines[1:]
*lines,
]
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor)
inherit_doc = replaceable_partial(original_inherit_doc, stripper=_doc_stripper)
docs = post_process(post_process(direct_doc))(replaceable_partial(inherit_doc, stripper=_doc_stripper))
__all__ = [
'zeros', 'zeros_like',
......@@ -39,99 +42,99 @@ __all__ = [
]
@inherit_doc(torch.zeros)
@docs(torch.zeros)
@func_treelize()
def zeros(*args, **kwargs):
return torch.zeros(*args, **kwargs)
@inherit_doc(torch.zeros_like)
@docs(torch.zeros_like)
@func_treelize()
def zeros_like(input_, *args, **kwargs):
return torch.zeros_like(input_, *args, **kwargs)
@inherit_doc(torch.randn)
@docs(torch.randn)
@func_treelize()
def randn(*args, **kwargs):
return torch.randn(*args, **kwargs)
@inherit_doc(torch.randn_like)
@docs(torch.randn_like)
@func_treelize()
def randn_like(input_, *args, **kwargs):
return torch.randn_like(input_, *args, **kwargs)
@inherit_doc(torch.randint)
@docs(torch.randint)
@func_treelize()
def randint(*args, **kwargs):
return torch.randint(*args, **kwargs)
@inherit_doc(torch.randint_like)
@docs(torch.randint_like)
@func_treelize()
def randint_like(input_, *args, **kwargs):
return torch.randint_like(input_, *args, **kwargs)
@inherit_doc(torch.ones)
@docs(torch.ones)
@func_treelize()
def ones(*args, **kwargs):
return torch.ones(*args, **kwargs)
@inherit_doc(torch.ones_like)
@docs(torch.ones_like)
@func_treelize()
def ones_like(input_, *args, **kwargs):
return torch.ones_like(input_, *args, **kwargs)
@inherit_doc(torch.full)
@docs(torch.full)
@func_treelize()
def full(*args, **kwargs):
return torch.full(*args, **kwargs)
@inherit_doc(torch.full_like)
@docs(torch.full_like)
@func_treelize()
def full_like(input_, *args, **kwargs):
return torch.full_like(input_, *args, **kwargs)
@inherit_doc(torch.empty)
@docs(torch.empty)
@func_treelize()
def empty(*args, **kwargs):
return torch.empty(*args, **kwargs)
@inherit_doc(torch.empty_like)
@docs(torch.empty_like)
@func_treelize()
def empty_like(input_, *args, **kwargs):
return torch.empty_like(input_, *args, **kwargs)
@inherit_doc(torch.all)
@docs(torch.all)
@tireduce(torch.all)
@func_treelize(return_type=TreeObject)
def all(input_, *args, **kwargs):
return torch.all(input_, *args, **kwargs)
@inherit_doc(torch.any)
@docs(torch.any)
@tireduce(torch.any)
@func_treelize(return_type=TreeObject)
def any(input_, *args, **kwargs):
return torch.any(input_, *args, **kwargs)
@inherit_doc(torch.eq)
@docs(torch.eq)
@func_treelize()
def eq(input_, other, *args, **kwargs):
return torch.eq(input_, other, *args, **kwargs)
@inherit_doc(torch.equal)
@docs(torch.equal)
@ireduce(builtins.all)
@func_treelize()
def equal(input_, other, *args, **kwargs):
......
......@@ -2,7 +2,7 @@ import os
from typing import List, Optional, Callable, Any
__all__ = [
'inherit_doc',
'inherit_doc', 'direct_doc',
]
......@@ -37,3 +37,11 @@ def inherit_doc(src, stripper: Optional[Callable[[Any, Any, List[str]], List[str
return obj
return _decorator
_DIRECT_DOC = '__direct_doc__'
def direct_doc(obj):
setattr(obj, _DIRECT_DOC, True)
return obj
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册