Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
64f9f118
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
64f9f118
编写于
9月 13, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
doc(hansbug): optimize documentation for treetensor.torch.funcs
上级
798d9183
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
221 addition
and
150 deletion
+221
-150
docs/source/_libs/docs.py
docs/source/_libs/docs.py
+69
-0
docs/source/api_doc/numpy/funcs.rst.py
docs/source/api_doc/numpy/funcs.rst.py
+45
-0
docs/source/api_doc/numpy/funcs.rstc
docs/source/api_doc/numpy/funcs.rstc
+0
-1
docs/source/api_doc/numpy/numpy.rstc
docs/source/api_doc/numpy/numpy.rstc
+0
-1
docs/source/api_doc/torch/funcs.rst.py
docs/source/api_doc/torch/funcs.rst.py
+40
-0
docs/source/api_doc/torch/funcs.rstc
docs/source/api_doc/torch/funcs.rstc
+0
-1
docs/source/apidoc.mk
docs/source/apidoc.mk
+7
-9
docs/source/apidoc_gen.py
docs/source/apidoc_gen.py
+0
-53
treetensor/numpy/funcs.py
treetensor/numpy/funcs.py
+5
-6
treetensor/torch/funcs.py
treetensor/torch/funcs.py
+38
-38
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+13
-1
treetensor/utils/doc.py
treetensor/utils/doc.py
+4
-40
未找到文件。
docs/source/_libs/docs.py
0 → 100644
浏览文件 @
64f9f118
from
functools
import
partial
from
typing
import
Optional
,
Tuple
,
List
def
strip_docs
(
doc
:
Optional
[
str
])
->
Tuple
[
str
,
List
[
str
]]:
_lines
=
(
doc
or
''
).
splitlines
()
_non_empty_lines
=
sorted
(
filter
(
lambda
t
:
t
[
1
].
strip
(),
enumerate
(
_lines
)))
if
_non_empty_lines
:
_first_line_id
,
_
=
_non_empty_lines
[
0
]
_lines
=
_lines
[
_first_line_id
:]
else
:
_lines
=
[]
_exist_lines
=
list
(
filter
(
str
.
strip
,
_lines
))
if
not
_exist_lines
:
_indent
=
''
else
:
l
,
r
=
0
,
min
(
map
(
len
,
_exist_lines
))
while
l
<
r
:
m
=
(
l
+
r
+
1
)
//
2
_prefixes
=
set
(
map
(
lambda
x
:
x
[:
m
],
_exist_lines
))
l
,
r
=
(
m
,
r
)
if
len
(
_prefixes
)
<=
1
else
(
l
,
m
-
1
)
_indent
=
list
(
map
(
lambda
x
:
x
[:
l
],
_exist_lines
))[
0
]
_stripped_lines
=
list
(
map
(
lambda
x
:
x
[
len
(
_indent
):]
if
x
.
strip
()
else
''
,
_lines
))
return
_indent
,
_stripped_lines
_DOC_FROM_TAG
=
'__doc_from__'
def
get_origin
(
obj
):
return
getattr
(
obj
,
_DOC_FROM_TAG
,
None
)
def
print_title
(
title
:
str
,
levelc
=
'='
,
file
=
None
):
_print
=
partial
(
print
,
file
=
file
)
_print
(
title
)
_print
(
levelc
*
(
len
(
title
)
+
5
))
_print
()
def
print_doc
(
doc
:
str
,
strip
:
bool
=
True
,
indent
:
str
=
''
,
file
=
None
):
_print
=
partial
(
print
,
indent
,
file
=
file
,
sep
=
''
)
if
strip
:
_
,
_lines
=
strip_docs
(
doc
or
''
)
else
:
_lines
=
(
doc
or
''
).
splitlines
()
for
_line
in
_lines
:
_print
(
_line
)
_print
()
def
print_block
(
doc
:
str
,
name
:
str
,
value
:
Optional
[
str
]
=
None
,
params
:
Optional
[
dict
]
=
None
,
file
=
None
):
_print
=
partial
(
print
,
file
=
file
)
_print
(
f
'..
{
name
}
::
{
str
(
value
)
if
value
is
not
None
else
""
}
'
)
for
k
,
v
in
(
params
or
{}).
items
():
_print
(
f
' :
{
k
}
:
{
str
(
v
)
if
v
is
not
None
else
""
}
'
)
_print
()
print_doc
(
doc
,
strip
=
True
,
indent
=
' '
,
file
=
file
)
def
current_module
(
module
:
str
,
file
=
None
):
_print
=
partial
(
print
,
file
=
file
)
_print
(
f
'.. currentmodule::
{
module
}
'
)
_print
()
docs/source/api_doc/numpy/funcs.rst.py
0 → 100644
浏览文件 @
64f9f118
import
re
import
numpy
as
np
import
treetensor.numpy
as
tnp
from
docs
import
print_title
,
current_module
,
get_origin
,
print_block
,
print_doc
_DOC_FROM_TAG
=
'__doc_from__'
_H2_PATTERN
=
re
.
compile
(
'-{3,}'
)
if
__name__
==
'__main__'
:
_numpy_version
=
np
.
__version__
_short_version
=
'.'
.
join
(
_numpy_version
.
split
(
'.'
)[:
2
])
print_title
(
tnp
.
funcs
.
__name__
,
levelc
=
'='
)
current_module
(
tnp
.
funcs
.
__name__
)
for
_name
in
sorted
(
tnp
.
funcs
.
__all__
):
_item
=
getattr
(
tnp
.
funcs
,
_name
)
_origin
=
get_origin
(
_item
)
print_title
(
_name
,
levelc
=
'-'
)
print_block
(
''
,
'autofunction'
,
value
=
_name
)
if
_origin
and
(
_origin
.
__doc__
or
''
).
strip
():
print_block
(
f
"""
This documentation is based on
`numpy.
{
_name
}
<https://numpy.org/doc/
{
_short_version
}
/reference/generated/numpy.
{
_name
}
.html>`_
in `numpy v
{
_numpy_version
}
<https://numpy.org/doc/
{
_short_version
}
/>`_.
**Its arguments
\'
arrangements depend on the version of numpy you installed**.
If some arguments listed here are not working properly, please check your numpy's version
with the following command and find its documentation.
.. code-block:: shell
:linenos:
python -c 'import numpy as np;print(np.__version__)'
The arguments and keyword arguments supported in numpy v
{
_numpy_version
}
is listed below.
"""
,
'note'
)
print
()
print_doc
(
_H2_PATTERN
.
sub
(
lambda
x
:
'~'
*
len
(
x
.
group
(
0
)),
_origin
.
__doc__
or
''
))
print
()
docs/source/api_doc/numpy/funcs.rstc
已删除
100644 → 0
浏览文件 @
798d9183
treetensor.numpy.funcs
\ No newline at end of file
docs/source/api_doc/numpy/numpy.rstc
已删除
100644 → 0
浏览文件 @
798d9183
treetensor.numpy.numpy
\ No newline at end of file
docs/source/api_doc/torch/funcs.rst.py
0 → 100644
浏览文件 @
64f9f118
import
torch
import
treetensor.torch
as
ttorch
from
docs
import
print_title
,
current_module
,
get_origin
,
print_block
,
print_doc
_DOC_FROM_TAG
=
'__doc_from__'
if
__name__
==
'__main__'
:
_torch_version
=
torch
.
__version__
print_title
(
ttorch
.
funcs
.
__name__
,
levelc
=
'='
)
current_module
(
ttorch
.
funcs
.
__name__
)
for
_name
in
sorted
(
ttorch
.
funcs
.
__all__
):
_item
=
getattr
(
ttorch
.
funcs
,
_name
)
_origin
=
get_origin
(
_item
)
print_title
(
_name
,
levelc
=
'-'
)
print_block
(
''
,
'autofunction'
,
value
=
_name
)
if
_origin
and
(
_origin
.
__doc__
or
''
).
strip
():
print_block
(
f
"""
This documentation is based on
`torch.
{
_name
}
<https://pytorch.org/docs/
{
_torch_version
}
/generated/torch.
{
_name
}
.html>`_
in `torch v
{
_torch_version
}
<https://pytorch.org/docs/
{
_torch_version
}
/>`_.
**Its arguments
\'
arrangements depend on the version of pytorch you installed**.
If some arguments listed here are not working properly, please check your pytorch's version
with the following command and find its documentation.
.. code-block:: shell
:linenos:
python -c 'import torch;print(torch.__version__)'
The arguments and keyword arguments supported in torch v
{
_torch_version
}
is listed below.
"""
,
'note'
)
print_doc
(
f
'.. function::
{
_origin
.
__doc__
.
lstrip
()
}
'
)
print
()
docs/source/api_doc/torch/funcs.rstc
已删除
100644 → 0
浏览文件 @
798d9183
treetensor.torch.funcs
\ No newline at end of file
docs/source/apidoc.mk
浏览文件 @
64f9f118
PYTHON
:=
$(
shell
which python
)
PYTHON
:=
$(
shell
which python
)
SOURCE
?=
.
RSTC_FILES
:=
$(
shell
find
${SOURCE}
-name
*
.rstc
)
RST_RESULTS
:=
$(
addsuffix
.auto.rst,
$(
basename
${RSTC_FILES}
))
SOURCE
?=
.
PYTHON_SCRIPTS
:=
$(
shell
find
${SOURCE}
-name
*
.rst.py
)
PYTHON_RESULTS
:=
$(
addsuffix
.auto.rst,
$(
basename
$(
basename
${PYTHON_SCRIPTS}
)
))
APIDOC_GEN_PY
:=
$(
shell
readlink
-f
${SOURCE}
/apidoc_gen.py
)
%.auto.rst
:
%.rstc ${APIDOC_GEN_PY}
%.auto.rst
:
%.rst.py
cd
"
$(
shell
dirname
$(
shell
readlink -f
$<
))
"
&&
\
PYTHONPATH
=
"
$(
shell
dirname
$(
shell
readlink -f
$<
))
:
${PYTHONPATH}
"
\
cat
"
$(
shell
readlink -f
$<
)
"
|
$(PYTHON)
"
${APIDOC_GEN_PY}
"
>
"
$(
shell
readlink -f
$@
)
"
$(PYTHON)
"
$(
shell
readlink -f
$<
)
"
>
"
$(
shell
readlink -f
$@
)
"
build
:
${
RST
_RESULTS}
build
:
${
PYTHON
_RESULTS}
all
:
build
...
...
docs/source/apidoc_gen.py
已删除
100644 → 0
浏览文件 @
798d9183
import
importlib
import
types
from
typing
import
List
_DOC_TAG
=
'__doc_names__'
_DIRECT_DOC_TAG
=
'__direct_doc__'
def
_is_tagged_name
(
clazz
,
name
):
return
name
in
set
(
getattr
(
clazz
,
_DOC_TAG
,
set
()))
def
_find_class_members
(
clazz
:
type
)
->
List
[
str
]:
members
=
[]
for
name
in
dir
(
clazz
):
item
=
getattr
(
clazz
,
name
)
if
_is_tagged_name
(
clazz
,
name
)
and
\
getattr
(
item
,
'__name__'
,
None
)
==
name
:
# should be public or protected
members
.
append
(
name
)
return
members
if
__name__
==
'__main__'
:
package_name
=
input
().
strip
()
_module
=
importlib
.
import_module
(
package_name
)
_alls
=
getattr
(
_module
,
'__all__'
)
print
(
package_name
)
print
(
'='
*
(
len
(
package_name
)
+
5
))
print
()
print
(
f
'.. automodule::
{
package_name
}
'
)
print
()
for
_name
in
sorted
(
_alls
):
print
(
_name
)
print
(
'-'
*
(
len
(
_name
)
+
5
))
print
()
_item
=
getattr
(
_module
,
_name
)
if
getattr
(
_item
,
_DIRECT_DOC_TAG
,
None
):
print
(
_item
.
__doc__
)
else
:
if
isinstance
(
_item
,
types
.
FunctionType
):
print
(
f
'.. autofunction::
{
package_name
}
.
{
_name
}
'
)
elif
isinstance
(
_item
,
type
):
print
(
f
'.. autoclass::
{
package_name
}
.
{
_name
}
'
)
print
(
f
' :members:
{
", "
.
join
(
sorted
(
_find_class_members
(
_item
)))
}
'
)
else
:
print
(
f
'.. autodata::
{
package_name
}
.
{
_name
}
'
)
print
(
f
' :annotation:'
)
print
()
treetensor/numpy/funcs.py
浏览文件 @
64f9f118
...
...
@@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize
from
.numpy
import
TreeNumpy
from
..common
import
ireduce
,
TreeObject
from
..utils
import
replaceable_partial
,
inherit_doc
from
..utils
import
replaceable_partial
,
doc_from
__all__
=
[
'all'
,
'any'
,
...
...
@@ -30,30 +30,29 @@ def _doc_stripper(src, _, lines: List[str]):
func_treelize
=
replaceable_partial
(
original_func_treelize
,
return_type
=
TreeNumpy
)
docs
=
replaceable_partial
(
inherit_doc
,
stripper
=
_doc_stripper
)
@
doc
s
(
np
.
all
)
@
doc
_from
(
np
.
all
)
@
ireduce
(
builtins
.
all
)
@
func_treelize
(
return_type
=
TreeObject
)
def
all
(
a
,
*
args
,
**
kwargs
):
return
np
.
all
(
a
,
*
args
,
**
kwargs
)
@
doc
s
(
np
.
any
)
@
doc
_from
(
np
.
any
)
@
ireduce
(
builtins
.
any
)
@
func_treelize
()
def
any
(
a
,
*
args
,
**
kwargs
):
return
np
.
any
(
a
,
*
args
,
**
kwargs
)
@
doc
s
(
np
.
equal
)
@
doc
_from
(
np
.
equal
)
@
func_treelize
()
def
equal
(
x1
,
x2
,
*
args
,
**
kwargs
):
return
np
.
equal
(
x1
,
x2
,
*
args
,
**
kwargs
)
@
doc
s
(
np
.
array_equal
)
@
doc
_from
(
np
.
array_equal
)
@
func_treelize
()
def
array_equal
(
a1
,
a2
,
*
args
,
**
kwargs
):
return
np
.
array_equal
(
a1
,
a2
,
*
args
,
**
kwargs
)
treetensor/torch/funcs.py
浏览文件 @
64f9f118
import
builtins
from
typing
import
List
import
torch
from
treevalue
import
func_treelize
as
original_func_treelize
from
treevalue.utils
import
post_process
from
.tensor
import
TreeTensor
,
tireduce
from
..common
import
TreeObject
,
ireduce
from
..utils
import
replaceable_partial
,
d
irect_doc
,
inherit_doc
from
..utils
import
replaceable_partial
,
d
oc_from
__all__
=
[
'zeros'
,
'zeros_like'
,
...
...
@@ -20,121 +18,123 @@ __all__ = [
'eq'
,
'equal'
,
]
def
_doc_stripper
(
src
,
_
,
lines
:
List
[
str
]):
_name
,
_version
=
src
.
__name__
,
torch
.
__version__
if
lines
:
lines
[
0
]
=
f
'.. function::
{
lines
[
0
]
}
'
return
[
f
'.. note::'
,
f
''
,
f
' This documentation is based on '
f
' `torch.
{
_name
}
<https://pytorch.org/docs/
{
_version
}
/generated/torch.
{
_name
}
.html>`_ '
f
' in `torch v
{
_version
}
<https://pytorch.org/docs/
{
_version
}
/>`_.'
,
f
' **Its arguments
\'
arrangements depend on the version of pytorch you installed**.'
,
f
''
,
*
lines
,
]
func_treelize
=
replaceable_partial
(
original_func_treelize
,
return_type
=
TreeTensor
)
docs
=
post_process
(
post_process
(
direct_doc
))(
replaceable_partial
(
inherit_doc
,
stripper
=
_doc_stripper
))
@
doc
s
(
torch
.
zeros
)
@
doc
_from
(
torch
.
zeros
)
@
func_treelize
()
def
zeros
(
*
args
,
**
kwargs
):
return
torch
.
zeros
(
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
zeros_like
)
@
doc
_from
(
torch
.
zeros_like
)
@
func_treelize
()
def
zeros_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
zeros_like
(
input_
,
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
randn
)
@
doc
_from
(
torch
.
randn
)
@
func_treelize
()
def
randn
(
*
args
,
**
kwargs
):
return
torch
.
randn
(
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
randn_like
)
@
doc
_from
(
torch
.
randn_like
)
@
func_treelize
()
def
randn_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
randn_like
(
input_
,
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
randint
)
@
doc
_from
(
torch
.
randint
)
@
func_treelize
()
def
randint
(
*
args
,
**
kwargs
):
return
torch
.
randint
(
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
randint_like
)
@
doc
_from
(
torch
.
randint_like
)
@
func_treelize
()
def
randint_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
randint_like
(
input_
,
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
ones
)
@
doc
_from
(
torch
.
ones
)
@
func_treelize
()
def
ones
(
*
args
,
**
kwargs
):
return
torch
.
ones
(
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
ones_like
)
@
doc
_from
(
torch
.
ones_like
)
@
func_treelize
()
def
ones_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
ones_like
(
input_
,
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
full
)
@
doc
_from
(
torch
.
full
)
@
func_treelize
()
def
full
(
*
args
,
**
kwargs
):
return
torch
.
full
(
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
full_like
)
@
doc
_from
(
torch
.
full_like
)
@
func_treelize
()
def
full_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
full_like
(
input_
,
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
empty
)
@
doc
_from
(
torch
.
empty
)
@
func_treelize
()
def
empty
(
*
args
,
**
kwargs
):
return
torch
.
empty
(
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
empty_like
)
@
doc
_from
(
torch
.
empty_like
)
@
func_treelize
()
def
empty_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
empty_like
(
input_
,
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
all
)
@
doc
_from
(
torch
.
all
)
@
tireduce
(
torch
.
all
)
@
func_treelize
(
return_type
=
TreeObject
)
def
all
(
input_
,
*
args
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``all`` result of a whole tree with this function.
Example::
>>> all(torch.tensor([True, True])) # the same as torch.all
torch.tensor(True)
>>> all(TreeTensor({
>>> 'a': torch.tensor([True, True]),
>>> 'b': torch.tensor([True, True]),
>>> }))
torch.tensor(True)
>>> all(TreeTensor({
>>> 'a': torch.tensor([True, True]),
>>> 'b': torch.tensor([True, False]),
>>> }))
torch.tensor(False)
"""
return
torch
.
all
(
input_
,
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
any
)
@
doc
_from
(
torch
.
any
)
@
tireduce
(
torch
.
any
)
@
func_treelize
(
return_type
=
TreeObject
)
def
any
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
any
(
input_
,
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
eq
)
@
doc
_from
(
torch
.
eq
)
@
func_treelize
()
def
eq
(
input_
,
other
,
*
args
,
**
kwargs
):
return
torch
.
eq
(
input_
,
other
,
*
args
,
**
kwargs
)
@
doc
s
(
torch
.
equal
)
@
doc
_from
(
torch
.
equal
)
@
ireduce
(
builtins
.
all
)
@
func_treelize
()
def
equal
(
input_
,
other
,
*
args
,
**
kwargs
):
...
...
treetensor/torch/tensor.py
浏览文件 @
64f9f118
...
...
@@ -6,7 +6,7 @@ from treevalue.utils import pre_process
from
.size
import
TreeSize
from
..common
import
TreeObject
,
TreeData
,
ireduce
from
..numpy
import
TreeNumpy
from
..utils
import
inherit_names
,
current_names
from
..utils
import
inherit_names
,
current_names
,
doc_from
__all__
=
[
'TreeTensor'
...
...
@@ -20,56 +20,68 @@ tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduc
@
current_names
()
@
inherit_names
(
TreeData
)
class
TreeTensor
(
TreeData
):
@
doc_from
(
torch
.
Tensor
.
numpy
)
@
method_treelize
(
return_type
=
TreeNumpy
)
def
numpy
(
self
:
torch
.
Tensor
)
->
np
.
ndarray
:
return
self
.
numpy
()
@
doc_from
(
torch
.
Tensor
.
tolist
)
@
method_treelize
(
return_type
=
TreeObject
)
def
tolist
(
self
:
torch
.
Tensor
):
return
self
.
tolist
()
@
doc_from
(
torch
.
Tensor
.
cpu
)
@
method_treelize
()
def
cpu
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
return
self
.
cpu
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
cuda
)
@
method_treelize
()
def
cuda
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
return
self
.
cuda
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
to
)
@
method_treelize
()
def
to
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
return
self
.
to
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
numel
)
@
ireduce
(
sum
)
@
method_treelize
(
return_type
=
TreeObject
)
def
numel
(
self
:
torch
.
Tensor
):
return
self
.
numel
()
@
property
@
doc_from
(
torch
.
Tensor
.
shape
)
@
method_treelize
(
return_type
=
TreeSize
)
def
shape
(
self
:
torch
.
Tensor
):
return
self
.
shape
@
doc_from
(
torch
.
Tensor
.
all
)
@
tireduce
(
torch
.
all
)
@
method_treelize
(
return_type
=
TreeObject
)
def
all
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
bool
:
return
self
.
all
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
any
)
@
tireduce
(
torch
.
any
)
@
method_treelize
(
return_type
=
TreeObject
)
def
any
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
bool
:
return
self
.
any
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
max
)
@
tireduce
(
torch
.
max
)
@
method_treelize
(
return_type
=
TreeObject
)
def
max
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
return
self
.
max
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
min
)
@
tireduce
(
torch
.
min
)
@
method_treelize
(
return_type
=
TreeObject
)
def
min
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
return
self
.
min
(
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
Tensor
.
sum
)
@
tireduce
(
torch
.
sum
)
@
method_treelize
(
return_type
=
TreeObject
)
def
sum
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
...
...
treetensor/utils/doc.py
浏览文件 @
64f9f118
import
os
from
typing
import
List
,
Optional
,
Callable
,
Any
__all__
=
[
'
inherit_doc'
,
'direct_doc
'
,
'
doc_from
'
,
]
def
_strip_lines
(
doc
:
Optional
[
str
]):
_lines
=
(
doc
or
''
).
splitlines
()
_first_line_id
,
_
=
sorted
(
filter
(
lambda
t
:
t
[
1
].
strip
(),
enumerate
(
_lines
)))[
0
]
_lines
=
_lines
[
_first_line_id
:]
_exist_lines
=
list
(
filter
(
str
.
strip
,
_lines
))
if
not
_exist_lines
:
_indent
=
''
else
:
l
,
r
=
0
,
min
(
map
(
len
,
_exist_lines
))
while
l
<
r
:
m
=
(
l
+
r
+
1
)
//
2
_prefixes
=
set
(
map
(
lambda
x
:
x
[:
m
],
_exist_lines
))
l
,
r
=
(
m
,
r
)
if
len
(
_prefixes
)
<=
1
else
(
l
,
m
-
1
)
_indent
=
list
(
map
(
lambda
x
:
x
[:
l
],
_exist_lines
))[
0
]
_stripped_lines
=
list
(
map
(
lambda
x
:
x
[
len
(
_indent
):]
if
x
.
strip
()
else
''
,
_lines
))
return
_indent
,
_stripped_lines
def
_unstrip_lines
(
indent
:
str
,
stripped_lines
:
List
[
str
])
->
str
:
return
os
.
linesep
.
join
(
map
(
lambda
x
:
indent
+
x
,
stripped_lines
))
_DOC_FROM_TAG
=
'__doc_from__'
def
inherit_doc
(
src
,
stripper
:
Optional
[
Callable
[[
Any
,
Any
,
List
[
str
]],
List
[
str
]]]
=
None
):
_indent
,
_stripped_lines
=
_strip_lines
(
src
.
__doc__
)
def
doc_from
(
src
):
def
_decorator
(
obj
):
_lines
=
(
stripper
or
(
lambda
s
,
o
,
x
:
x
))(
src
,
obj
,
_stripped_lines
)
obj
.
__doc__
=
_unstrip_lines
(
_indent
,
_lines
)
setattr
(
obj
,
_DOC_FROM_TAG
,
src
)
return
obj
return
_decorator
_DIRECT_DOC
=
'__direct_doc__'
def
direct_doc
(
obj
):
setattr
(
obj
,
_DIRECT_DOC
,
True
)
return
obj
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录