提交 a8d1488e 编写于 作者: HansBug's avatar HansBug 😆

doc(hansbug): add doc for treetensor.tensor.funcs

上级 7cb29d8e
......@@ -131,13 +131,13 @@ class TestTensorFuncs:
})
def test_randint(self):
_target = ttorch.randint(TreeValue({
_target = ttorch.randint(-10, 10, TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}), -10, 10)
}))
assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.TreeSize({
......@@ -148,13 +148,13 @@ class TestTensorFuncs:
}
})
_target = ttorch.randint(TreeValue({
_target = ttorch.randint(10, TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}), 10)
}))
assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.TreeSize({
......
import builtins
from typing import List
import torch
from treevalue import func_treelize as original_func_treelize
from .tensor import TreeTensor, tireduce
from ..common import TreeObject, ireduce
from ..utils import inherit_doc as original_inherit_doc
from ..utils import replaceable_partial
def _doc_stripper(src, _, lines: List[str]):
_name, _version = src.__name__, torch.__version__
return [
f'.. note::',
f'',
f' This documentation is based on '
f' `torch.{_name} <https://pytorch.org/docs/{_version}/generated/torch.{_name}.html>`_ '
f' in `torch v{_version} <https://pytorch.org/docs/{_version}/>`_.',
f' **Its arguments\' arrangements depend on the version of pytorch you installed**.',
f'',
*lines[1:]
]
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor)
inherit_doc = replaceable_partial(original_inherit_doc, stripper=_doc_stripper)
__all__ = [
'zeros', 'zeros_like',
......@@ -21,83 +39,99 @@ __all__ = [
]
@inherit_doc(torch.zeros)
@func_treelize()
def zeros(size, *args, **kwargs):
return torch.zeros(*size, *args, **kwargs)
def zeros(*args, **kwargs):
return torch.zeros(*args, **kwargs)
@inherit_doc(torch.zeros_like)
@func_treelize()
def zeros_like(input_, *args, **kwargs):
return torch.zeros_like(input_, *args, **kwargs)
@inherit_doc(torch.randn)
@func_treelize()
def randn(size, *args, **kwargs):
return torch.randn(*size, *args, **kwargs)
def randn(*args, **kwargs):
return torch.randn(*args, **kwargs)
@inherit_doc(torch.randn_like)
@func_treelize()
def randn_like(input_, *args, **kwargs):
return torch.randn_like(input_, *args, **kwargs)
@inherit_doc(torch.randint)
@func_treelize()
def randint(size, *args, **kwargs):
return torch.randint(*args, size, **kwargs)
def randint(*args, **kwargs):
return torch.randint(*args, **kwargs)
@inherit_doc(torch.randint_like)
@func_treelize()
def randint_like(input_, *args, **kwargs):
return torch.randint_like(input_, *args, **kwargs)
@inherit_doc(torch.ones)
@func_treelize()
def ones(size, *args, **kwargs):
return torch.ones(*size, *args, **kwargs)
def ones(*args, **kwargs):
return torch.ones(*args, **kwargs)
@inherit_doc(torch.ones_like)
@func_treelize()
def ones_like(input_, *args, **kwargs):
return torch.ones_like(input_, *args, **kwargs)
@inherit_doc(torch.full)
@func_treelize()
def full(size, *args, **kwargs):
return torch.full(size, *args, **kwargs)
def full(*args, **kwargs):
return torch.full(*args, **kwargs)
@inherit_doc(torch.full_like)
@func_treelize()
def full_like(input_, *args, **kwargs):
return torch.full_like(input_, *args, **kwargs)
@inherit_doc(torch.empty)
@func_treelize()
def empty(size, *args, **kwargs):
return torch.empty(size, *args, **kwargs)
def empty(*args, **kwargs):
return torch.empty(*args, **kwargs)
@inherit_doc(torch.empty_like)
@func_treelize()
def empty_like(input_, *args, **kwargs):
return torch.empty_like(input_, *args, **kwargs)
@inherit_doc(torch.all)
@tireduce(torch.all)
@func_treelize(return_type=TreeObject)
def all(input_, *args, **kwargs):
return torch.all(input_, *args, **kwargs)
@inherit_doc(torch.any)
@tireduce(torch.any)
@func_treelize(return_type=TreeObject)
def any(input_, *args, **kwargs):
return torch.any(input_, *args, **kwargs)
@inherit_doc(torch.eq)
@func_treelize()
def eq(input_, other, *args, **kwargs):
return torch.eq(input_, other, *args, **kwargs)
@inherit_doc(torch.equal)
@ireduce(builtins.all)
@func_treelize()
def equal(input_, other, *args, **kwargs):
......
from .clazz import *
from .doc import *
from .func import *
import os
from typing import List, Optional, Callable, Any
__all__ = [
'inherit_doc',
]
def _strip_lines(doc: str):
_lines = doc.strip().splitlines()
_exist_lines = list(filter(str.strip, _lines))
if not _exist_lines:
_indent = ''
else:
l, r = 0, min(map(len, _exist_lines))
while l < r:
m = (l + r + 1) // 2
_prefixes = set(map(lambda x: x[:m], _exist_lines))
l, r = (m, r) if len(_prefixes) <= 1 else (l, m - 1)
_indent = list(map(lambda x: x[:l], _exist_lines))[0]
_stripped_lines = list(map(lambda x: x[len(_indent):] if x.strip() else '', _lines))
return _indent, _stripped_lines
def _unstrip_lines(indent: str, stripped_lines: List[str]) -> str:
return os.linesep.join(map(lambda x: indent + x, stripped_lines))
def inherit_doc(src, stripper: Optional[Callable[[Any, Any, List[str]], List[str]]] = None):
_indent, _stripped_lines = _strip_lines(src.__doc__)
def _decorator(obj):
_lines = (stripper or (lambda s, o, x: x))(src, obj, _stripped_lines)
obj.__doc__ = _unstrip_lines(_indent, _lines)
return obj
return _decorator
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册