diff --git a/docs/source/_libs/docs.py b/docs/source/_libs/docs.py index ecf0722ee3f141386a630d53b8544dae7e37697c..b2ecf8702581c1414d24759b8c8e150f8f326919 100644 --- a/docs/source/_libs/docs.py +++ b/docs/source/_libs/docs.py @@ -1,3 +1,5 @@ +import io +from contextlib import contextmanager from functools import partial from typing import Optional, Tuple, List @@ -18,8 +20,8 @@ def strip_docs(doc: Optional[str]) -> Tuple[str, List[str]]: l, r = 0, min(map(len, _exist_lines)) while l < r: m = (l + r + 1) // 2 - _prefixes = set(map(lambda x: x[:m], _exist_lines)) - l, r = (m, r) if len(_prefixes) <= 1 else (l, m - 1) + _prefixes = list(map(lambda x: x[:m], _exist_lines)) + l, r = (m, r) if len(set(_prefixes)) <= 1 and not _prefixes[0].strip() else (l, m - 1) _indent = list(map(lambda x: x[:l], _exist_lines))[0] _stripped_lines = list(map(lambda x: x[len(_indent):] if x.strip() else '', _lines)) @@ -52,18 +54,33 @@ def print_doc(doc: str, strip: bool = True, indent: str = '', file=None): _print() -def print_block(doc: str, name: str, value: Optional[str] = None, +@contextmanager +def print_block(name: str, value: Optional[str] = None, params: Optional[dict] = None, file=None): _print = partial(print, file=file) - _print(f'.. {name}:: {str(value) if value is not None else ""}') + _print(f'.. {name + "::" if name else ""} {str(value) if value is not None else ""}') for k, v in (params or {}).items(): _print(f' :{k}: {str(v) if v is not None else ""}') _print() - print_doc(doc, strip=True, indent=' ', file=file) + with io.StringIO() as bf: + try: + yield bf + finally: + bf.flush() + print_doc(bf.getvalue(), strip=True, indent=' ', file=file) def current_module(module: str, file=None): _print = partial(print, file=file) _print(f'.. currentmodule:: {module}') _print() + + +class _TempClazz: + @property + def prop(self): + return None + + +PropertyType = type(_TempClazz.prop) diff --git a/docs/source/api_doc/numpy/funcs.rst.py b/docs/source/api_doc/numpy/funcs.rst.py index a5f3d75efb809f4e8ab0edff9a9ed8427d2d4a9c..4ac8f4303eaa2e99300ad5df55d94af75d1aaafe 100644 --- a/docs/source/api_doc/numpy/funcs.rst.py +++ b/docs/source/api_doc/numpy/funcs.rst.py @@ -7,10 +7,17 @@ from docs import print_title, current_module, get_origin, print_block, print_doc _DOC_FROM_TAG = '__doc_from__' _H2_PATTERN = re.compile('-{3,}') +_numpy_version = np.__version__ +_short_version = '.'.join(_numpy_version.split('.')[:2]) + + +def _raw_doc_process(doc: str) -> str: + _doc = _H2_PATTERN.sub(lambda x: '~' * len(x.group(0)), doc) + _doc = _doc.replace(' : ', ' \\: ') + return _doc + if __name__ == '__main__': - _numpy_version = np.__version__ - _short_version = '.'.join(_numpy_version.split('.')[:2]) print_title(tnp.funcs.__name__, levelc='=') current_module(tnp.funcs.__name__) @@ -19,10 +26,12 @@ if __name__ == '__main__': _origin = get_origin(_item) print_title(_name, levelc='-') - print_block('', 'autofunction', value=_name) + with print_block('autofunction', value=_name): + pass if _origin and (_origin.__doc__ or '').strip(): - print_block(f""" + with print_block('admonition', value='Numpy Version Related', params={'class': 'tip'}) as f: + print_doc(f""" This documentation is based on `numpy.{_name} `_ in `numpy v{_numpy_version} `_. @@ -38,8 +47,8 @@ with the following command and find its documentation. The arguments and keyword arguments supported in numpy v{_numpy_version} is listed below. - """, 'admonition', value='Numpy Version Related', params={'class': 'tip'}) + """, file=f) print() - print_doc(_H2_PATTERN.sub(lambda x: '~' * len(x.group(0)), _origin.__doc__ or '')) - print() + print_doc(_raw_doc_process(_origin.__doc__ or '')) + print() diff --git a/docs/source/api_doc/torch/funcs.rst.py b/docs/source/api_doc/torch/funcs.rst.py index b1dae64f6c9524f83f4280084065498df13f25ce..cb378690341dcbb6cc4cedc2d5454dde6f00c0bd 100644 --- a/docs/source/api_doc/torch/funcs.rst.py +++ b/docs/source/api_doc/torch/funcs.rst.py @@ -4,21 +4,25 @@ import treetensor.torch as ttorch from docs import print_title, current_module, get_origin, print_block, print_doc _DOC_FROM_TAG = '__doc_from__' +_torch_version = torch.__version__ if __name__ == '__main__': - _torch_version = torch.__version__ print_title(ttorch.funcs.__name__, levelc='=') current_module(ttorch.__name__) + with print_block('automodule', value=ttorch.funcs.__name__): + pass for _name in sorted(ttorch.funcs.__all__): _item = getattr(ttorch.funcs, _name) _origin = get_origin(_item) print_title(_name, levelc='-') - print_block('', 'autofunction', value=_name) + with print_block('autofunction', value=_name): + pass if _origin and (_origin.__doc__ or '').strip(): - print_block(f""" + with print_block('admonition', value='Torch Version Related', params={'class': 'tip'}) as f: + print_doc(f""" This documentation is based on `torch.{_name} `_ in `torch v{_torch_version} `_. @@ -34,7 +38,10 @@ with the following command and find its documentation. The arguments and keyword arguments supported in torch v{_torch_version} is listed below. - """, 'admonition', value='Torch Version Related', params={'class': 'tip'}) - print_doc(f'.. function:: {_origin.__doc__.lstrip()}') + """, file=f) + print() + + with print_block('') as f: + print_doc(f'.. function:: {_origin.__doc__.lstrip()}', file=f) print() diff --git a/docs/source/api_doc/torch/tensor.rst.py b/docs/source/api_doc/torch/tensor.rst.py new file mode 100644 index 0000000000000000000000000000000000000000..7463cc8f7962eaee1bdba187cc97485dc0e74fbe --- /dev/null +++ b/docs/source/api_doc/torch/tensor.rst.py @@ -0,0 +1,56 @@ +import importlib + +import torch + +import treetensor.torch as ttorch +from docs import print_title, current_module, print_block, get_origin, PropertyType, print_doc + +_DOC_FROM_TAG = '__doc_from__' +_DOC_TAG = '__doc_names__' +_ttorch_tensor = importlib.import_module('treetensor.torch.tensor') +_torch_version = torch.__version__ + +if __name__ == '__main__': + print_title(_ttorch_tensor.__name__, levelc='=') + current_module(ttorch.__name__) + with print_block('automodule', value=_ttorch_tensor.__name__): + pass + + print_title(ttorch.Tensor.__name__, levelc='-') + + with print_block('autoclass', value=ttorch.Tensor.__name__) as tf: + for _name in sorted(getattr(ttorch.Tensor, _DOC_TAG, [])): + _item = getattr(ttorch.Tensor, _name) + _origin = get_origin(_item) + + with print_block( + 'autoproperty' if isinstance(_item, PropertyType) else 'automethod', + value=_name, file=tf + ) as f: + pass + + if _origin and (_origin.__doc__ or '').strip(): + with print_block('admonition', value='Torch Version Related', params={'class': 'tip'}, file=tf) as f: + print_doc(f""" + This documentation is based on + `torch.Tensor.{_name} `_ + in `torch v{_torch_version} `_. + **Its arguments\' arrangements depend on the version of pytorch you installed**. + + If some arguments listed here are not working properly, please check your pytorch's version + with the following command and find its documentation. + + .. code-block:: shell + :linenos: + + python -c 'import torch;print(torch.__version__)' + + The arguments and keyword arguments supported in torch v{_torch_version} is listed below. + + """, file=f) + print(file=tf) + + with print_block('', file=tf) as f: + print_doc(f'.. function:: {_origin.__doc__.lstrip()}', file=f) + + print(file=tf) diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index 302fa100d3af3da439c4cdc1cb204a65062d70c4..a0f146e1bf5cb8be9d30395d11f285521c07e748 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -1,3 +1,8 @@ +""" +Overview: + Common functions, based on ``torch`` module. +""" + import builtins import torch diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index aadb345103d88219a9de1d4eca5772e6a3164c73..86af29f5de20a3550958d42630daad119b15d74f 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -1,3 +1,8 @@ +""" +Overview: + ``Tensor`` class, based on ``torch`` module. +""" + import numpy as np import torch from treevalue import method_treelize diff --git a/treetensor/utils/clazz.py b/treetensor/utils/clazz.py index 181a4627dcad39e2adccd614645acd58b542a316..ed3b0cabec2d03f1069700d903fbe8166160c9fc 100644 --- a/treetensor/utils/clazz.py +++ b/treetensor/utils/clazz.py @@ -47,10 +47,17 @@ class _TempClazz: PropertyType = type(_TempClazz.prop) +def _is_property(clazz, name): + prop = getattr(clazz, name) + return isinstance(prop, PropertyType) and ( + not hasattr(clazz.__base__, name) or getattr(clazz.__base__, name) is not prop + ) + + # noinspection PyTypeChecker -def _is_func_property(clazz, name): +def _is_func(clazz, name): func = getattr(clazz, name) - return isinstance(func, (types.FunctionType, PropertyType)) and ( + return isinstance(func, types.FunctionType) and ( not hasattr(clazz.__base__, name) or getattr(clazz.__base__, name) is not func ) @@ -67,8 +74,8 @@ def current_names(keep: bool = True): members = set() for name in dir(cls): item = getattr(cls, name) - if (_is_func_property(cls, name) or _is_classmethod(cls, name)) and \ - getattr(item, '__name__', None) == name: # should be public or protected + if ((_is_func(cls, name) or _is_classmethod(cls, name)) and getattr(item, '__name__', None) == name) or \ + (_is_property(cls, name)): members.add(name) _old_names = _get_names(cls) if keep else set()