提交 64f9f118 编写于 作者: HansBug's avatar HansBug 😆

doc(hansbug): optimize documentation for treetensor.torch.funcs

上级 798d9183
from functools import partial
from typing import Optional, Tuple, List
def strip_docs(doc: Optional[str]) -> Tuple[str, List[str]]:
_lines = (doc or '').splitlines()
_non_empty_lines = sorted(filter(lambda t: t[1].strip(), enumerate(_lines)))
if _non_empty_lines:
_first_line_id, _ = _non_empty_lines[0]
_lines = _lines[_first_line_id:]
else:
_lines = []
_exist_lines = list(filter(str.strip, _lines))
if not _exist_lines:
_indent = ''
else:
l, r = 0, min(map(len, _exist_lines))
while l < r:
m = (l + r + 1) // 2
_prefixes = set(map(lambda x: x[:m], _exist_lines))
l, r = (m, r) if len(_prefixes) <= 1 else (l, m - 1)
_indent = list(map(lambda x: x[:l], _exist_lines))[0]
_stripped_lines = list(map(lambda x: x[len(_indent):] if x.strip() else '', _lines))
return _indent, _stripped_lines
_DOC_FROM_TAG = '__doc_from__'
def get_origin(obj):
return getattr(obj, _DOC_FROM_TAG, None)
def print_title(title: str, levelc='=', file=None):
_print = partial(print, file=file)
_print(title)
_print(levelc * (len(title) + 5))
_print()
def print_doc(doc: str, strip: bool = True, indent: str = '', file=None):
_print = partial(print, indent, file=file, sep='')
if strip:
_, _lines = strip_docs(doc or '')
else:
_lines = (doc or '').splitlines()
for _line in _lines:
_print(_line)
_print()
def print_block(doc: str, name: str, value: Optional[str] = None,
params: Optional[dict] = None, file=None):
_print = partial(print, file=file)
_print(f'.. {name}:: {str(value) if value is not None else ""}')
for k, v in (params or {}).items():
_print(f' :{k}: {str(v) if v is not None else ""}')
_print()
print_doc(doc, strip=True, indent=' ', file=file)
def current_module(module: str, file=None):
_print = partial(print, file=file)
_print(f'.. currentmodule:: {module}')
_print()
import re
import numpy as np
import treetensor.numpy as tnp
from docs import print_title, current_module, get_origin, print_block, print_doc
_DOC_FROM_TAG = '__doc_from__'
_H2_PATTERN = re.compile('-{3,}')
if __name__ == '__main__':
_numpy_version = np.__version__
_short_version = '.'.join(_numpy_version.split('.')[:2])
print_title(tnp.funcs.__name__, levelc='=')
current_module(tnp.funcs.__name__)
for _name in sorted(tnp.funcs.__all__):
_item = getattr(tnp.funcs, _name)
_origin = get_origin(_item)
print_title(_name, levelc='-')
print_block('', 'autofunction', value=_name)
if _origin and (_origin.__doc__ or '').strip():
print_block(f"""
This documentation is based on
`numpy.{_name} <https://numpy.org/doc/{_short_version}/reference/generated/numpy.{_name}.html>`_
in `numpy v{_numpy_version} <https://numpy.org/doc/{_short_version}/>`_.
**Its arguments\' arrangements depend on the version of numpy you installed**.
If some arguments listed here are not working properly, please check your numpy's version
with the following command and find its documentation.
.. code-block:: shell
:linenos:
python -c 'import numpy as np;print(np.__version__)'
The arguments and keyword arguments supported in numpy v{_numpy_version} is listed below.
""", 'note')
print()
print_doc(_H2_PATTERN.sub(lambda x: '~' * len(x.group(0)), _origin.__doc__ or ''))
print()
treetensor.numpy.funcs
\ No newline at end of file
treetensor.numpy.numpy
\ No newline at end of file
import torch
import treetensor.torch as ttorch
from docs import print_title, current_module, get_origin, print_block, print_doc
_DOC_FROM_TAG = '__doc_from__'
if __name__ == '__main__':
_torch_version = torch.__version__
print_title(ttorch.funcs.__name__, levelc='=')
current_module(ttorch.funcs.__name__)
for _name in sorted(ttorch.funcs.__all__):
_item = getattr(ttorch.funcs, _name)
_origin = get_origin(_item)
print_title(_name, levelc='-')
print_block('', 'autofunction', value=_name)
if _origin and (_origin.__doc__ or '').strip():
print_block(f"""
This documentation is based on
`torch.{_name} <https://pytorch.org/docs/{_torch_version}/generated/torch.{_name}.html>`_
in `torch v{_torch_version} <https://pytorch.org/docs/{_torch_version}/>`_.
**Its arguments\' arrangements depend on the version of pytorch you installed**.
If some arguments listed here are not working properly, please check your pytorch's version
with the following command and find its documentation.
.. code-block:: shell
:linenos:
python -c 'import torch;print(torch.__version__)'
The arguments and keyword arguments supported in torch v{_torch_version} is listed below.
""", 'note')
print_doc(f'.. function:: {_origin.__doc__.lstrip()}')
print()
treetensor.torch.funcs
\ No newline at end of file
PYTHON := $(shell which python)
PYTHON := $(shell which python)
SOURCE ?= .
RSTC_FILES := $(shell find ${SOURCE} -name *.rstc)
RST_RESULTS := $(addsuffix .auto.rst, $(basename ${RSTC_FILES}))
SOURCE ?= .
PYTHON_SCRIPTS := $(shell find ${SOURCE} -name *.rst.py)
PYTHON_RESULTS := $(addsuffix .auto.rst, $(basename $(basename ${PYTHON_SCRIPTS})))
APIDOC_GEN_PY := $(shell readlink -f ${SOURCE}/apidoc_gen.py)
%.auto.rst: %.rstc ${APIDOC_GEN_PY}
%.auto.rst: %.rst.py
cd "$(shell dirname $(shell readlink -f $<))" && \
PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
cat "$(shell readlink -f $<)" | $(PYTHON) "${APIDOC_GEN_PY}" > "$(shell readlink -f $@)"
$(PYTHON) "$(shell readlink -f $<)" > "$(shell readlink -f $@)"
build: ${RST_RESULTS}
build: ${PYTHON_RESULTS}
all: build
......
import importlib
import types
from typing import List
_DOC_TAG = '__doc_names__'
_DIRECT_DOC_TAG = '__direct_doc__'
def _is_tagged_name(clazz, name):
return name in set(getattr(clazz, _DOC_TAG, set()))
def _find_class_members(clazz: type) -> List[str]:
members = []
for name in dir(clazz):
item = getattr(clazz, name)
if _is_tagged_name(clazz, name) and \
getattr(item, '__name__', None) == name: # should be public or protected
members.append(name)
return members
if __name__ == '__main__':
package_name = input().strip()
_module = importlib.import_module(package_name)
_alls = getattr(_module, '__all__')
print(package_name)
print('=' * (len(package_name) + 5))
print()
print(f'.. automodule:: {package_name}')
print()
for _name in sorted(_alls):
print(_name)
print('-' * (len(_name) + 5))
print()
_item = getattr(_module, _name)
if getattr(_item, _DIRECT_DOC_TAG, None):
print(_item.__doc__)
else:
if isinstance(_item, types.FunctionType):
print(f'.. autofunction:: {package_name}.{_name}')
elif isinstance(_item, type):
print(f'.. autoclass:: {package_name}.{_name}')
print(f' :members: {", ".join(sorted(_find_class_members(_item)))}')
else:
print(f'.. autodata:: {package_name}.{_name}')
print(f' :annotation:')
print()
......@@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize
from .numpy import TreeNumpy
from ..common import ireduce, TreeObject
from ..utils import replaceable_partial, inherit_doc
from ..utils import replaceable_partial, doc_from
__all__ = [
'all', 'any',
......@@ -30,30 +30,29 @@ def _doc_stripper(src, _, lines: List[str]):
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeNumpy)
docs = replaceable_partial(inherit_doc, stripper=_doc_stripper)
@docs(np.all)
@doc_from(np.all)
@ireduce(builtins.all)
@func_treelize(return_type=TreeObject)
def all(a, *args, **kwargs):
return np.all(a, *args, **kwargs)
@docs(np.any)
@doc_from(np.any)
@ireduce(builtins.any)
@func_treelize()
def any(a, *args, **kwargs):
return np.any(a, *args, **kwargs)
@docs(np.equal)
@doc_from(np.equal)
@func_treelize()
def equal(x1, x2, *args, **kwargs):
return np.equal(x1, x2, *args, **kwargs)
@docs(np.array_equal)
@doc_from(np.array_equal)
@func_treelize()
def array_equal(a1, a2, *args, **kwargs):
return np.array_equal(a1, a2, *args, **kwargs)
import builtins
from typing import List
import torch
from treevalue import func_treelize as original_func_treelize
from treevalue.utils import post_process
from .tensor import TreeTensor, tireduce
from ..common import TreeObject, ireduce
from ..utils import replaceable_partial, direct_doc, inherit_doc
from ..utils import replaceable_partial, doc_from
__all__ = [
'zeros', 'zeros_like',
......@@ -20,121 +18,123 @@ __all__ = [
'eq', 'equal',
]
def _doc_stripper(src, _, lines: List[str]):
_name, _version = src.__name__, torch.__version__
if lines:
lines[0] = f'.. function:: {lines[0]}'
return [
f'.. note::',
f'',
f' This documentation is based on '
f' `torch.{_name} <https://pytorch.org/docs/{_version}/generated/torch.{_name}.html>`_ '
f' in `torch v{_version} <https://pytorch.org/docs/{_version}/>`_.',
f' **Its arguments\' arrangements depend on the version of pytorch you installed**.',
f'',
*lines,
]
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor)
docs = post_process(post_process(direct_doc))(replaceable_partial(inherit_doc, stripper=_doc_stripper))
@docs(torch.zeros)
@doc_from(torch.zeros)
@func_treelize()
def zeros(*args, **kwargs):
return torch.zeros(*args, **kwargs)
@docs(torch.zeros_like)
@doc_from(torch.zeros_like)
@func_treelize()
def zeros_like(input_, *args, **kwargs):
return torch.zeros_like(input_, *args, **kwargs)
@docs(torch.randn)
@doc_from(torch.randn)
@func_treelize()
def randn(*args, **kwargs):
return torch.randn(*args, **kwargs)
@docs(torch.randn_like)
@doc_from(torch.randn_like)
@func_treelize()
def randn_like(input_, *args, **kwargs):
return torch.randn_like(input_, *args, **kwargs)
@docs(torch.randint)
@doc_from(torch.randint)
@func_treelize()
def randint(*args, **kwargs):
return torch.randint(*args, **kwargs)
@docs(torch.randint_like)
@doc_from(torch.randint_like)
@func_treelize()
def randint_like(input_, *args, **kwargs):
return torch.randint_like(input_, *args, **kwargs)
@docs(torch.ones)
@doc_from(torch.ones)
@func_treelize()
def ones(*args, **kwargs):
return torch.ones(*args, **kwargs)
@docs(torch.ones_like)
@doc_from(torch.ones_like)
@func_treelize()
def ones_like(input_, *args, **kwargs):
return torch.ones_like(input_, *args, **kwargs)
@docs(torch.full)
@doc_from(torch.full)
@func_treelize()
def full(*args, **kwargs):
return torch.full(*args, **kwargs)
@docs(torch.full_like)
@doc_from(torch.full_like)
@func_treelize()
def full_like(input_, *args, **kwargs):
return torch.full_like(input_, *args, **kwargs)
@docs(torch.empty)
@doc_from(torch.empty)
@func_treelize()
def empty(*args, **kwargs):
return torch.empty(*args, **kwargs)
@docs(torch.empty_like)
@doc_from(torch.empty_like)
@func_treelize()
def empty_like(input_, *args, **kwargs):
return torch.empty_like(input_, *args, **kwargs)
@docs(torch.all)
@doc_from(torch.all)
@tireduce(torch.all)
@func_treelize(return_type=TreeObject)
def all(input_, *args, **kwargs):
"""
In ``treetensor``, you can get the ``all`` result of a whole tree with this function.
Example::
>>> all(torch.tensor([True, True])) # the same as torch.all
torch.tensor(True)
>>> all(TreeTensor({
>>> 'a': torch.tensor([True, True]),
>>> 'b': torch.tensor([True, True]),
>>> }))
torch.tensor(True)
>>> all(TreeTensor({
>>> 'a': torch.tensor([True, True]),
>>> 'b': torch.tensor([True, False]),
>>> }))
torch.tensor(False)
"""
return torch.all(input_, *args, **kwargs)
@docs(torch.any)
@doc_from(torch.any)
@tireduce(torch.any)
@func_treelize(return_type=TreeObject)
def any(input_, *args, **kwargs):
return torch.any(input_, *args, **kwargs)
@docs(torch.eq)
@doc_from(torch.eq)
@func_treelize()
def eq(input_, other, *args, **kwargs):
return torch.eq(input_, other, *args, **kwargs)
@docs(torch.equal)
@doc_from(torch.equal)
@ireduce(builtins.all)
@func_treelize()
def equal(input_, other, *args, **kwargs):
......
......@@ -6,7 +6,7 @@ from treevalue.utils import pre_process
from .size import TreeSize
from ..common import TreeObject, TreeData, ireduce
from ..numpy import TreeNumpy
from ..utils import inherit_names, current_names
from ..utils import inherit_names, current_names, doc_from
__all__ = [
'TreeTensor'
......@@ -20,56 +20,68 @@ tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduc
@current_names()
@inherit_names(TreeData)
class TreeTensor(TreeData):
@doc_from(torch.Tensor.numpy)
@method_treelize(return_type=TreeNumpy)
def numpy(self: torch.Tensor) -> np.ndarray:
return self.numpy()
@doc_from(torch.Tensor.tolist)
@method_treelize(return_type=TreeObject)
def tolist(self: torch.Tensor):
return self.tolist()
@doc_from(torch.Tensor.cpu)
@method_treelize()
def cpu(self: torch.Tensor, *args, **kwargs):
return self.cpu(*args, **kwargs)
@doc_from(torch.Tensor.cuda)
@method_treelize()
def cuda(self: torch.Tensor, *args, **kwargs):
return self.cuda(*args, **kwargs)
@doc_from(torch.Tensor.to)
@method_treelize()
def to(self: torch.Tensor, *args, **kwargs):
return self.to(*args, **kwargs)
@doc_from(torch.Tensor.numel)
@ireduce(sum)
@method_treelize(return_type=TreeObject)
def numel(self: torch.Tensor):
return self.numel()
@property
@doc_from(torch.Tensor.shape)
@method_treelize(return_type=TreeSize)
def shape(self: torch.Tensor):
return self.shape
@doc_from(torch.Tensor.all)
@tireduce(torch.all)
@method_treelize(return_type=TreeObject)
def all(self: torch.Tensor, *args, **kwargs) -> bool:
return self.all(*args, **kwargs)
@doc_from(torch.Tensor.any)
@tireduce(torch.any)
@method_treelize(return_type=TreeObject)
def any(self: torch.Tensor, *args, **kwargs) -> bool:
return self.any(*args, **kwargs)
@doc_from(torch.Tensor.max)
@tireduce(torch.max)
@method_treelize(return_type=TreeObject)
def max(self: torch.Tensor, *args, **kwargs):
return self.max(*args, **kwargs)
@doc_from(torch.Tensor.min)
@tireduce(torch.min)
@method_treelize(return_type=TreeObject)
def min(self: torch.Tensor, *args, **kwargs):
return self.min(*args, **kwargs)
@doc_from(torch.Tensor.sum)
@tireduce(torch.sum)
@method_treelize(return_type=TreeObject)
def sum(self: torch.Tensor, *args, **kwargs):
......
import os
from typing import List, Optional, Callable, Any
__all__ = [
'inherit_doc', 'direct_doc',
'doc_from',
]
def _strip_lines(doc: Optional[str]):
_lines = (doc or '').splitlines()
_first_line_id, _ = sorted(filter(lambda t: t[1].strip(), enumerate(_lines)))[0]
_lines = _lines[_first_line_id:]
_exist_lines = list(filter(str.strip, _lines))
if not _exist_lines:
_indent = ''
else:
l, r = 0, min(map(len, _exist_lines))
while l < r:
m = (l + r + 1) // 2
_prefixes = set(map(lambda x: x[:m], _exist_lines))
l, r = (m, r) if len(_prefixes) <= 1 else (l, m - 1)
_indent = list(map(lambda x: x[:l], _exist_lines))[0]
_stripped_lines = list(map(lambda x: x[len(_indent):] if x.strip() else '', _lines))
return _indent, _stripped_lines
def _unstrip_lines(indent: str, stripped_lines: List[str]) -> str:
return os.linesep.join(map(lambda x: indent + x, stripped_lines))
_DOC_FROM_TAG = '__doc_from__'
def inherit_doc(src, stripper: Optional[Callable[[Any, Any, List[str]], List[str]]] = None):
_indent, _stripped_lines = _strip_lines(src.__doc__)
def doc_from(src):
def _decorator(obj):
_lines = (stripper or (lambda s, o, x: x))(src, obj, _stripped_lines)
obj.__doc__ = _unstrip_lines(_indent, _lines)
setattr(obj, _DOC_FROM_TAG, src)
return obj
return _decorator
_DIRECT_DOC = '__direct_doc__'
def direct_doc(obj):
setattr(obj, _DIRECT_DOC, True)
return obj
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册