Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
ca149e3f
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ca149e3f
编写于
9月 28, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): upgrade max, min, sum
上级
0150535c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
268 addition
and
96 deletion
+268
-96
test/torch/funcs/test_reduction.py
test/torch/funcs/test_reduction.py
+42
-6
test/torch/tensor/test_reduction.py
test/torch/tensor/test_reduction.py
+42
-9
treetensor/torch/base/torch.py
treetensor/torch/base/torch.py
+10
-2
treetensor/torch/funcs/operation.py
treetensor/torch/funcs/operation.py
+4
-2
treetensor/torch/funcs/reduction.py
treetensor/torch/funcs/reduction.py
+123
-63
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+47
-14
未找到文件。
test/torch/funcs/test_reduction.py
浏览文件 @
ca149e3f
...
...
@@ -145,10 +145,24 @@ class TestTorchFuncsReduction:
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
t1
==
torch
.
tensor
(
1.0
)
assert
ttorch
.
isclose
(
ttorch
.
min
(
ttorch
.
tensor
({
tt0
=
ttorch
.
tensor
({
'a'
:
[
1.0
,
2.0
,
1.5
],
'b'
:
{
'x'
:
[[
1.8
,
0.9
],
[
1.3
,
2.5
]]},
})),
ttorch
.
tensor
(
0.9
),
atol
=
1e-4
)
})
assert
ttorch
.
isclose
(
ttorch
.
min
(
tt0
),
ttorch
.
tensor
(
0.9
),
atol
=
1e-4
).
all
()
tt1
=
ttorch
.
min
(
tt0
,
reduce
=
False
)
assert
ttorch
.
isclose
(
tt1
,
ttorch
.
tensor
({
'a'
:
1.0
,
'b'
:
0.9
,
}),
atol
=
1e-4
).
all
()
tt2_a
,
tt2_b
=
ttorch
.
min
(
tt0
,
dim
=
0
)
assert
ttorch
.
isclose
(
tt2_a
,
ttorch
.
tensor
({
'a'
:
1.0
,
'b'
:
[
1.3
,
0.9
],
}),
atol
=
1e-4
).
all
()
assert
(
tt2_b
==
ttorch
.
tensor
({
'a'
:
0
,
'b'
:
[
1
,
0
],
})).
all
()
@
choose_mark
()
def
test_max
(
self
):
...
...
@@ -156,18 +170,40 @@ class TestTorchFuncsReduction:
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
t1
==
torch
.
tensor
(
2.0
)
assert
ttorch
.
isclose
(
ttorch
.
max
(
ttorch
.
tensor
({
tt0
=
ttorch
.
tensor
({
'a'
:
[
1.0
,
2.0
,
1.5
],
'b'
:
{
'x'
:
[[
1.8
,
0.9
],
[
1.3
,
2.5
]]},
})),
ttorch
.
tensor
(
2.5
),
atol
=
1e-4
)
})
assert
ttorch
.
isclose
(
ttorch
.
max
(
tt0
),
ttorch
.
tensor
(
2.5
),
atol
=
1e-4
)
tt1
=
ttorch
.
max
(
tt0
,
reduce
=
False
)
assert
ttorch
.
isclose
(
tt1
,
ttorch
.
tensor
({
'a'
:
2.0
,
'b'
:
2.5
,
}),
atol
=
1e-4
).
all
()
tt2_a
,
tt2_b
=
ttorch
.
max
(
tt0
,
dim
=
0
)
assert
ttorch
.
isclose
(
tt2_a
,
ttorch
.
tensor
({
'a'
:
2.0
,
'b'
:
[
1.8
,
2.5
],
}),
atol
=
1e-4
).
all
()
assert
(
tt2_b
==
ttorch
.
tensor
({
'a'
:
1
,
'b'
:
[
0
,
1
],
})).
all
()
@
choose_mark
()
def
test_sum
(
self
):
assert
ttorch
.
sum
(
torch
.
tensor
([
1.0
,
2.0
,
1.5
]))
==
torch
.
tensor
(
4.5
)
assert
(
ttorch
.
sum
(
ttorch
.
tensor
({
tt0
=
ttorch
.
tensor
({
'a'
:
[
1.0
,
2.0
,
1.5
],
'b'
:
{
'x'
:
[[
1.8
,
0.9
],
[
1.3
,
2.5
]]},
}))
==
torch
.
tensor
(
11.0
)).
all
()
})
assert
ttorch
.
isclose
(
ttorch
.
sum
(
tt0
),
torch
.
tensor
(
11.0
),
atol
=
1e-4
).
all
()
assert
ttorch
.
isclose
(
ttorch
.
sum
(
tt0
,
reduce
=
False
),
ttorch
.
tensor
({
'a'
:
4.5
,
'b'
:
{
'x'
:
6.5
},
}),
atol
=
1e-4
).
all
()
assert
ttorch
.
isclose
(
ttorch
.
sum
(
tt0
,
dim
=
0
),
ttorch
.
tensor
({
'a'
:
4.5
,
'b'
:
{
'x'
:
[
3.1
,
3.4
]},
}),
atol
=
1e-4
).
all
()
@
choose_mark
()
def
test_mean
(
self
):
...
...
test/torch/tensor/test_reduction.py
浏览文件 @
ca149e3f
...
...
@@ -86,30 +86,63 @@ class TestTorchTensorReduction:
@
choose_mark
()
def
test_max
(
self
):
t
1
=
ttorch
.
Tensor
({
t
0
=
ttorch
.
Tensor
({
'a'
:
[
1
,
2
],
'b'
:
{
'x'
:
[[
0
,
3
],
[
2
,
-
1
]]}
}).
max
()
})
t1
=
t0
.
max
()
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
t1
.
tolist
()
==
3
assert
(
t1
==
torch
.
tensor
(
3
)).
all
()
t2
=
t0
.
max
(
reduce
=
False
)
assert
(
t2
==
ttorch
.
tensor
({
'a'
:
2
,
'b'
:
{
'x'
:
3
}})).
all
()
t3_a
,
t3_b
=
t0
.
max
(
dim
=
0
)
assert
(
t3_a
==
ttorch
.
tensor
({
'a'
:
2
,
'b'
:
{
'x'
:
[
2
,
3
]},
})).
all
()
assert
(
t3_b
==
ttorch
.
tensor
({
'a'
:
1
,
'b'
:
{
'x'
:
[
1
,
0
]},
})).
all
()
@
choose_mark
()
def
test_min
(
self
):
t
1
=
ttorch
.
Tensor
({
t
0
=
ttorch
.
Tensor
({
'a'
:
[
1
,
2
],
'b'
:
{
'x'
:
[[
0
,
3
],
[
2
,
-
1
]]}
}).
min
()
})
t1
=
t0
.
min
()
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
t1
.
tolist
()
==
-
1
assert
(
t1
==
torch
.
tensor
(
-
1
)).
all
()
t2
=
t0
.
min
(
reduce
=
False
)
assert
(
t2
==
ttorch
.
tensor
({
'a'
:
1
,
'b'
:
{
'x'
:
-
1
}})).
all
()
t3_a
,
t3_b
=
t0
.
min
(
dim
=
0
)
assert
(
t3_a
==
ttorch
.
tensor
({
'a'
:
1
,
'b'
:
{
'x'
:
[
0
,
-
1
]},
})).
all
()
assert
(
t3_b
==
ttorch
.
tensor
({
'a'
:
0
,
'b'
:
{
'x'
:
[
0
,
1
]},
})).
all
()
@
choose_mark
()
def
test_sum
(
self
):
t
1
=
ttorch
.
Tensor
({
t
0
=
ttorch
.
Tensor
({
'a'
:
[
1
,
2
],
'b'
:
{
'x'
:
[[
0
,
3
],
[
2
,
-
1
]]}
}).
sum
()
})
t1
=
t0
.
sum
()
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
t1
.
tolist
()
==
7
assert
(
t1
==
ttorch
.
tensor
(
7
)).
all
()
t2
=
t0
.
sum
(
reduce
=
False
)
assert
(
t2
==
ttorch
.
tensor
({
'a'
:
3
,
'b'
:
{
'x'
:
4
}})).
all
()
t3
=
t0
.
sum
(
dim
=
0
)
assert
(
t3
==
ttorch
.
tensor
({
'a'
:
3
,
'b'
:
{
'x'
:
[
2
,
2
]},
})).
all
()
@
choose_mark
()
def
test_mean
(
self
):
...
...
treetensor/torch/base/torch.py
浏览文件 @
ca149e3f
...
...
@@ -11,5 +11,13 @@ class Torch(BaseTreeStruct):
pass
def
auto_torch
(
value
,
cls
:
Type
[
Torch
]):
return
typetrans
(
value
,
cls
)
if
isinstance
(
value
,
TreeValue
)
else
value
# noinspection PyArgumentList
def
auto_torch
(
v
,
cls
:
Type
[
Torch
]):
if
isinstance
(
v
,
TreeValue
):
return
typetrans
(
v
,
cls
)
elif
isinstance
(
v
,
(
tuple
,
list
,
set
)):
return
type
(
v
)((
auto_torch
(
item
,
cls
)
for
item
in
v
))
elif
isinstance
(
v
,
dict
):
return
type
(
v
)({
key
:
auto_torch
(
value
,
cls
)
for
key
,
value
in
v
.
items
()})
else
:
return
v
treetensor/torch/funcs/operation.py
浏览文件 @
ca149e3f
...
...
@@ -117,7 +117,8 @@ def cat(tensors, *args, **kwargs):
# noinspection PyShadowingNames
@
doc_from_base
()
@
post_process
(
lambda
r
:
tuple
(
map
(
auto_tensor
,
r
)))
@
post_process
(
lambda
r
:
tuple
(
r
))
@
post_process
(
auto_tensor
)
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
dict
(
template
=
[
None
]))
@
post_process
(
lambda
r
:
list
(
r
))
def
split
(
tensor
,
split_size_or_sections
,
*
args
,
**
kwargs
):
...
...
@@ -207,7 +208,8 @@ def split(tensor, split_size_or_sections, *args, **kwargs):
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
post_process
(
lambda
r
:
tuple
(
map
(
auto_tensor
,
r
)))
@
post_process
(
lambda
r
:
tuple
(
r
))
@
post_process
(
auto_tensor
)
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
dict
(
template
=
[
None
]))
@
post_process
(
lambda
r
:
list
(
r
))
def
chunk
(
input
,
chunks
,
*
args
,
**
kwargs
):
...
...
treetensor/torch/funcs/reduction.py
浏览文件 @
ca149e3f
import
torch
from
treevalue
import
TreeValue
from
treevalue.utils
import
post_process
from
.base
import
doc_from_base
,
func_treelize
from
.base
import
doc_from_base
,
func_treelize
,
auto_tensor
from
..base
import
rmreduce
,
post_reduce
,
auto_reduce
from
...common
import
Object
...
...
@@ -93,29 +95,39 @@ def any(input, *args, reduce=None, **kwargs):
>>> ttorch.any(ttorch.tensor({'a': [False, False], 'b': {'x': [False, False]}}))
tensor(False)
.. note::
In this ``any`` function, the return value should be a tensor with single boolean value.
If what you need is a tree of boolean tensors, you should do like this
>>> ttorch.any(ttorch.tensor({'a': [True, False], 'b': {'x': [False, False]}}), reduce=False)
<Tensor 0x7fd45b52d518>
├── a --> tensor(True)
└── b --> <Tensor 0x7fd45b52d470>
└── x --> tensor(False)
>>> ttorch.tensor({
>>> 'a': [True, False],
>>> 'b': {'x': [False, False]},
>>> }).map(lambda x: torch.any(x))
<Tensor 0x7ff363bc6898>
├── a --> tensor(True)
└── b --> <Tensor 0x7ff363bc67f0>
└── x --> tensor(False)
>>> ttorch.any(ttorch.tensor({'a': [False, False], 'b': {'x': [False, False]}}), dim=0)
<Tensor 0x7fd45b534128>
├── a --> tensor(False)
└── b --> <Tensor 0x7fd45b534080>
└── x --> tensor(False)
"""
pass
# pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_reduce
(
torch
.
min
)
@
func_treelize
(
return_type
=
Object
)
def
_min_r
(
input
,
*
args
,
**
kwargs
):
return
input
# noinspection PyShadowingBuiltins
@
post_process
(
auto_tensor
)
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
_min_nr
(
input
,
*
args
,
**
kwargs
):
return
torch
.
min
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
doc_from_base
()
@
rmreduce
(
torch
.
min
)
@
func_treelize
(
return_type
=
Object
)
def
min
(
input
,
*
args
,
**
kwargs
):
@
auto_reduce
(
_min_r
,
_min_nr
)
def
min
(
input
,
*
args
,
reduce
=
None
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``min`` result of a whole tree with this function.
...
...
@@ -132,29 +144,52 @@ def min(input, *args, **kwargs):
... }))
tensor(0.9000)
.. note::
>>> ttorch.min(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), reduce=False)
<Tensor 0x7fd45b5913c8>
├── a --> tensor(1.)
└── b --> <Tensor 0x7fd45b5912e8>
└── x --> tensor(0.9000)
In this ``min`` function, the return value should be a tensor with single value.
>>> ttorch.min(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), dim=0)
torch.return_types.min(
values=<Tensor 0x7fd45b52d2e8>
├── a --> tensor(1.)
└── b --> <Tensor 0x7fd45b52d208>
└── x --> tensor([1.3000, 0.9000])
,
indices=<Tensor 0x7fd45b591cc0>
├── a --> tensor(0)
└── b --> <Tensor 0x7fd45b52d3c8>
└── x --> tensor([1, 0])
)
"""
pass
# pragma: no cover
If what you need is a tree of tensors, you should do like this
>>> ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }).map(lambda x: torch.min(x))
<Tensor 0x7ff363bbb2b0>
├── a --> tensor(1.)
└── b --> <Tensor 0x7ff363bbb0b8>
└── x --> tensor(0.9000)
"""
return
torch
.
min
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_reduce
(
torch
.
max
)
@
func_treelize
(
return_type
=
Object
)
def
_max_r
(
input
,
*
args
,
**
kwargs
):
return
input
# noinspection PyShadowingBuiltins
@
post_process
(
auto_tensor
)
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
_max_nr
(
input
,
*
args
,
**
kwargs
):
return
torch
.
max
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
doc_from_base
()
@
rmreduce
(
torch
.
max
)
@
func_treelize
(
return_type
=
Object
)
def
max
(
input
,
*
args
,
**
kwargs
):
@
auto_reduce
(
_max_r
,
_max_nr
)
def
max
(
input
,
*
args
,
reduce
=
None
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``max`` result of a whole tree with this function.
...
...
@@ -171,29 +206,51 @@ def max(input, *args, **kwargs):
... }))
tensor(2.5000)
.. note::
>>> ttorch.max(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), reduce=False)
<Tensor 0x7fd45b52d940>
├── a --> tensor(2.)
└── b --> <Tensor 0x7fd45b52d908>
└── x --> tensor(2.5000)
In this ``max`` function, the return value should be a tensor with single value.
>>> ttorch.max(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), dim=0)
torch.return_types.max(
values=<Tensor 0x7fd45b5345f8>
├── a --> tensor(2.)
└── b --> <Tensor 0x7fd45b5345c0>
└── x --> tensor([1.8000, 2.5000])
,
indices=<Tensor 0x7fd45b5346d8>
├── a --> tensor(1)
└── b --> <Tensor 0x7fd45b5346a0>
└── x --> tensor([0, 1])
)
"""
pass
# pragma: no cover
If what you need is a tree of tensors, you should do like this
>>> ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }).map(lambda x: torch.max(x))
<Tensor 0x7ff363bc6b00>
├── a --> tensor(2.)
└── b --> <Tensor 0x7ff363bc6c18>
└── x --> tensor(2.5000)
"""
return
torch
.
max
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_reduce
(
torch
.
sum
)
@
func_treelize
(
return_type
=
Object
)
def
_sum_r
(
input
,
*
args
,
**
kwargs
):
return
input
# noinspection PyShadowingBuiltins
@
func_treelize
()
def
_sum_nr
(
input
,
*
args
,
**
kwargs
):
return
torch
.
sum
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
doc_from_base
()
@
rmreduce
(
torch
.
sum
)
@
func_treelize
(
return_type
=
Object
)
def
sum
(
input
,
*
args
,
**
kwargs
):
@
auto_reduce
(
_sum_r
,
_sum_nr
)
def
sum
(
input
,
*
args
,
reduce
=
None
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``sum`` result of a whole tree with this function.
...
...
@@ -210,22 +267,25 @@ def sum(input, *args, **kwargs):
... }))
tensor(11.)
.. note::
In this ``sum`` function, the return value should be a tensor with single value.
If what you need is a tree of tensors, you should do like this
>>> ttorch.sum(ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), reduce=False)
<Tensor 0x7fd45b534898>
├── a --> tensor(4.5000)
└── b --> <Tensor 0x7fd45b5344e0>
└── x --> tensor(6.5000)
>>>
ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }).map(lambda x: torch.sum(x)
)
<Tensor 0x7ff363bbbda0
>
├── a --> tensor(4.5000)
└── b --> <Tensor 0x7ff363bbbcf8
>
└── x --> tensor(6.5000
)
>>> ttorch.sum(
ttorch.tensor({
... 'a': [1.0, 2.0, 1.5],
... 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
... }), dim=0
)
<Tensor 0x7f3640703128
>
├── a --> tensor(4.5000)
└── b --> <Tensor 0x7f3640703080
>
└── x --> tensor([3.1000, 3.4000]
)
"""
return
torch
.
sum
(
input
,
*
args
,
**
kwargs
)
pass
# pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
...
...
treetensor/torch/tensor.py
浏览文件 @
ca149e3f
...
...
@@ -224,32 +224,65 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
pass
# pragma: no cover
@
doc_from_base
()
@
rm
reduce
(
torch
.
max
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_
reduce
(
torch
.
max
)
@
method_treelize
(
return_type
=
Object
)
def
max
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
def
__max_r
(
self
,
*
args
,
**
kwargs
):
return
self
# noinspection PyShadowingBuiltins
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_torch
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
__max_nr
(
self
,
*
args
,
**
kwargs
):
return
torch
.
max
(
self
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
auto_reduce
(
__max_r
,
__max_nr
)
def
max
(
self
:
torch
.
Tensor
,
*
args
,
reduce
=
None
,
**
kwargs
):
"""
See :func:`treetensor.torch.max`
"""
return
self
.
max
(
*
args
,
**
kwargs
)
pass
# pragma: no cover
@
doc_from_base
()
@
rm
reduce
(
torch
.
min
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_
reduce
(
torch
.
min
)
@
method_treelize
(
return_type
=
Object
)
def
min
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
def
__min_r
(
self
,
*
args
,
**
kwargs
):
return
self
# noinspection PyShadowingBuiltins
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_torch
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
__min_nr
(
self
,
*
args
,
**
kwargs
):
return
torch
.
min
(
self
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
auto_reduce
(
__min_r
,
__min_nr
)
def
min
(
self
:
torch
.
Tensor
,
*
args
,
reduce
=
None
,
**
kwargs
):
"""
See :func:`treetensor.torch.min`
"""
return
self
.
min
(
*
args
,
**
kwargs
)
pass
# pragma: no cover
@
doc_from_base
()
@
rm
reduce
(
torch
.
sum
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_
reduce
(
torch
.
sum
)
@
method_treelize
(
return_type
=
Object
)
def
sum
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
def
__sum_r
(
self
,
*
args
,
**
kwargs
):
return
self
# noinspection PyShadowingBuiltins
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_torch
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
__sum_nr
(
self
,
*
args
,
**
kwargs
):
return
torch
.
sum
(
self
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
auto_reduce
(
__sum_r
,
__sum_nr
)
def
sum
(
self
:
torch
.
Tensor
,
*
args
,
reduce
=
None
,
**
kwargs
):
"""
See :func:`treetensor.torch.sum`
"""
return
self
.
sum
(
*
args
,
**
kwargs
)
pass
# pragma: no cover
@
method_treelize
()
def
__eq__
(
self
,
other
):
...
...
@@ -681,7 +714,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return
self
.
log10_
(
*
args
,
**
kwargs
)
@
doc_from_base
()
@
post_process
(
lambda
r
:
tuple
(
map
(
replaceable_partial
(
auto_torch
,
cls
=
Tensor
),
r
)
))
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_torch
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
dict
(
template
=
[
None
]))
@
post_process
(
lambda
r
:
list
(
r
))
def
split
(
self
,
split_size
,
*
args
,
**
kwargs
):
...
...
@@ -691,7 +724,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return
self
.
split
(
split_size
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
post_process
(
lambda
r
:
tuple
(
map
(
replaceable_partial
(
auto_torch
,
cls
=
Tensor
),
r
)
))
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_torch
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
dict
(
template
=
[
None
]))
@
post_process
(
lambda
r
:
list
(
r
))
def
chunk
(
self
,
chunks
,
*
args
,
**
kwargs
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录