提交 adf83c6a 编写于 作者: HansBug's avatar HansBug 😆

doc(hansbug): complete auto documentation for Tensor class

上级 33d88ece
import io
from contextlib import contextmanager
from functools import partial from functools import partial
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
...@@ -18,8 +20,8 @@ def strip_docs(doc: Optional[str]) -> Tuple[str, List[str]]: ...@@ -18,8 +20,8 @@ def strip_docs(doc: Optional[str]) -> Tuple[str, List[str]]:
l, r = 0, min(map(len, _exist_lines)) l, r = 0, min(map(len, _exist_lines))
while l < r: while l < r:
m = (l + r + 1) // 2 m = (l + r + 1) // 2
_prefixes = set(map(lambda x: x[:m], _exist_lines)) _prefixes = list(map(lambda x: x[:m], _exist_lines))
l, r = (m, r) if len(_prefixes) <= 1 else (l, m - 1) l, r = (m, r) if len(set(_prefixes)) <= 1 and not _prefixes[0].strip() else (l, m - 1)
_indent = list(map(lambda x: x[:l], _exist_lines))[0] _indent = list(map(lambda x: x[:l], _exist_lines))[0]
_stripped_lines = list(map(lambda x: x[len(_indent):] if x.strip() else '', _lines)) _stripped_lines = list(map(lambda x: x[len(_indent):] if x.strip() else '', _lines))
...@@ -52,18 +54,33 @@ def print_doc(doc: str, strip: bool = True, indent: str = '', file=None): ...@@ -52,18 +54,33 @@ def print_doc(doc: str, strip: bool = True, indent: str = '', file=None):
_print() _print()
def print_block(doc: str, name: str, value: Optional[str] = None, @contextmanager
def print_block(name: str, value: Optional[str] = None,
params: Optional[dict] = None, file=None): params: Optional[dict] = None, file=None):
_print = partial(print, file=file) _print = partial(print, file=file)
_print(f'.. {name}:: {str(value) if value is not None else ""}') _print(f'.. {name + "::" if name else ""} {str(value) if value is not None else ""}')
for k, v in (params or {}).items(): for k, v in (params or {}).items():
_print(f' :{k}: {str(v) if v is not None else ""}') _print(f' :{k}: {str(v) if v is not None else ""}')
_print() _print()
print_doc(doc, strip=True, indent=' ', file=file) with io.StringIO() as bf:
try:
yield bf
finally:
bf.flush()
print_doc(bf.getvalue(), strip=True, indent=' ', file=file)
def current_module(module: str, file=None): def current_module(module: str, file=None):
_print = partial(print, file=file) _print = partial(print, file=file)
_print(f'.. currentmodule:: {module}') _print(f'.. currentmodule:: {module}')
_print() _print()
class _TempClazz:
@property
def prop(self):
return None
PropertyType = type(_TempClazz.prop)
...@@ -7,10 +7,17 @@ from docs import print_title, current_module, get_origin, print_block, print_doc ...@@ -7,10 +7,17 @@ from docs import print_title, current_module, get_origin, print_block, print_doc
_DOC_FROM_TAG = '__doc_from__' _DOC_FROM_TAG = '__doc_from__'
_H2_PATTERN = re.compile('-{3,}') _H2_PATTERN = re.compile('-{3,}')
_numpy_version = np.__version__
_short_version = '.'.join(_numpy_version.split('.')[:2])
def _raw_doc_process(doc: str) -> str:
_doc = _H2_PATTERN.sub(lambda x: '~' * len(x.group(0)), doc)
_doc = _doc.replace(' : ', ' \\: ')
return _doc
if __name__ == '__main__': if __name__ == '__main__':
_numpy_version = np.__version__
_short_version = '.'.join(_numpy_version.split('.')[:2])
print_title(tnp.funcs.__name__, levelc='=') print_title(tnp.funcs.__name__, levelc='=')
current_module(tnp.funcs.__name__) current_module(tnp.funcs.__name__)
...@@ -19,10 +26,12 @@ if __name__ == '__main__': ...@@ -19,10 +26,12 @@ if __name__ == '__main__':
_origin = get_origin(_item) _origin = get_origin(_item)
print_title(_name, levelc='-') print_title(_name, levelc='-')
print_block('', 'autofunction', value=_name) with print_block('autofunction', value=_name):
pass
if _origin and (_origin.__doc__ or '').strip(): if _origin and (_origin.__doc__ or '').strip():
print_block(f""" with print_block('admonition', value='Numpy Version Related', params={'class': 'tip'}) as f:
print_doc(f"""
This documentation is based on This documentation is based on
`numpy.{_name} <https://numpy.org/doc/{_short_version}/reference/generated/numpy.{_name}.html>`_ `numpy.{_name} <https://numpy.org/doc/{_short_version}/reference/generated/numpy.{_name}.html>`_
in `numpy v{_numpy_version} <https://numpy.org/doc/{_short_version}/>`_. in `numpy v{_numpy_version} <https://numpy.org/doc/{_short_version}/>`_.
...@@ -38,8 +47,8 @@ with the following command and find its documentation. ...@@ -38,8 +47,8 @@ with the following command and find its documentation.
The arguments and keyword arguments supported in numpy v{_numpy_version} is listed below. The arguments and keyword arguments supported in numpy v{_numpy_version} is listed below.
""", 'admonition', value='Numpy Version Related', params={'class': 'tip'}) """, file=f)
print() print()
print_doc(_H2_PATTERN.sub(lambda x: '~' * len(x.group(0)), _origin.__doc__ or ''))
print() print_doc(_raw_doc_process(_origin.__doc__ or ''))
print()
...@@ -4,21 +4,25 @@ import treetensor.torch as ttorch ...@@ -4,21 +4,25 @@ import treetensor.torch as ttorch
from docs import print_title, current_module, get_origin, print_block, print_doc from docs import print_title, current_module, get_origin, print_block, print_doc
_DOC_FROM_TAG = '__doc_from__' _DOC_FROM_TAG = '__doc_from__'
_torch_version = torch.__version__
if __name__ == '__main__': if __name__ == '__main__':
_torch_version = torch.__version__
print_title(ttorch.funcs.__name__, levelc='=') print_title(ttorch.funcs.__name__, levelc='=')
current_module(ttorch.__name__) current_module(ttorch.__name__)
with print_block('automodule', value=ttorch.funcs.__name__):
pass
for _name in sorted(ttorch.funcs.__all__): for _name in sorted(ttorch.funcs.__all__):
_item = getattr(ttorch.funcs, _name) _item = getattr(ttorch.funcs, _name)
_origin = get_origin(_item) _origin = get_origin(_item)
print_title(_name, levelc='-') print_title(_name, levelc='-')
print_block('', 'autofunction', value=_name) with print_block('autofunction', value=_name):
pass
if _origin and (_origin.__doc__ or '').strip(): if _origin and (_origin.__doc__ or '').strip():
print_block(f""" with print_block('admonition', value='Torch Version Related', params={'class': 'tip'}) as f:
print_doc(f"""
This documentation is based on This documentation is based on
`torch.{_name} <https://pytorch.org/docs/{_torch_version}/generated/torch.{_name}.html>`_ `torch.{_name} <https://pytorch.org/docs/{_torch_version}/generated/torch.{_name}.html>`_
in `torch v{_torch_version} <https://pytorch.org/docs/{_torch_version}/>`_. in `torch v{_torch_version} <https://pytorch.org/docs/{_torch_version}/>`_.
...@@ -34,7 +38,10 @@ with the following command and find its documentation. ...@@ -34,7 +38,10 @@ with the following command and find its documentation.
The arguments and keyword arguments supported in torch v{_torch_version} is listed below. The arguments and keyword arguments supported in torch v{_torch_version} is listed below.
""", 'admonition', value='Torch Version Related', params={'class': 'tip'}) """, file=f)
print_doc(f'.. function:: {_origin.__doc__.lstrip()}') print()
with print_block('') as f:
print_doc(f'.. function:: {_origin.__doc__.lstrip()}', file=f)
print() print()
import importlib
import torch
import treetensor.torch as ttorch
from docs import print_title, current_module, print_block, get_origin, PropertyType, print_doc
_DOC_FROM_TAG = '__doc_from__'
_DOC_TAG = '__doc_names__'
_ttorch_tensor = importlib.import_module('treetensor.torch.tensor')
_torch_version = torch.__version__
if __name__ == '__main__':
print_title(_ttorch_tensor.__name__, levelc='=')
current_module(ttorch.__name__)
with print_block('automodule', value=_ttorch_tensor.__name__):
pass
print_title(ttorch.Tensor.__name__, levelc='-')
with print_block('autoclass', value=ttorch.Tensor.__name__) as tf:
for _name in sorted(getattr(ttorch.Tensor, _DOC_TAG, [])):
_item = getattr(ttorch.Tensor, _name)
_origin = get_origin(_item)
with print_block(
'autoproperty' if isinstance(_item, PropertyType) else 'automethod',
value=_name, file=tf
) as f:
pass
if _origin and (_origin.__doc__ or '').strip():
with print_block('admonition', value='Torch Version Related', params={'class': 'tip'}, file=tf) as f:
print_doc(f"""
This documentation is based on
`torch.Tensor.{_name} <https://pytorch.org/docs/{_torch_version}/tensors.html#torch.Tensor.{_name}>`_
in `torch v{_torch_version} <https://pytorch.org/docs/{_torch_version}/>`_.
**Its arguments\' arrangements depend on the version of pytorch you installed**.
If some arguments listed here are not working properly, please check your pytorch's version
with the following command and find its documentation.
.. code-block:: shell
:linenos:
python -c 'import torch;print(torch.__version__)'
The arguments and keyword arguments supported in torch v{_torch_version} is listed below.
""", file=f)
print(file=tf)
with print_block('', file=tf) as f:
print_doc(f'.. function:: {_origin.__doc__.lstrip()}', file=f)
print(file=tf)
"""
Overview:
Common functions, based on ``torch`` module.
"""
import builtins import builtins
import torch import torch
......
"""
Overview:
``Tensor`` class, based on ``torch`` module.
"""
import numpy as np import numpy as np
import torch import torch
from treevalue import method_treelize from treevalue import method_treelize
......
...@@ -47,10 +47,17 @@ class _TempClazz: ...@@ -47,10 +47,17 @@ class _TempClazz:
PropertyType = type(_TempClazz.prop) PropertyType = type(_TempClazz.prop)
def _is_property(clazz, name):
prop = getattr(clazz, name)
return isinstance(prop, PropertyType) and (
not hasattr(clazz.__base__, name) or getattr(clazz.__base__, name) is not prop
)
# noinspection PyTypeChecker # noinspection PyTypeChecker
def _is_func_property(clazz, name): def _is_func(clazz, name):
func = getattr(clazz, name) func = getattr(clazz, name)
return isinstance(func, (types.FunctionType, PropertyType)) and ( return isinstance(func, types.FunctionType) and (
not hasattr(clazz.__base__, name) or getattr(clazz.__base__, name) is not func not hasattr(clazz.__base__, name) or getattr(clazz.__base__, name) is not func
) )
...@@ -67,8 +74,8 @@ def current_names(keep: bool = True): ...@@ -67,8 +74,8 @@ def current_names(keep: bool = True):
members = set() members = set()
for name in dir(cls): for name in dir(cls):
item = getattr(cls, name) item = getattr(cls, name)
if (_is_func_property(cls, name) or _is_classmethod(cls, name)) and \ if ((_is_func(cls, name) or _is_classmethod(cls, name)) and getattr(item, '__name__', None) == name) or \
getattr(item, '__name__', None) == name: # should be public or protected (_is_property(cls, name)):
members.add(name) members.add(name)
_old_names = _get_names(cls) if keep else set() _old_names = _get_names(cls) if keep else set()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册