From a8d1488e6c995920656e3fccee96ac658c487295 Mon Sep 17 00:00:00 2001 From: HansBug Date: Sun, 12 Sep 2021 23:00:21 +0800 Subject: [PATCH] doc(hansbug): add doc for treetensor.tensor.funcs --- test/tensor/test_funcs.py | 8 ++--- treetensor/tensor/funcs.py | 58 ++++++++++++++++++++++++++++-------- treetensor/utils/__init__.py | 1 + treetensor/utils/doc.py | 39 ++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 16 deletions(-) create mode 100644 treetensor/utils/doc.py diff --git a/test/tensor/test_funcs.py b/test/tensor/test_funcs.py index 53be6d572..6145b1cf0 100644 --- a/test/tensor/test_funcs.py +++ b/test/tensor/test_funcs.py @@ -131,13 +131,13 @@ class TestTensorFuncs: }) def test_randint(self): - _target = ttorch.randint(TreeValue({ + _target = ttorch.randint(-10, 10, TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } - }), -10, 10) + })) assert ttorch.all(_target < 10) assert ttorch.all(-10 <= _target) assert _target.shape == ttorch.TreeSize({ @@ -148,13 +148,13 @@ class TestTensorFuncs: } }) - _target = ttorch.randint(TreeValue({ + _target = ttorch.randint(10, TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } - }), 10) + })) assert ttorch.all(_target < 10) assert ttorch.all(0 <= _target) assert _target.shape == ttorch.TreeSize({ diff --git a/treetensor/tensor/funcs.py b/treetensor/tensor/funcs.py index 71e46c5bf..253221dd4 100644 --- a/treetensor/tensor/funcs.py +++ b/treetensor/tensor/funcs.py @@ -1,13 +1,31 @@ import builtins +from typing import List import torch from treevalue import func_treelize as original_func_treelize from .tensor import TreeTensor, tireduce from ..common import TreeObject, ireduce +from ..utils import inherit_doc as original_inherit_doc from ..utils import replaceable_partial + +def _doc_stripper(src, _, lines: List[str]): + _name, _version = src.__name__, torch.__version__ + return [ + f'.. note::', + f'', + f' This documentation is based on ' + f' `torch.{_name} `_ ' + f' in `torch v{_version} `_.', + f' **Its arguments\' arrangements depend on the version of pytorch you installed**.', + f'', + *lines[1:] + ] + + func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor) +inherit_doc = replaceable_partial(original_inherit_doc, stripper=_doc_stripper) __all__ = [ 'zeros', 'zeros_like', @@ -21,83 +39,99 @@ __all__ = [ ] +@inherit_doc(torch.zeros) @func_treelize() -def zeros(size, *args, **kwargs): - return torch.zeros(*size, *args, **kwargs) +def zeros(*args, **kwargs): + return torch.zeros(*args, **kwargs) +@inherit_doc(torch.zeros_like) @func_treelize() def zeros_like(input_, *args, **kwargs): return torch.zeros_like(input_, *args, **kwargs) +@inherit_doc(torch.randn) @func_treelize() -def randn(size, *args, **kwargs): - return torch.randn(*size, *args, **kwargs) +def randn(*args, **kwargs): + return torch.randn(*args, **kwargs) +@inherit_doc(torch.randn_like) @func_treelize() def randn_like(input_, *args, **kwargs): return torch.randn_like(input_, *args, **kwargs) +@inherit_doc(torch.randint) @func_treelize() -def randint(size, *args, **kwargs): - return torch.randint(*args, size, **kwargs) +def randint(*args, **kwargs): + return torch.randint(*args, **kwargs) +@inherit_doc(torch.randint_like) @func_treelize() def randint_like(input_, *args, **kwargs): return torch.randint_like(input_, *args, **kwargs) +@inherit_doc(torch.ones) @func_treelize() -def ones(size, *args, **kwargs): - return torch.ones(*size, *args, **kwargs) +def ones(*args, **kwargs): + return torch.ones(*args, **kwargs) +@inherit_doc(torch.ones_like) @func_treelize() def ones_like(input_, *args, **kwargs): return torch.ones_like(input_, *args, **kwargs) +@inherit_doc(torch.full) @func_treelize() -def full(size, *args, **kwargs): - return torch.full(size, *args, **kwargs) +def full(*args, **kwargs): + return torch.full(*args, **kwargs) +@inherit_doc(torch.full_like) @func_treelize() def full_like(input_, *args, **kwargs): return torch.full_like(input_, *args, **kwargs) +@inherit_doc(torch.empty) @func_treelize() -def empty(size, *args, **kwargs): - return torch.empty(size, *args, **kwargs) +def empty(*args, **kwargs): + return torch.empty(*args, **kwargs) +@inherit_doc(torch.empty_like) @func_treelize() def empty_like(input_, *args, **kwargs): return torch.empty_like(input_, *args, **kwargs) +@inherit_doc(torch.all) @tireduce(torch.all) @func_treelize(return_type=TreeObject) def all(input_, *args, **kwargs): return torch.all(input_, *args, **kwargs) +@inherit_doc(torch.any) @tireduce(torch.any) @func_treelize(return_type=TreeObject) def any(input_, *args, **kwargs): return torch.any(input_, *args, **kwargs) +@inherit_doc(torch.eq) @func_treelize() def eq(input_, other, *args, **kwargs): return torch.eq(input_, other, *args, **kwargs) +@inherit_doc(torch.equal) @ireduce(builtins.all) @func_treelize() def equal(input_, other, *args, **kwargs): diff --git a/treetensor/utils/__init__.py b/treetensor/utils/__init__.py index 336167bf0..45e90ec28 100644 --- a/treetensor/utils/__init__.py +++ b/treetensor/utils/__init__.py @@ -1,2 +1,3 @@ from .clazz import * +from .doc import * from .func import * diff --git a/treetensor/utils/doc.py b/treetensor/utils/doc.py new file mode 100644 index 000000000..10261f36f --- /dev/null +++ b/treetensor/utils/doc.py @@ -0,0 +1,39 @@ +import os +from typing import List, Optional, Callable, Any + +__all__ = [ + 'inherit_doc', +] + + +def _strip_lines(doc: str): + _lines = doc.strip().splitlines() + _exist_lines = list(filter(str.strip, _lines)) + + if not _exist_lines: + _indent = '' + else: + l, r = 0, min(map(len, _exist_lines)) + while l < r: + m = (l + r + 1) // 2 + _prefixes = set(map(lambda x: x[:m], _exist_lines)) + l, r = (m, r) if len(_prefixes) <= 1 else (l, m - 1) + _indent = list(map(lambda x: x[:l], _exist_lines))[0] + + _stripped_lines = list(map(lambda x: x[len(_indent):] if x.strip() else '', _lines)) + return _indent, _stripped_lines + + +def _unstrip_lines(indent: str, stripped_lines: List[str]) -> str: + return os.linesep.join(map(lambda x: indent + x, stripped_lines)) + + +def inherit_doc(src, stripper: Optional[Callable[[Any, Any, List[str]], List[str]]] = None): + _indent, _stripped_lines = _strip_lines(src.__doc__) + + def _decorator(obj): + _lines = (stripper or (lambda s, o, x: x))(src, obj, _stripped_lines) + obj.__doc__ = _unstrip_lines(_indent, _lines) + return obj + + return _decorator -- GitLab