提交 adf83c6a 编写于 作者: HansBug's avatar HansBug 😆

doc(hansbug): complete auto documentation for Tensor class

上级 33d88ece
import io
from contextlib import contextmanager
from functools import partial
from typing import Optional, Tuple, List
......@@ -18,8 +20,8 @@ def strip_docs(doc: Optional[str]) -> Tuple[str, List[str]]:
l, r = 0, min(map(len, _exist_lines))
while l < r:
m = (l + r + 1) // 2
_prefixes = set(map(lambda x: x[:m], _exist_lines))
l, r = (m, r) if len(_prefixes) <= 1 else (l, m - 1)
_prefixes = list(map(lambda x: x[:m], _exist_lines))
l, r = (m, r) if len(set(_prefixes)) <= 1 and not _prefixes[0].strip() else (l, m - 1)
_indent = list(map(lambda x: x[:l], _exist_lines))[0]
_stripped_lines = list(map(lambda x: x[len(_indent):] if x.strip() else '', _lines))
......@@ -52,18 +54,33 @@ def print_doc(doc: str, strip: bool = True, indent: str = '', file=None):
_print()
def print_block(doc: str, name: str, value: Optional[str] = None,
@contextmanager
def print_block(name: str, value: Optional[str] = None,
params: Optional[dict] = None, file=None):
_print = partial(print, file=file)
_print(f'.. {name}:: {str(value) if value is not None else ""}')
_print(f'.. {name + "::" if name else ""} {str(value) if value is not None else ""}')
for k, v in (params or {}).items():
_print(f' :{k}: {str(v) if v is not None else ""}')
_print()
print_doc(doc, strip=True, indent=' ', file=file)
with io.StringIO() as bf:
try:
yield bf
finally:
bf.flush()
print_doc(bf.getvalue(), strip=True, indent=' ', file=file)
def current_module(module: str, file=None):
_print = partial(print, file=file)
_print(f'.. currentmodule:: {module}')
_print()
class _TempClazz:
@property
def prop(self):
return None
PropertyType = type(_TempClazz.prop)
......@@ -7,10 +7,17 @@ from docs import print_title, current_module, get_origin, print_block, print_doc
_DOC_FROM_TAG = '__doc_from__'
_H2_PATTERN = re.compile('-{3,}')
_numpy_version = np.__version__
_short_version = '.'.join(_numpy_version.split('.')[:2])
def _raw_doc_process(doc: str) -> str:
_doc = _H2_PATTERN.sub(lambda x: '~' * len(x.group(0)), doc)
_doc = _doc.replace(' : ', ' \\: ')
return _doc
if __name__ == '__main__':
_numpy_version = np.__version__
_short_version = '.'.join(_numpy_version.split('.')[:2])
print_title(tnp.funcs.__name__, levelc='=')
current_module(tnp.funcs.__name__)
......@@ -19,10 +26,12 @@ if __name__ == '__main__':
_origin = get_origin(_item)
print_title(_name, levelc='-')
print_block('', 'autofunction', value=_name)
with print_block('autofunction', value=_name):
pass
if _origin and (_origin.__doc__ or '').strip():
print_block(f"""
with print_block('admonition', value='Numpy Version Related', params={'class': 'tip'}) as f:
print_doc(f"""
This documentation is based on
`numpy.{_name} <https://numpy.org/doc/{_short_version}/reference/generated/numpy.{_name}.html>`_
in `numpy v{_numpy_version} <https://numpy.org/doc/{_short_version}/>`_.
......@@ -38,8 +47,8 @@ with the following command and find its documentation.
The arguments and keyword arguments supported in numpy v{_numpy_version} is listed below.
""", 'admonition', value='Numpy Version Related', params={'class': 'tip'})
""", file=f)
print()
print_doc(_H2_PATTERN.sub(lambda x: '~' * len(x.group(0)), _origin.__doc__ or ''))
print()
print_doc(_raw_doc_process(_origin.__doc__ or ''))
print()
......@@ -4,21 +4,25 @@ import treetensor.torch as ttorch
from docs import print_title, current_module, get_origin, print_block, print_doc
_DOC_FROM_TAG = '__doc_from__'
_torch_version = torch.__version__
if __name__ == '__main__':
_torch_version = torch.__version__
print_title(ttorch.funcs.__name__, levelc='=')
current_module(ttorch.__name__)
with print_block('automodule', value=ttorch.funcs.__name__):
pass
for _name in sorted(ttorch.funcs.__all__):
_item = getattr(ttorch.funcs, _name)
_origin = get_origin(_item)
print_title(_name, levelc='-')
print_block('', 'autofunction', value=_name)
with print_block('autofunction', value=_name):
pass
if _origin and (_origin.__doc__ or '').strip():
print_block(f"""
with print_block('admonition', value='Torch Version Related', params={'class': 'tip'}) as f:
print_doc(f"""
This documentation is based on
`torch.{_name} <https://pytorch.org/docs/{_torch_version}/generated/torch.{_name}.html>`_
in `torch v{_torch_version} <https://pytorch.org/docs/{_torch_version}/>`_.
......@@ -34,7 +38,10 @@ with the following command and find its documentation.
The arguments and keyword arguments supported in torch v{_torch_version} is listed below.
""", 'admonition', value='Torch Version Related', params={'class': 'tip'})
print_doc(f'.. function:: {_origin.__doc__.lstrip()}')
""", file=f)
print()
with print_block('') as f:
print_doc(f'.. function:: {_origin.__doc__.lstrip()}', file=f)
print()
import importlib
import torch
import treetensor.torch as ttorch
from docs import print_title, current_module, print_block, get_origin, PropertyType, print_doc
_DOC_FROM_TAG = '__doc_from__'
_DOC_TAG = '__doc_names__'
_ttorch_tensor = importlib.import_module('treetensor.torch.tensor')
_torch_version = torch.__version__
if __name__ == '__main__':
print_title(_ttorch_tensor.__name__, levelc='=')
current_module(ttorch.__name__)
with print_block('automodule', value=_ttorch_tensor.__name__):
pass
print_title(ttorch.Tensor.__name__, levelc='-')
with print_block('autoclass', value=ttorch.Tensor.__name__) as tf:
for _name in sorted(getattr(ttorch.Tensor, _DOC_TAG, [])):
_item = getattr(ttorch.Tensor, _name)
_origin = get_origin(_item)
with print_block(
'autoproperty' if isinstance(_item, PropertyType) else 'automethod',
value=_name, file=tf
) as f:
pass
if _origin and (_origin.__doc__ or '').strip():
with print_block('admonition', value='Torch Version Related', params={'class': 'tip'}, file=tf) as f:
print_doc(f"""
This documentation is based on
`torch.Tensor.{_name} <https://pytorch.org/docs/{_torch_version}/tensors.html#torch.Tensor.{_name}>`_
in `torch v{_torch_version} <https://pytorch.org/docs/{_torch_version}/>`_.
**Its arguments\' arrangements depend on the version of pytorch you installed**.
If some arguments listed here are not working properly, please check your pytorch's version
with the following command and find its documentation.
.. code-block:: shell
:linenos:
python -c 'import torch;print(torch.__version__)'
The arguments and keyword arguments supported in torch v{_torch_version} is listed below.
""", file=f)
print(file=tf)
with print_block('', file=tf) as f:
print_doc(f'.. function:: {_origin.__doc__.lstrip()}', file=f)
print(file=tf)
"""
Overview:
Common functions, based on ``torch`` module.
"""
import builtins
import torch
......
"""
Overview:
``Tensor`` class, based on ``torch`` module.
"""
import numpy as np
import torch
from treevalue import method_treelize
......
......@@ -47,10 +47,17 @@ class _TempClazz:
PropertyType = type(_TempClazz.prop)
def _is_property(clazz, name):
prop = getattr(clazz, name)
return isinstance(prop, PropertyType) and (
not hasattr(clazz.__base__, name) or getattr(clazz.__base__, name) is not prop
)
# noinspection PyTypeChecker
def _is_func_property(clazz, name):
def _is_func(clazz, name):
func = getattr(clazz, name)
return isinstance(func, (types.FunctionType, PropertyType)) and (
return isinstance(func, types.FunctionType) and (
not hasattr(clazz.__base__, name) or getattr(clazz.__base__, name) is not func
)
......@@ -67,8 +74,8 @@ def current_names(keep: bool = True):
members = set()
for name in dir(cls):
item = getattr(cls, name)
if (_is_func_property(cls, name) or _is_classmethod(cls, name)) and \
getattr(item, '__name__', None) == name: # should be public or protected
if ((_is_func(cls, name) or _is_classmethod(cls, name)) and getattr(item, '__name__', None) == name) or \
(_is_property(cls, name)):
members.add(name)
_old_names = _get_names(cls) if keep else set()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册