Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-treetensor
提交
a24fa275
D
DI-treetensor
项目概览
OpenDILab开源决策智能平台
/
DI-treetensor
大约 1 年 前同步成功
通知
44
Star
172
Fork
11
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a24fa275
编写于
9月 29, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): add auto system for numpy
上级
4ebeb305
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
243 addition
and
110 deletion
+243
-110
treetensor/common/__init__.py
treetensor/common/__init__.py
+2
-0
treetensor/common/module.py
treetensor/common/module.py
+43
-0
treetensor/common/proxy.py
treetensor/common/proxy.py
+56
-0
treetensor/common/trees.py
treetensor/common/trees.py
+17
-2
treetensor/numpy/__init__.py
treetensor/numpy/__init__.py
+47
-0
treetensor/numpy/array.py
treetensor/numpy/array.py
+52
-9
treetensor/numpy/funcs.py
treetensor/numpy/funcs.py
+4
-2
treetensor/torch/base/torch.py
treetensor/torch/base/torch.py
+3
-17
treetensor/torch/funcs/base.py
treetensor/torch/funcs/base.py
+3
-24
treetensor/torch/size.py
treetensor/torch/size.py
+4
-4
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+12
-52
未找到文件。
treetensor/common/__init__.py
浏览文件 @
a24fa275
from
.module
import
*
from
.object
import
*
from
.proxy
import
*
from
.trees
import
*
from
.wrappers
import
*
treetensor/common/module.py
0 → 100644
浏览文件 @
a24fa275
from
functools
import
wraps
from
typing
import
Type
from
treevalue
import
TreeValue
from
treevalue
import
func_treelize
as
original_func_treelize
from
treevalue.tree.common
import
BaseTree
from
treevalue.utils
import
post_process
from
.trees
import
auto_tree
from
.wrappers
import
return_self
from
..utils
import
doc_from_base
as
original_doc_from_base
from
..utils
import
replaceable_partial
,
args_mapping
__all__
=
[
'module_func_loader'
,
]
def
module_func_loader
(
base
,
cls
:
Type
[
TreeValue
],
module_name
:
str
):
func_treelize
=
post_process
(
post_process
(
args_mapping
(
lambda
i
,
x
:
TreeValue
(
x
)
if
isinstance
(
x
,
(
dict
,
BaseTree
,
TreeValue
))
else
x
)))(
replaceable_partial
(
original_func_treelize
,
return_type
=
cls
)
)
doc_from_base
=
replaceable_partial
(
original_doc_from_base
,
base
=
base
)
auto_tree_cls
=
replaceable_partial
(
auto_tree
,
cls
=
cls
)
def
_load_func
(
name
):
func
=
getattr
(
base
,
name
)
return_self_dec
=
return_self
if
func
.
__name__
.
endswith
(
"_"
)
else
(
lambda
x
:
x
)
@
doc_from_base
()
@
return_self_dec
@
post_process
(
auto_tree_cls
)
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
@
wraps
(
func
,
assigned
=
(
'__name__'
,),
updated
=
())
def
_new_func
(
*
args
,
**
kwargs
):
return
func
(
*
args
,
**
kwargs
)
_new_func
.
__qualname__
=
_new_func
.
__name__
_new_func
.
__module__
=
module_name
return
_new_func
return
_load_func
treetensor/common/proxy.py
0 → 100644
浏览文件 @
a24fa275
from
functools
import
wraps
from
types
import
MethodType
from
treevalue
import
method_treelize
,
TreeValue
from
treevalue.utils
import
post_process
from
.trees
import
auto_tree
from
.wrappers
import
return_self
from
..utils
import
doc_from_base
as
original_doc_from_base
from
..utils
import
replaceable_partial
__all__
=
[
'get_tree_proxy'
,
]
def
get_tree_proxy
(
base
):
doc_from_base
=
replaceable_partial
(
original_doc_from_base
,
base
=
base
)
class
_TreeClassProxy
:
def
__init__
(
self
,
cls
):
self
.
__torch_funcs
=
{}
self
.
__cls
=
cls
def
__getattr__
(
self
,
name
):
if
name
in
self
.
__torch_funcs
.
keys
():
return
self
.
__torch_funcs
[
name
]
elif
hasattr
(
base
,
name
)
and
not
name
.
startswith
(
'_'
)
\
and
callable
(
getattr
(
base
,
name
)):
_origin_func
=
getattr
(
base
,
name
)
return_self_deco
=
return_self
if
name
.
endswith
(
'_'
)
else
(
lambda
x
:
x
)
@
doc_from_base
()
@
return_self_deco
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_tree
,
cls
=
self
.
__cls
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
@
wraps
(
_origin_func
,
assigned
=
(
'__name__'
,),
updated
=
())
def
_new_func
(
*
args
,
**
kwargs
):
return
_origin_func
(
*
args
,
**
kwargs
)
_new_func
.
__qualname__
=
f
'
{
self
.
__cls
.
__name__
}
.
{
name
}
'
_new_func
.
__module__
=
self
.
__cls
.
__module__
self
.
__torch_funcs
[
name
]
=
_new_func
return
_new_func
else
:
raise
AttributeError
(
f
'Function
{
repr
(
name
)
}
not found in
{
repr
(
base
)
}
'
)
class
_TreeInstanceProxy
:
def
__init__
(
self
,
proxy
,
s
):
self
.
__proxy
=
proxy
self
.
__self
=
s
def
__getattr__
(
self
,
name
):
return
MethodType
(
getattr
(
self
.
__proxy
,
name
),
self
.
__self
)
return
_TreeClassProxy
,
_TreeInstanceProxy
treetensor/common/trees.py
浏览文件 @
a24fa275
...
...
@@ -6,7 +6,7 @@ from typing import Optional, Tuple, Callable
from
typing
import
Type
from
treevalue
import
func_treelize
as
original_func_treelize
from
treevalue
import
general_tree_value
,
TreeValue
from
treevalue
import
general_tree_value
,
TreeValue
,
typetrans
from
treevalue.tree.common
import
BaseTree
from
treevalue.tree.tree.tree
import
get_data_property
from
treevalue.utils
import
post_process
...
...
@@ -15,7 +15,7 @@ from ..utils import replaceable_partial, args_mapping
__all__
=
[
'BaseTreeStruct'
,
'print_tree'
,
'clsmeta'
,
'print_tree'
,
'clsmeta'
,
'auto_tree'
,
]
...
...
@@ -177,3 +177,18 @@ def clsmeta(func, allow_dict: bool = False) -> Type[type]:
return
_result
return
_MetaClass
# noinspection PyArgumentList
def
auto_tree
(
v
,
cls
):
if
isinstance
(
cls
,
type
)
and
issubclass
(
cls
,
TreeValue
):
cls
=
partial
(
typetrans
,
return_type
=
cls
)
if
isinstance
(
v
,
TreeValue
):
return
cls
(
v
)
elif
isinstance
(
v
,
(
tuple
,
list
,
set
)):
return
type
(
v
)((
auto_tree
(
item
,
cls
)
for
item
in
v
))
elif
isinstance
(
v
,
dict
):
return
type
(
v
)({
key
:
auto_tree
(
value
,
cls
)
for
key
,
value
in
v
.
items
()})
else
:
return
v
treetensor/numpy/__init__.py
浏览文件 @
a24fa275
import
builtins
from
types
import
ModuleType
,
FunctionType
,
BuiltinFunctionType
from
typing
import
Iterable
import
numpy
as
np
from
.array
import
*
from
.array
import
__all__
as
_array_all
from
.funcs
import
*
from
.funcs
import
__all__
as
_funcs_all
from
.funcs
import
get_func_from_numpy
from
..config.meta
import
__VERSION__
__all__
=
[
*
_funcs_all
,
*
_array_all
,
]
_basic_types
=
(
builtins
.
bool
,
builtins
.
bytearray
,
builtins
.
bytes
,
builtins
.
complex
,
builtins
.
dict
,
builtins
.
float
,
builtins
.
frozenset
,
builtins
.
int
,
builtins
.
list
,
builtins
.
range
,
builtins
.
set
,
builtins
.
slice
,
builtins
.
str
,
builtins
.
tuple
,
)
_np_all
=
set
(
np
.
__all__
)
class
_Module
(
ModuleType
):
def
__init__
(
self
,
module
):
ModuleType
.
__init__
(
self
,
module
.
__name__
)
for
name
in
filter
(
lambda
x
:
x
.
startswith
(
'__'
)
and
x
.
endswith
(
'__'
),
dir
(
module
)):
setattr
(
self
,
name
,
getattr
(
module
,
name
))
self
.
__origin__
=
module
self
.
__numpy_version__
=
np
.
__version__
self
.
__version__
=
__VERSION__
def
__getattr__
(
self
,
name
):
if
(
name
in
self
.
__all__
)
or
\
(
hasattr
(
self
.
__origin__
,
name
)
and
isinstance
(
getattr
(
self
.
__origin__
,
name
),
ModuleType
)):
return
getattr
(
self
.
__origin__
,
name
)
else
:
item
=
getattr
(
np
,
name
)
if
isinstance
(
item
,
(
FunctionType
,
BuiltinFunctionType
))
and
not
name
.
startswith
(
'_'
):
return
get_func_from_numpy
(
name
)
elif
isinstance
(
item
,
_basic_types
)
and
name
in
_np_all
:
return
item
else
:
raise
AttributeError
(
f
'Attribute
{
repr
(
name
)
}
not found in
{
repr
(
__name__
)
}
.'
)
def
__dir__
(
self
)
->
Iterable
[
str
]:
return
self
.
__all__
import
sys
sys
.
modules
[
__name__
]
=
_Module
(
sys
.
modules
[
__name__
])
treetensor/numpy/array.py
浏览文件 @
a24fa275
import
numpy
as
np
import
numpy
from
treevalue
import
method_treelize
from
.base
import
TreeNumpy
from
..common
import
Object
,
ireduce
from
..common
import
Object
,
ireduce
,
clsmeta
,
get_tree_proxy
from
..utils
import
current_names
__all__
=
[
'ndarray'
]
_ArrayProxy
,
_InstanceArrayProxy
=
get_tree_proxy
(
numpy
.
ndarray
)
class
_BaseArrayMeta
(
clsmeta
(
numpy
.
asarray
,
allow_dict
=
True
)):
pass
# noinspection PyMethodParameters
class
_ArrayMeta
(
_BaseArrayMeta
):
def
__init__
(
cls
,
*
args
,
**
kwargs
):
_BaseArrayMeta
.
__init__
(
cls
,
*
args
,
**
kwargs
)
cls
.
__proxy
=
None
@
property
def
np
(
cls
):
if
not
cls
.
__proxy
:
cls
.
__proxy
=
_ArrayProxy
(
cls
)
return
cls
.
__proxy
def
__getattr__
(
cls
,
name
):
try
:
return
cls
.
np
.
__getattr__
(
name
)
except
AttributeError
:
raise
AttributeError
(
f
"type object
{
repr
(
cls
.
__name__
)
}
has no attribute
{
repr
(
name
)
}
"
)
# noinspection PyPep8Naming
@
current_names
()
class
ndarray
(
TreeNumpy
):
class
ndarray
(
TreeNumpy
,
metaclass
=
_ArrayMeta
):
"""
Overview:
Real numpy tree.
"""
@
method_treelize
(
return_type
=
Object
)
def
tolist
(
self
:
np
.
ndarray
):
def
__get_attr
(
self
,
key
):
return
getattr
(
self
,
key
)
def
_attr_extern
(
self
,
name
):
try
:
return
getattr
(
self
.
np
,
name
)
except
AttributeError
:
tree
=
self
.
__get_attr
(
name
)
if
tree
.
map
(
lambda
x
:
isinstance
(
x
,
numpy
.
ndarray
)).
all
():
return
tree
.
type
(
ndarray
)
else
:
return
tree
@
property
def
np
(
self
):
return
_InstanceArrayProxy
(
self
.
__class__
.
np
,
self
)
@
method_treelize
(
return_type
=
Object
)
def
tolist
(
self
:
numpy
.
ndarray
):
return
self
.
tolist
()
@
property
@
ireduce
(
sum
)
@
method_treelize
(
return_type
=
Object
)
def
size
(
self
:
n
p
.
ndarray
)
->
int
:
def
size
(
self
:
n
umpy
.
ndarray
)
->
int
:
return
self
.
size
@
property
@
ireduce
(
sum
)
@
method_treelize
(
return_type
=
Object
)
def
nbytes
(
self
:
n
p
.
ndarray
)
->
int
:
def
nbytes
(
self
:
n
umpy
.
ndarray
)
->
int
:
return
self
.
nbytes
@
ireduce
(
sum
)
@
method_treelize
(
return_type
=
Object
)
def
sum
(
self
:
n
p
.
ndarray
,
*
args
,
**
kwargs
):
def
sum
(
self
:
n
umpy
.
ndarray
,
*
args
,
**
kwargs
):
return
self
.
sum
(
*
args
,
**
kwargs
)
@
ireduce
(
all
)
@
method_treelize
(
return_type
=
Object
)
def
all
(
self
:
n
p
.
ndarray
,
*
args
,
**
kwargs
):
def
all
(
self
:
n
umpy
.
ndarray
,
*
args
,
**
kwargs
):
return
self
.
all
(
*
args
,
**
kwargs
)
@
ireduce
(
any
)
@
method_treelize
(
return_type
=
Object
)
def
any
(
self
:
n
p
.
ndarray
,
*
args
,
**
kwargs
):
def
any
(
self
:
n
umpy
.
ndarray
,
*
args
,
**
kwargs
):
return
self
.
any
(
*
args
,
**
kwargs
)
@
method_treelize
()
...
...
treetensor/numpy/funcs.py
浏览文件 @
a24fa275
...
...
@@ -3,10 +3,11 @@ import builtins
import
numpy
as
np
from
treevalue
import
TreeValue
from
treevalue
import
func_treelize
as
original_func_treelize
from
treevalue.tree.common
import
BaseTree
from
treevalue.utils
import
post_process
from
.array
import
ndarray
from
..common
import
ireduce
,
Object
from
..common
import
ireduce
,
Object
,
module_func_loader
from
..utils
import
replaceable_partial
,
doc_from
,
args_mapping
__all__
=
[
...
...
@@ -15,9 +16,10 @@ __all__ = [
]
func_treelize
=
post_process
(
post_process
(
args_mapping
(
lambda
i
,
x
:
TreeValue
(
x
)
if
isinstance
(
x
,
(
dict
,
TreeValue
))
else
x
)))(
lambda
i
,
x
:
TreeValue
(
x
)
if
isinstance
(
x
,
(
dict
,
BaseTree
,
TreeValue
))
else
x
)))(
replaceable_partial
(
original_func_treelize
,
return_type
=
ndarray
)
)
get_func_from_numpy
=
module_func_loader
(
np
,
ndarray
,
__name__
)
@
doc_from
(
np
.
all
)
...
...
treetensor/torch/base/torch.py
浏览文件 @
a24fa275
from
typing
import
Type
from
treevalue
import
TreeValue
,
typetrans
from
...common
import
BaseTreeStruct
__all__
=
[
'Torch'
,
'auto_torch'
]
__all__
=
[
'Torch'
]
class
Torch
(
BaseTreeStruct
):
pass
# noinspection PyArgumentList
def
auto_torch
(
v
,
cls
:
Type
[
Torch
]):
if
isinstance
(
v
,
TreeValue
):
return
typetrans
(
v
,
cls
)
elif
isinstance
(
v
,
(
tuple
,
list
,
set
)):
return
type
(
v
)((
auto_torch
(
item
,
cls
)
for
item
in
v
))
elif
isinstance
(
v
,
dict
):
return
type
(
v
)({
key
:
auto_torch
(
value
,
cls
)
for
key
,
value
in
v
.
items
()})
else
:
return
v
treetensor/torch/funcs/base.py
浏览文件 @
a24fa275
from
functools
import
wraps
import
torch
from
treevalue
import
TreeValue
from
treevalue
import
func_treelize
as
original_func_treelize
from
treevalue.tree.common
import
BaseTree
from
treevalue.utils
import
post_process
from
..base
import
auto_torch
from
..tensor
import
Tensor
from
...common
import
return_self
from
...common
import
auto_tree
,
module_func_loader
from
...utils
import
doc_from_base
as
original_doc_from_base
from
...utils
import
replaceable_partial
,
args_mapping
...
...
@@ -17,23 +14,5 @@ func_treelize = post_process(post_process(args_mapping(
replaceable_partial
(
original_func_treelize
,
return_type
=
Tensor
)
)
doc_from_base
=
replaceable_partial
(
original_doc_from_base
,
base
=
torch
)
auto_tensor
=
replaceable_partial
(
auto_torch
,
cls
=
Tensor
)
_funcs_module
=
'.'
.
join
(
__name__
.
split
(
'.'
)[:
-
1
])
def
get_func_from_torch
(
name
):
func
=
getattr
(
torch
,
name
)
return_self_dec
=
return_self
if
func
.
__name__
.
endswith
(
"_"
)
else
(
lambda
x
:
x
)
@
doc_from_base
()
@
return_self_dec
@
post_process
(
auto_tensor
)
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
@
wraps
(
func
,
assigned
=
(
'__name__'
,),
updated
=
())
def
_new_func
(
*
args
,
**
kwargs
):
return
func
(
*
args
,
**
kwargs
)
_new_func
.
__qualname__
=
_new_func
.
__name__
_new_func
.
__module__
=
_funcs_module
return
_new_func
auto_tensor
=
replaceable_partial
(
auto_tree
,
cls
=
Tensor
)
get_func_from_torch
=
module_func_loader
(
torch
,
Tensor
,
'.'
.
join
(
__name__
.
split
(
'.'
)[:
-
1
]))
treetensor/torch/size.py
浏览文件 @
a24fa275
...
...
@@ -54,7 +54,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)):
Examples::
>>> import torch
>>> import treetensor.
torch
as ttorch
>>> import treetensor.
numpy
as ttorch
>>> ttorch.Size([1, 2, 3])
torch.Size([1, 2, 3])
...
...
@@ -81,7 +81,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)):
Example::
>>> import torch
>>> import treetensor.
torch
as ttorch
>>> import treetensor.
numpy
as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
...
...
@@ -99,7 +99,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)):
Example::
>>> import torch
>>> import treetensor.
torch
as ttorch
>>> import treetensor.
numpy
as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
...
...
@@ -132,7 +132,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)):
Example::
>>> import torch
>>> import treetensor.
torch
as ttorch
>>> import treetensor.
numpy
as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
...
...
treetensor/torch/tensor.py
浏览文件 @
a24fa275
from
functools
import
wraps
from
types
import
MethodType
import
numpy
as
np
import
torch
as
pytorch
from
treevalue
import
method_treelize
,
TreeValue
from
treevalue.utils
import
post_process
from
.base
import
Torch
,
auto_torch
,
rmreduce
,
post_reduce
,
auto_reduce
from
.base
import
Torch
,
rmreduce
,
post_reduce
,
auto_reduce
from
.size
import
Size
from
..common
import
Object
,
ireduce
,
clsmeta
,
return_self
from
..common
import
Object
,
ireduce
,
clsmeta
,
return_self
,
auto_tree
,
get_tree_proxy
from
..numpy
import
ndarray
from
..utils
import
current_names
,
class_autoremove
,
replaceable_partial
from
..utils
import
doc_from_base
as
original_doc_from_base
...
...
@@ -18,6 +15,7 @@ __all__ = [
]
doc_from_base
=
replaceable_partial
(
original_doc_from_base
,
base
=
pytorch
.
Tensor
)
_TorchProxy
,
_InstanceTorchProxy
=
get_tree_proxy
(
pytorch
.
Tensor
)
def
_to_tensor
(
*
args
,
**
kwargs
):
...
...
@@ -30,44 +28,6 @@ def _to_tensor(*args, **kwargs):
return
pytorch
.
tensor
(
*
args
,
**
kwargs
)
class
_TorchProxy
:
def
__init__
(
self
,
cls
):
self
.
__torch_funcs
=
{}
self
.
__cls
=
cls
def
__getattr__
(
self
,
name
):
if
name
in
self
.
__torch_funcs
.
keys
():
return
self
.
__torch_funcs
[
name
]
elif
hasattr
(
pytorch
.
Tensor
,
name
)
and
not
name
.
startswith
(
'_'
)
\
and
callable
(
getattr
(
pytorch
.
Tensor
,
name
)):
_origin_func
=
getattr
(
pytorch
.
Tensor
,
name
)
return_self_deco
=
return_self
if
name
.
endswith
(
'_'
)
else
(
lambda
x
:
x
)
@
doc_from_base
()
@
return_self_deco
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_torch
,
cls
=
self
.
__cls
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
@
wraps
(
_origin_func
,
assigned
=
(
'__name__'
,),
updated
=
())
def
_new_func
(
*
args
,
**
kwargs
):
return
_origin_func
(
*
args
,
**
kwargs
)
_new_func
.
__qualname__
=
f
'
{
self
.
__cls
.
__name__
}
.
{
name
}
'
_new_func
.
__module__
=
self
.
__cls
.
__module__
self
.
__torch_funcs
[
name
]
=
_new_func
return
_new_func
else
:
raise
AttributeError
(
f
'Function
{
repr
(
name
)
}
not found in
{
repr
(
pytorch
)
}
'
)
class
_InstanceTorchProxy
:
def
__init__
(
self
,
proxy
,
s
):
self
.
__proxy
=
proxy
self
.
__self
=
s
def
__getattr__
(
self
,
name
):
return
MethodType
(
getattr
(
self
.
__proxy
,
name
),
self
.
__self
)
class
_BaseTensorMeta
(
clsmeta
(
_to_tensor
,
allow_dict
=
True
)):
pass
...
...
@@ -76,13 +36,13 @@ class _BaseTensorMeta(clsmeta(_to_tensor, allow_dict=True)):
class
_TensorMeta
(
_BaseTensorMeta
):
def
__init__
(
cls
,
*
args
,
**
kwargs
):
_BaseTensorMeta
.
__init__
(
cls
,
*
args
,
**
kwargs
)
cls
.
__
torch_
proxy
=
None
cls
.
__proxy
=
None
@
property
def
torch
(
cls
):
if
not
cls
.
__
torch_
proxy
:
cls
.
__
torch_
proxy
=
_TorchProxy
(
cls
)
return
cls
.
__
torch_
proxy
if
not
cls
.
__proxy
:
cls
.
__proxy
=
_TorchProxy
(
cls
)
return
cls
.
__proxy
def
__getattr__
(
cls
,
name
):
try
:
...
...
@@ -439,7 +399,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return
self
# noinspection PyShadowingBuiltins
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_t
orch
,
cls
=
Tensor
)(
r
))
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_t
ree
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
__max_nr
(
self
,
*
args
,
**
kwargs
):
return
pytorch
.
max
(
self
,
*
args
,
**
kwargs
)
...
...
@@ -459,7 +419,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return
self
# noinspection PyShadowingBuiltins
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_t
orch
,
cls
=
Tensor
)(
r
))
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_t
ree
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
__min_nr
(
self
,
*
args
,
**
kwargs
):
return
pytorch
.
min
(
self
,
*
args
,
**
kwargs
)
...
...
@@ -479,7 +439,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return
self
# noinspection PyShadowingBuiltins
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_t
orch
,
cls
=
Tensor
)(
r
))
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_t
ree
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
__sum_nr
(
self
,
*
args
,
**
kwargs
):
return
pytorch
.
sum
(
self
,
*
args
,
**
kwargs
)
...
...
@@ -922,7 +882,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return
self
.
log10_
(
*
args
,
**
kwargs
)
@
doc_from_base
()
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_t
orch
,
cls
=
Tensor
)(
r
))
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_t
ree
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
split
(
self
,
split_size
,
*
args
,
**
kwargs
):
"""
...
...
@@ -931,7 +891,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return
self
.
split
(
split_size
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_t
orch
,
cls
=
Tensor
)(
r
))
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_t
ree
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
chunk
(
self
,
chunks
,
*
args
,
**
kwargs
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录