Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
4086271e
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4086271e
编写于
9月 19, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev, doc(hansbug): add new function for Size class && add plenty of new documentations
上级
7ff4109d
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
186 addition
and
68 deletion
+186
-68
test/numpy/test_array.py
test/numpy/test_array.py
+4
-4
test/torch/test_tensor.py
test/torch/test_tensor.py
+12
-12
treetensor/__init__.py
treetensor/__init__.py
+1
-1
treetensor/common/trees.py
treetensor/common/trees.py
+36
-4
treetensor/numpy/array.py
treetensor/numpy/array.py
+7
-7
treetensor/numpy/funcs.py
treetensor/numpy/funcs.py
+2
-2
treetensor/torch/funcs.py
treetensor/torch/funcs.py
+8
-12
treetensor/torch/size.py
treetensor/torch/size.py
+96
-11
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+20
-15
未找到文件。
test/numpy/test_array.py
浏览文件 @
4086271e
...
...
@@ -2,7 +2,7 @@ import numpy as np
import
pytest
import
treetensor.numpy
as
tnp
from
treetensor.common
import
Tree
Object
from
treetensor.common
import
Object
# noinspection DuplicatedCode
...
...
@@ -209,7 +209,7 @@ class TestNumpyArray:
})).
all
()
def
test_tolist
(
self
):
assert
self
.
_DEMO_1
.
tolist
()
==
Tree
Object
({
assert
self
.
_DEMO_1
.
tolist
()
==
Object
({
'a'
:
[[
1
,
2
,
3
],
[
4
,
5
,
6
]],
'b'
:
[
1
,
3
,
5
,
7
],
'x'
:
{
...
...
@@ -217,7 +217,7 @@ class TestNumpyArray:
'd'
:
[
3
,
9
,
11.0
],
}
})
assert
self
.
_DEMO_2
.
tolist
()
==
Tree
Object
({
assert
self
.
_DEMO_2
.
tolist
()
==
Object
({
'a'
:
[[
1
,
22
,
3
],
[
4
,
5
,
6
]],
'b'
:
[
1
,
3
,
5
,
7
],
'x'
:
{
...
...
@@ -225,7 +225,7 @@ class TestNumpyArray:
'd'
:
[
3
,
9
,
11.0
],
}
})
assert
self
.
_DEMO_3
.
tolist
()
==
Tree
Object
({
assert
self
.
_DEMO_3
.
tolist
()
==
Object
({
'a'
:
[[
0
,
0
,
0
],
[
0
,
0
,
0
]],
'b'
:
[
0
,
0
,
0
,
0
],
'x'
:
{
...
...
test/torch/test_tensor.py
浏览文件 @
4086271e
...
...
@@ -12,20 +12,20 @@ _all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
@
pytest
.
mark
.
unittest
class
TestTorchTensor
:
_DEMO_1
=
ttorch
.
Tensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])
,
'b'
:
torch
.
tensor
([[
1
,
2
],
[
5
,
6
]])
,
'a'
:
[[
1
,
2
,
3
],
[
4
,
5
,
6
]]
,
'b'
:
[[
1
,
2
],
[
5
,
6
]]
,
'x'
:
{
'c'
:
torch
.
tensor
([
3
,
5
,
6
,
7
])
,
'd'
:
torch
.
tensor
([[[
1
,
2
],
[
8
,
9
]]])
,
'c'
:
[
3
,
5
,
6
,
7
]
,
'd'
:
[[[
1
,
2
],
[
8
,
9
]]]
,
}
})
_DEMO_2
=
ttorch
.
Tensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])
,
'b'
:
torch
.
tensor
([[
1
,
2
],
[
5
,
60
]])
,
'a'
:
[[
1
,
2
,
3
],
[
4
,
5
,
6
]]
,
'b'
:
[[
1
,
2
],
[
5
,
60
]]
,
'x'
:
{
'c'
:
torch
.
tensor
([
3
,
5
,
6
,
7
])
,
'd'
:
torch
.
tensor
([[[
1
,
2
],
[
8
,
9
]]])
,
'c'
:
[
3
,
5
,
6
,
7
]
,
'd'
:
[[[
1
,
2
],
[
8
,
9
]]]
,
}
})
...
...
@@ -48,11 +48,11 @@ class TestTorchTensor:
def
test_to
(
self
):
assert
ttorch
.
all
(
self
.
_DEMO_1
.
to
(
torch
.
float32
)
==
ttorch
.
Tensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
torch
.
float32
),
'b'
:
torch
.
tensor
([[
1
,
2
],
[
5
,
6
]],
dtype
=
torch
.
float32
),
'a'
:
torch
.
FloatTensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]
),
'b'
:
torch
.
FloatTensor
([[
1
,
2
],
[
5
,
6
]]
),
'x'
:
{
'c'
:
torch
.
tensor
([
3
,
5
,
6
,
7
],
dtype
=
torch
.
float32
),
'd'
:
torch
.
tensor
([[[
1
,
2
],
[
8
,
9
]]],
dtype
=
torch
.
float32
),
'c'
:
torch
.
FloatTensor
([
3
,
5
,
6
,
7
]
),
'd'
:
torch
.
FloatTensor
([[[
1
,
2
],
[
8
,
9
]]]
),
}
}))
...
...
treetensor/__init__.py
浏览文件 @
4086271e
from
.common
import
Tree
Object
from
.common
import
Object
from
.numpy
import
ndarray
from
.torch
import
Tensor
treetensor/common/trees.py
浏览文件 @
4086271e
import
builtins
import
io
import
os
from
abc
import
ABCMeta
from
functools
import
partial
from
typing
import
Optional
,
Tuple
,
Callable
from
treevalue
import
func_treelize
as
original_func_treelize
from
treevalue
import
general_tree_value
,
TreeValue
from
treevalue.tree.common
import
BaseTree
from
treevalue.tree.tree.tree
import
get_data_property
from
treevalue.utils
import
post_process
from
..utils
import
replaceable_partial
,
args_mapping
__all__
=
[
'BaseTreeStruct'
,
"TreeObject"
,
'print_tree'
,
'BaseTreeStruct'
,
"Object"
,
'print_tree'
,
'clsmeta'
,
]
...
...
@@ -78,7 +83,7 @@ def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, fil
print
(
repr_
(
tree
),
file
=
file
)
class
BaseTreeStruct
(
general_tree_value
()
,
metaclass
=
ABCMeta
):
class
BaseTreeStruct
(
general_tree_value
()):
"""
Overview:
Base structure of all the trees in ``treetensor``.
...
...
@@ -93,5 +98,32 @@ class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
return
self
.
__repr__
()
class
TreeObject
(
BaseTreeStruct
):
def
clsmeta
(
cls
:
type
,
allow_dict
:
bool
=
False
,
allow_data
:
bool
=
True
):
class
_TempTreeValue
(
TreeValue
):
pass
_types
=
(
TreeValue
,
*
((
dict
,)
if
allow_dict
else
()),
*
((
BaseTree
,)
if
allow_data
else
()),
)
func_treelize
=
post_process
(
post_process
(
args_mapping
(
lambda
i
,
x
:
TreeValue
(
x
)
if
isinstance
(
x
,
_types
)
else
x
)))(
replaceable_partial
(
original_func_treelize
,
return_type
=
_TempTreeValue
)
)
_torch_size
=
func_treelize
()(
cls
)
class
_MetaClass
(
type
):
def
__call__
(
cls
,
*
args
,
**
kwargs
):
_result
=
_torch_size
(
*
args
,
**
kwargs
)
if
isinstance
(
_result
,
_TempTreeValue
):
return
type
.
__call__
(
cls
,
_result
)
else
:
return
_result
return
_MetaClass
class
Object
(
BaseTreeStruct
):
pass
treetensor/numpy/array.py
浏览文件 @
4086271e
...
...
@@ -2,7 +2,7 @@ import numpy as np
from
treevalue
import
method_treelize
from
.base
import
TreeNumpy
from
..common
import
Tree
Object
,
ireduce
from
..common
import
Object
,
ireduce
from
..utils
import
current_names
__all__
=
[
...
...
@@ -18,34 +18,34 @@ class ndarray(TreeNumpy):
Real numpy tree.
"""
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
tolist
(
self
:
np
.
ndarray
):
return
self
.
tolist
()
@
property
@
ireduce
(
sum
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
size
(
self
:
np
.
ndarray
)
->
int
:
return
self
.
size
@
property
@
ireduce
(
sum
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
nbytes
(
self
:
np
.
ndarray
)
->
int
:
return
self
.
nbytes
@
ireduce
(
sum
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
sum
(
self
:
np
.
ndarray
,
*
args
,
**
kwargs
):
return
self
.
sum
(
*
args
,
**
kwargs
)
@
ireduce
(
all
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
all
(
self
:
np
.
ndarray
,
*
args
,
**
kwargs
):
return
self
.
all
(
*
args
,
**
kwargs
)
@
ireduce
(
any
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
any
(
self
:
np
.
ndarray
,
*
args
,
**
kwargs
):
return
self
.
any
(
*
args
,
**
kwargs
)
...
...
treetensor/numpy/funcs.py
浏览文件 @
4086271e
...
...
@@ -6,7 +6,7 @@ from treevalue import func_treelize as original_func_treelize
from
treevalue.utils
import
post_process
from
.array
import
ndarray
from
..common
import
ireduce
,
Tree
Object
from
..common
import
ireduce
,
Object
from
..utils
import
replaceable_partial
,
doc_from
,
args_mapping
__all__
=
[
...
...
@@ -22,7 +22,7 @@ func_treelize = post_process(post_process(args_mapping(
@
doc_from
(
np
.
all
)
@
ireduce
(
builtins
.
all
)
@
func_treelize
(
return_type
=
Tree
Object
)
@
func_treelize
(
return_type
=
Object
)
def
all
(
a
,
*
args
,
**
kwargs
):
return
np
.
all
(
a
,
*
args
,
**
kwargs
)
...
...
treetensor/torch/funcs.py
浏览文件 @
4086271e
"""
Overview:
Common functions, based on ``torch`` module.
"""
import
builtins
import
torch
from
treevalue
import
TreeValue
from
treevalue
import
func_treelize
as
original_func_treelize
from
treevalue.tree.common
import
BaseTree
from
treevalue.utils
import
post_process
from
.tensor
import
Tensor
,
tireduce
from
..common
import
Tree
Object
,
ireduce
from
..common
import
Object
,
ireduce
from
..utils
import
replaceable_partial
,
doc_from
,
args_mapping
__all__
=
[
...
...
@@ -28,7 +24,7 @@ __all__ = [
]
func_treelize
=
post_process
(
post_process
(
args_mapping
(
lambda
i
,
x
:
TreeValue
(
x
)
if
isinstance
(
x
,
(
dict
,
TreeValue
))
else
x
)))(
lambda
i
,
x
:
TreeValue
(
x
)
if
isinstance
(
x
,
(
dict
,
BaseTree
,
TreeValue
))
else
x
)))(
replaceable_partial
(
original_func_treelize
,
return_type
=
Tensor
)
)
...
...
@@ -355,7 +351,7 @@ def empty_like(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
all
)
@
tireduce
(
torch
.
all
)
@
func_treelize
(
return_type
=
Tree
Object
)
@
func_treelize
(
return_type
=
Object
)
def
all
(
input
,
*
args
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``all`` result of a whole tree with this function.
...
...
@@ -394,7 +390,7 @@ def all(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
any
)
@
tireduce
(
torch
.
any
)
@
func_treelize
(
return_type
=
Tree
Object
)
@
func_treelize
(
return_type
=
Object
)
def
any
(
input
,
*
args
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``any`` result of a whole tree with this function.
...
...
@@ -433,7 +429,7 @@ def any(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
min
)
@
tireduce
(
torch
.
min
)
@
func_treelize
(
return_type
=
Tree
Object
)
@
func_treelize
(
return_type
=
Object
)
def
min
(
input
,
*
args
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``min`` result of a whole tree with this function.
...
...
@@ -472,7 +468,7 @@ def min(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
max
)
@
tireduce
(
torch
.
max
)
@
func_treelize
(
return_type
=
Tree
Object
)
@
func_treelize
(
return_type
=
Object
)
def
max
(
input
,
*
args
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``max`` result of a whole tree with this function.
...
...
@@ -511,7 +507,7 @@ def max(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@
doc_from
(
torch
.
sum
)
@
tireduce
(
torch
.
sum
)
@
func_treelize
(
return_type
=
Tree
Object
)
@
func_treelize
(
return_type
=
Object
)
def
sum
(
input
,
*
args
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``sum`` result of a whole tree with this function.
...
...
treetensor/torch/size.py
浏览文件 @
4086271e
from
functools
import
wraps
import
torch
from
treevalue
import
TreeValue
from
treevalue
import
func_treelize
as
original_func_treelize
from
treevalue.tree.common
import
BaseTree
from
treevalue.utils
import
post_process
from
.base
import
TreeTorch
from
..common
import
TreeObject
from
..utils
import
replaceable_partial
,
doc_from
,
current_names
from
..common
import
Object
,
clsmeta
,
ireduce
from
..utils
import
replaceable_partial
,
doc_from
,
current_names
,
args_mapping
func_treelize
=
replaceable_partial
(
original_func_treelize
)
func_treelize
=
post_process
(
post_process
(
args_mapping
(
lambda
i
,
x
:
TreeValue
(
x
)
if
isinstance
(
x
,
(
dict
,
BaseTree
,
TreeValue
))
else
x
)))(
replaceable_partial
(
original_func_treelize
)
)
__all__
=
[
'Size'
]
def
_post_index
(
func
):
def
_has_non_none
(
tree
):
if
isinstance
(
tree
,
TreeValue
):
for
_
,
value
in
tree
:
if
_has_non_none
(
value
):
return
True
return
False
else
:
return
tree
is
not
None
@
wraps
(
func
)
def
_new_func
(
self
,
value
,
*
args
,
**
kwargs
):
_tree
=
func
(
self
,
value
,
*
args
,
**
kwargs
)
if
not
_has_non_none
(
_tree
):
raise
ValueError
(
f
'Can not find
{
repr
(
value
)
}
in all the sizes.'
)
else
:
return
_tree
return
_new_func
# noinspection PyTypeChecker
@
current_names
()
class
Size
(
TreeTorch
):
class
Size
(
TreeTorch
,
metaclass
=
clsmeta
(
torch
.
Size
,
allow_dict
=
True
)
):
@
doc_from
(
torch
.
Size
.
numel
)
@
func_treelize
(
return_type
=
TreeObject
)
def
numel
(
self
:
torch
.
Size
)
->
TreeObject
:
@
ireduce
(
sum
)
@
func_treelize
(
return_type
=
Object
)
def
numel
(
self
:
torch
.
Size
)
->
Object
:
"""
Get the numel sum of the sizes in this tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
... }).numel()
26
"""
return
self
.
numel
()
@
doc_from
(
torch
.
Size
.
index
)
@
func_treelize
(
return_type
=
TreeObject
)
def
index
(
self
:
torch
.
Size
,
*
args
,
**
kwargs
)
->
TreeObject
:
return
self
.
index
(
*
args
,
**
kwargs
)
@
_post_index
@
func_treelize
(
return_type
=
Object
)
def
index
(
self
:
torch
.
Size
,
value
,
*
args
,
**
kwargs
)
->
Object
:
"""
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
... 'c': [3, 5],
... }).index(2)
<Object 0x7fb412780e80>
├── a --> 1
├── b --> <Object 0x7fb412780eb8>
│ └── x --> 1
└── c --> None
.. note::
This method's behaviour is different from the :func:`torch.Size.index`.
No :class:`ValueError` will be raised unless the value can not be found
in any of the sizes, instead there will be nones returned in the tree.
"""
try
:
return
self
.
index
(
value
,
*
args
,
**
kwargs
)
except
ValueError
:
return
None
@
doc_from
(
torch
.
Size
.
count
)
@
func_treelize
(
return_type
=
TreeObject
)
def
count
(
self
:
torch
.
Size
,
*
args
,
**
kwargs
)
->
TreeObject
:
@
ireduce
(
sum
)
@
func_treelize
(
return_type
=
Object
)
def
count
(
self
:
torch
.
Size
,
*
args
,
**
kwargs
)
->
Object
:
"""
Get the occurrence count of the sizes in this tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.Size({
... 'a': [1, 2],
... 'b': {'x': [3, 2, 4]},
... }).count(2)
2
"""
return
self
.
count
(
*
args
,
**
kwargs
)
treetensor/torch/tensor.py
浏览文件 @
4086271e
"""
Overview:
``Tensor`` class, based on ``torch`` module.
"""
import
numpy
as
np
import
torch
from
treevalue
import
method_treelize
...
...
@@ -10,7 +5,7 @@ from treevalue.utils import pre_process
from
.base
import
TreeTorch
from
.size
import
Size
from
..common
import
TreeObject
,
ireduce
from
..common
import
Object
,
ireduce
,
clsmeta
from
..numpy
import
ndarray
from
..utils
import
current_names
,
doc_from
...
...
@@ -22,9 +17,19 @@ _reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
tireduce
=
pre_process
(
lambda
rfunc
:
((
_reduce_tensor_wrap
(
rfunc
),),
{}))(
ireduce
)
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
def
_to_tensor
(
*
args
,
**
kwargs
):
if
(
len
(
args
)
==
1
and
not
kwargs
)
or
\
(
not
args
and
set
(
kwargs
.
keys
())
==
{
'data'
}):
data
=
args
[
0
]
if
len
(
args
)
==
1
else
kwargs
[
'data'
]
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
return
torch
.
tensor
(
*
args
,
**
kwargs
)
# noinspection PyTypeChecker
@
current_names
()
class
Tensor
(
TreeTorch
):
class
Tensor
(
TreeTorch
,
metaclass
=
clsmeta
(
_to_tensor
,
allow_dict
=
True
)
):
@
doc_from
(
torch
.
Tensor
.
numpy
)
@
method_treelize
(
return_type
=
ndarray
)
def
numpy
(
self
:
torch
.
Tensor
)
->
np
.
ndarray
:
...
...
@@ -36,7 +41,7 @@ class Tensor(TreeTorch):
return
self
.
numpy
()
@
doc_from
(
torch
.
Tensor
.
tolist
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
tolist
(
self
:
torch
.
Tensor
):
"""
Get the dump result of tree tensor.
...
...
@@ -106,7 +111,7 @@ class Tensor(TreeTorch):
@
doc_from
(
torch
.
Tensor
.
numel
)
@
ireduce
(
sum
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
numel
(
self
:
torch
.
Tensor
):
"""
See :func:`treetensor.torch.numel`
...
...
@@ -137,7 +142,7 @@ class Tensor(TreeTorch):
@
doc_from
(
torch
.
Tensor
.
all
)
@
tireduce
(
torch
.
all
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
all
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
bool
:
"""
See :func:`treetensor.torch.all`
...
...
@@ -146,7 +151,7 @@ class Tensor(TreeTorch):
@
doc_from
(
torch
.
Tensor
.
any
)
@
tireduce
(
torch
.
any
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
any
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
bool
:
"""
See :func:`treetensor.torch.any`
...
...
@@ -155,7 +160,7 @@ class Tensor(TreeTorch):
@
doc_from
(
torch
.
Tensor
.
max
)
@
tireduce
(
torch
.
max
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
max
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.max`
...
...
@@ -164,7 +169,7 @@ class Tensor(TreeTorch):
@
doc_from
(
torch
.
Tensor
.
min
)
@
tireduce
(
torch
.
min
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
min
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.min`
...
...
@@ -173,7 +178,7 @@ class Tensor(TreeTorch):
@
doc_from
(
torch
.
Tensor
.
sum
)
@
tireduce
(
torch
.
sum
)
@
method_treelize
(
return_type
=
Tree
Object
)
@
method_treelize
(
return_type
=
Object
)
def
sum
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.sum`
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录