Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
6350b914
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6350b914
编写于
9月 26, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev, doc, test(hansbug): complete norm, dist, std, mean, chunk
上级
596fa782
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
523 addition
and
2 deletion
+523
-2
test/torch/funcs/test_math.py
test/torch/funcs/test_math.py
+78
-0
test/torch/funcs/test_operation.py
test/torch/funcs/test_operation.py
+43
-0
test/torch/tensor/test_math.py
test/torch/tensor/test_math.py
+83
-1
test/torch/tensor/test_operation.py
test/torch/tensor/test_operation.py
+43
-0
treetensor/torch/funcs/math.py
treetensor/torch/funcs/math.py
+171
-0
treetensor/torch/funcs/operation.py
treetensor/torch/funcs/operation.py
+63
-1
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+42
-0
未找到文件。
test/torch/funcs/test_math.py
浏览文件 @
6350b914
...
...
@@ -721,3 +721,81 @@ class TestTorchFuncsMath:
'b'
:
{
'x'
:
[[
math
.
nan
,
0.0792
,
-
0.6021
],
[
1.2041
,
0.5740
,
math
.
nan
]]},
}),
rtol
=
1e-4
,
atol
=
1e-4
,
equal_nan
=
True
).
all
()
@
choose_mark
()
def
test_std
(
self
):
t1
=
torch
.
tensor
([[
25.5133
,
24.2050
,
8.1067
],
[
22.7316
,
-
17.8863
,
-
37.9171
]]).
std
()
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
ttorch
.
isclose
(
t1
,
torch
.
tensor
(
26.3619
),
atol
=
1e-4
).
all
()
tt1
=
ttorch
.
tensor
({
'a'
:
[[
-
48.6580
,
30.9506
,
-
16.1800
],
[
37.6667
,
10.3850
,
-
5.7679
]],
'b'
:
{
'x'
:
[[
-
17.9371
,
8.4873
,
-
49.0445
,
4.7368
],
[
21.3990
,
-
11.2385
,
-
15.9331
,
-
41.6838
],
[
-
7.1814
,
-
38.1301
,
-
2.2320
,
10.1392
]]},
}).
std
()
assert
ttorch
.
isclose
(
tt1
,
ttorch
.
tensor
({
'a'
:
32.0483
,
'b'
:
{
'x'
:
22.1754
},
}),
atol
=
1e-4
).
all
()
@
choose_mark
()
def
test_mean
(
self
):
t1
=
torch
.
tensor
([[
11.8069
,
16.7822
,
-
11.8583
],
[
-
10.0426
,
38.7326
,
30.0298
]]).
mean
()
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
ttorch
.
isclose
(
t1
,
torch
.
tensor
(
12.5751
),
atol
=
1e-4
).
all
()
tt1
=
ttorch
.
tensor
({
'a'
:
[[
-
29.3862
,
10.3668
,
-
19.8407
],
[
11.3299
,
-
0.7511
,
-
13.8404
]],
'b'
:
{
'x'
:
[[
-
25.1722
,
22.6307
,
-
9.3588
,
-
6.8217
],
[
-
31.4652
,
6.6465
,
36.9483
,
-
4.0487
],
[
-
17.2146
,
24.0029
,
35.4574
,
-
29.2970
]]},
}).
mean
()
assert
ttorch
.
isclose
(
tt1
,
ttorch
.
tensor
({
'a'
:
-
7.0203
,
'b'
:
{
'x'
:
0.1923
},
}),
atol
=
1e-4
).
all
()
@
choose_mark
()
def
test_dist
(
self
):
t1
=
torch
.
tensor
([
-
0.6566
,
1.2243
,
1.5018
,
-
0.1492
,
0.8947
]).
dist
(
torch
.
tensor
([
0.5898
,
0.6839
,
0.0388
,
0.4649
,
0.7964
]),
)
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
ttorch
.
isclose
(
t1
,
torch
.
tensor
(
2.0911
),
atol
=
1e-4
).
all
()
tt1
=
ttorch
.
tensor
({
'a'
:
[
-
0.5491
,
1.5006
,
-
0.0483
,
1.2282
,
-
1.4837
],
'b'
:
{
'x'
:
[
-
1.8414
,
1.2913
,
0.0943
,
0.3473
,
1.2717
,
0.6013
]},
}).
dist
(
ttorch
.
tensor
({
'a'
:
[
0.1389
,
-
0.7804
,
-
1.3048
,
-
1.1066
,
1.3225
],
'b'
:
{
'x'
:
[
1.4873
,
0.2218
,
-
0.1063
,
-
0.8726
,
-
0.6756
,
0.4805
]},
}))
assert
ttorch
.
isclose
(
tt1
,
ttorch
.
tensor
({
'a'
:
4.5366
,
'b'
:
{
'x'
:
4.1904
}
}),
atol
=
1e-4
).
all
()
@
choose_mark
()
def
test_norm
(
self
):
t1
=
torch
.
tensor
([[
0.0363
,
-
1.7385
,
1.0669
,
2.6967
],
[
0.0848
,
0.2735
,
0.3538
,
0.2271
],
[
-
0.1014
,
1.1351
,
-
0.5761
,
-
1.2671
]]).
norm
()
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
ttorch
.
isclose
(
t1
,
torch
.
tensor
(
3.8638
),
atol
=
1e-4
).
all
()
tt1
=
ttorch
.
tensor
({
'a'
:
[[
-
0.5012
,
2.0900
,
0.0151
],
[
-
0.5035
,
0.2144
,
0.8370
]],
'b'
:
{
'x'
:
[[
0.3911
,
0.3557
,
-
2.2156
,
0.3653
],
[
-
0.3503
,
1.2182
,
-
0.2364
,
-
0.2854
],
[
-
1.5770
,
-
0.7349
,
0.8391
,
-
0.2845
]]},
}).
norm
()
assert
ttorch
.
isclose
(
tt1
,
ttorch
.
tensor
({
'a'
:
2.3706
,
'b'
:
{
'x'
:
3.2982
},
}),
atol
=
1e-4
).
all
()
test/torch/funcs/test_operation.py
浏览文件 @
6350b914
...
...
@@ -138,6 +138,49 @@ class TestTorchFuncsOperation:
[
58
,
54
,
78
]]]
})).
all
()
@
choose_mark
()
def
test_chunk
(
self
):
t
=
torch
.
tensor
([[
54
,
97
,
12
,
48
,
62
],
[
92
,
87
,
28
,
53
,
54
],
[
65
,
82
,
40
,
26
,
61
],
[
75
,
43
,
86
,
99
,
7
]])
t_a
,
t_b
=
ttorch
.
chunk
(
t
,
2
)
assert
isinstance
(
t_a
,
torch
.
Tensor
)
assert
isinstance
(
t_b
,
torch
.
Tensor
)
assert
(
t_a
==
torch
.
tensor
([[
54
,
97
,
12
,
48
,
62
],
[
92
,
87
,
28
,
53
,
54
]])).
all
()
assert
(
t_b
==
torch
.
tensor
([[
65
,
82
,
40
,
26
,
61
],
[
75
,
43
,
86
,
99
,
7
]])).
all
()
tt
=
ttorch
.
tensor
({
'a'
:
[[
80
,
2
,
15
,
45
,
48
],
[
38
,
89
,
34
,
10
,
34
],
[
18
,
99
,
33
,
38
,
20
],
[
43
,
21
,
35
,
43
,
37
]],
'b'
:
{
'x'
:
[[[
19
,
17
,
39
,
68
],
[
41
,
69
,
33
,
89
],
[
31
,
88
,
39
,
14
]],
[[
27
,
81
,
84
,
35
],
[
29
,
65
,
17
,
72
],
[
53
,
50
,
75
,
0
]]]},
})
tt_a
,
tt_b
=
ttorch
.
chunk
(
tt
,
2
)
assert
(
tt_a
==
ttorch
.
tensor
({
'a'
:
[[
80
,
2
,
15
,
45
,
48
],
[
38
,
89
,
34
,
10
,
34
]],
'b'
:
{
'x'
:
[[[
19
,
17
,
39
,
68
],
[
41
,
69
,
33
,
89
],
[
31
,
88
,
39
,
14
]]]},
})).
all
()
assert
(
tt_b
==
ttorch
.
tensor
({
'a'
:
[[
18
,
99
,
33
,
38
,
20
],
[
43
,
21
,
35
,
43
,
37
]],
'b'
:
{
'x'
:
[[[
27
,
81
,
84
,
35
],
[
29
,
65
,
17
,
72
],
[
53
,
50
,
75
,
0
]]]},
})).
all
()
@
choose_mark
()
def
test_stack
(
self
):
t1
=
torch
.
tensor
([[
17
,
15
,
27
],
...
...
test/torch/tensor/test_math.py
浏览文件 @
6350b914
...
...
@@ -6,7 +6,7 @@ import treetensor.torch as ttorch
from
.base
import
choose_mark
# noinspection PyUnresolvedReferences
# noinspection PyUnresolvedReferences
,DuplicatedCode
class
TestTorchTensorMath
:
@
choose_mark
()
def
test_abs
(
self
):
...
...
@@ -889,3 +889,85 @@ class TestTorchTensorMath:
'b'
:
{
'x'
:
[[
math
.
nan
,
0.0792
,
-
0.6021
],
[
1.2041
,
0.5740
,
math
.
nan
]]},
}),
rtol
=
1e-4
,
atol
=
1e-4
,
equal_nan
=
True
).
all
()
@
choose_mark
()
def
test_std
(
self
):
t1
=
ttorch
.
std
(
torch
.
tensor
([[
25.5133
,
24.2050
,
8.1067
],
[
22.7316
,
-
17.8863
,
-
37.9171
]]))
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
ttorch
.
isclose
(
t1
,
torch
.
tensor
(
26.3619
),
atol
=
1e-4
).
all
()
tt1
=
ttorch
.
std
(
ttorch
.
tensor
({
'a'
:
[[
-
48.6580
,
30.9506
,
-
16.1800
],
[
37.6667
,
10.3850
,
-
5.7679
]],
'b'
:
{
'x'
:
[[
-
17.9371
,
8.4873
,
-
49.0445
,
4.7368
],
[
21.3990
,
-
11.2385
,
-
15.9331
,
-
41.6838
],
[
-
7.1814
,
-
38.1301
,
-
2.2320
,
10.1392
]]},
}))
assert
ttorch
.
isclose
(
tt1
,
ttorch
.
tensor
({
'a'
:
32.0483
,
'b'
:
{
'x'
:
22.1754
},
}),
atol
=
1e-4
).
all
()
@
choose_mark
()
def
test_mean
(
self
):
t1
=
ttorch
.
mean
(
torch
.
tensor
([[
11.8069
,
16.7822
,
-
11.8583
],
[
-
10.0426
,
38.7326
,
30.0298
]]))
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
ttorch
.
isclose
(
t1
,
torch
.
tensor
(
12.5751
),
atol
=
1e-4
).
all
()
tt1
=
ttorch
.
mean
(
ttorch
.
tensor
({
'a'
:
[[
-
29.3862
,
10.3668
,
-
19.8407
],
[
11.3299
,
-
0.7511
,
-
13.8404
]],
'b'
:
{
'x'
:
[[
-
25.1722
,
22.6307
,
-
9.3588
,
-
6.8217
],
[
-
31.4652
,
6.6465
,
36.9483
,
-
4.0487
],
[
-
17.2146
,
24.0029
,
35.4574
,
-
29.2970
]]},
}))
assert
ttorch
.
isclose
(
tt1
,
ttorch
.
tensor
({
'a'
:
-
7.0203
,
'b'
:
{
'x'
:
0.1923
},
}),
atol
=
1e-4
).
all
()
@
choose_mark
()
def
test_dist
(
self
):
t1
=
ttorch
.
dist
(
torch
.
tensor
([
-
0.6566
,
1.2243
,
1.5018
,
-
0.1492
,
0.8947
]),
torch
.
tensor
([
0.5898
,
0.6839
,
0.0388
,
0.4649
,
0.7964
]),
)
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
ttorch
.
isclose
(
t1
,
torch
.
tensor
(
2.0911
),
atol
=
1e-4
).
all
()
tt1
=
ttorch
.
dist
(
ttorch
.
tensor
({
'a'
:
[
-
0.5491
,
1.5006
,
-
0.0483
,
1.2282
,
-
1.4837
],
'b'
:
{
'x'
:
[
-
1.8414
,
1.2913
,
0.0943
,
0.3473
,
1.2717
,
0.6013
]},
}),
ttorch
.
tensor
({
'a'
:
[
0.1389
,
-
0.7804
,
-
1.3048
,
-
1.1066
,
1.3225
],
'b'
:
{
'x'
:
[
1.4873
,
0.2218
,
-
0.1063
,
-
0.8726
,
-
0.6756
,
0.4805
]},
})
)
assert
ttorch
.
isclose
(
tt1
,
ttorch
.
tensor
({
'a'
:
4.5366
,
'b'
:
{
'x'
:
4.1904
}
}),
atol
=
1e-4
).
all
()
@
choose_mark
()
def
test_norm
(
self
):
t1
=
ttorch
.
norm
(
torch
.
tensor
([[
0.0363
,
-
1.7385
,
1.0669
,
2.6967
],
[
0.0848
,
0.2735
,
0.3538
,
0.2271
],
[
-
0.1014
,
1.1351
,
-
0.5761
,
-
1.2671
]]))
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
ttorch
.
isclose
(
t1
,
torch
.
tensor
(
3.8638
),
atol
=
1e-4
).
all
()
tt1
=
ttorch
.
norm
(
ttorch
.
tensor
({
'a'
:
[[
-
0.5012
,
2.0900
,
0.0151
],
[
-
0.5035
,
0.2144
,
0.8370
]],
'b'
:
{
'x'
:
[[
0.3911
,
0.3557
,
-
2.2156
,
0.3653
],
[
-
0.3503
,
1.2182
,
-
0.2364
,
-
0.2854
],
[
-
1.5770
,
-
0.7349
,
0.8391
,
-
0.2845
]]},
}))
assert
ttorch
.
isclose
(
tt1
,
ttorch
.
tensor
({
'a'
:
2.3706
,
'b'
:
{
'x'
:
3.2982
},
}),
atol
=
1e-4
).
all
()
test/torch/tensor/test_operation.py
浏览文件 @
6350b914
...
...
@@ -78,6 +78,49 @@ class TestTorchTensorOperation:
[
58
,
54
,
78
]]]
})).
all
()
@
choose_mark
()
def
test_chunk
(
self
):
t
=
torch
.
tensor
([[
54
,
97
,
12
,
48
,
62
],
[
92
,
87
,
28
,
53
,
54
],
[
65
,
82
,
40
,
26
,
61
],
[
75
,
43
,
86
,
99
,
7
]])
t_a
,
t_b
=
t
.
chunk
(
2
)
assert
isinstance
(
t_a
,
torch
.
Tensor
)
assert
isinstance
(
t_b
,
torch
.
Tensor
)
assert
(
t_a
==
torch
.
tensor
([[
54
,
97
,
12
,
48
,
62
],
[
92
,
87
,
28
,
53
,
54
]])).
all
()
assert
(
t_b
==
torch
.
tensor
([[
65
,
82
,
40
,
26
,
61
],
[
75
,
43
,
86
,
99
,
7
]])).
all
()
tt
=
ttorch
.
tensor
({
'a'
:
[[
80
,
2
,
15
,
45
,
48
],
[
38
,
89
,
34
,
10
,
34
],
[
18
,
99
,
33
,
38
,
20
],
[
43
,
21
,
35
,
43
,
37
]],
'b'
:
{
'x'
:
[[[
19
,
17
,
39
,
68
],
[
41
,
69
,
33
,
89
],
[
31
,
88
,
39
,
14
]],
[[
27
,
81
,
84
,
35
],
[
29
,
65
,
17
,
72
],
[
53
,
50
,
75
,
0
]]]},
})
tt_a
,
tt_b
=
tt
.
chunk
(
2
)
assert
(
tt_a
==
ttorch
.
tensor
({
'a'
:
[[
80
,
2
,
15
,
45
,
48
],
[
38
,
89
,
34
,
10
,
34
]],
'b'
:
{
'x'
:
[[[
19
,
17
,
39
,
68
],
[
41
,
69
,
33
,
89
],
[
31
,
88
,
39
,
14
]]]},
})).
all
()
assert
(
tt_b
==
ttorch
.
tensor
({
'a'
:
[[
18
,
99
,
33
,
38
,
20
],
[
43
,
21
,
35
,
43
,
37
]],
'b'
:
{
'x'
:
[[[
27
,
81
,
84
,
35
],
[
29
,
65
,
17
,
72
],
[
53
,
50
,
75
,
0
]]]},
})).
all
()
@
choose_mark
()
def
test_reshape
(
self
):
t1
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]]).
reshape
((
-
1
,))
...
...
treetensor/torch/funcs/math.py
浏览文件 @
6350b914
...
...
@@ -9,6 +9,7 @@ __all__ = [
'add'
,
'sub'
,
'mul'
,
'div'
,
'pow'
,
'neg'
,
'neg_'
,
'exp'
,
'exp_'
,
'exp2'
,
'exp2_'
,
'sqrt'
,
'sqrt_'
,
'log'
,
'log_'
,
'log2'
,
'log2_'
,
'log10'
,
'log10_'
,
'mean'
,
'std'
,
'dist'
,
'norm'
,
]
...
...
@@ -1076,3 +1077,173 @@ def log10_(input):
[ 1.2041, 0.5740, nan]])
"""
return
torch
.
log10_
(
input
)
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
func_treelize
()
def
std
(
input
,
*
args
,
**
kwargs
):
"""
Returns the standard-deviation of all elements in the ``input`` tensor.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t = torch.randn((2, 3)) * 30
>>> t
tensor([[ 25.5133, 24.2050, 8.1067],
[ 22.7316, -17.8863, -37.9171]])
>>> ttorch.std(t)
tensor(26.3619)
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... }) * 30
>>> tt
<Tensor 0x7f7c7288ca58>
├── a --> tensor([[-48.6580, 30.9506, -16.1800],
│ [ 37.6667, 10.3850, -5.7679]])
└── b --> <Tensor 0x7f7c7288c978>
└── x --> tensor([[-17.9371, 8.4873, -49.0445, 4.7368],
[ 21.3990, -11.2385, -15.9331, -41.6838],
[ -7.1814, -38.1301, -2.2320, 10.1392]])
>>> ttorch.std(tt)
<Tensor 0x7f7c7288c470>
├── a --> tensor(32.0483)
└── b --> <Tensor 0x7f7c7288c3c8>
└── x --> tensor(22.1754)
.. note::
Reduction will not be processed in :func:`treetensor.torch.std`.
It means the result should be a tree of tensors instead of one tensor.
"""
return
torch
.
std
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
func_treelize
()
def
mean
(
input
,
*
args
,
**
kwargs
):
"""
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t = torch.randn((2, 3)) * 30
>>> t
tensor([[ 11.8069, 16.7822, -11.8583],
[-10.0426, 38.7326, 30.0298]])
>>> ttorch.mean(t)
tensor(12.5751)
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... }) * 30
>>> tt
<Tensor 0x7f95f684f6a0>
├── a --> tensor([[-29.3862, 10.3668, -19.8407],
│ [ 11.3299, -0.7511, -13.8404]])
└── b --> <Tensor 0x7f95f684f828>
└── x --> tensor([[-25.1722, 22.6307, -9.3588, -6.8217],
[-31.4652, 6.6465, 36.9483, -4.0487],
[-17.2146, 24.0029, 35.4574, -29.2970]])
>>> ttorch.mean(tt)
<Tensor 0x7f95f6849e80>
├── a --> tensor(-7.0203)
└── b --> <Tensor 0x7f95f6849470>
└── x --> tensor(0.1923)
.. note::
Reduction will not be processed in :func:`treetensor.torch.std`.
It means the result should be a tree of tensors instead of one tensor.
"""
return
torch
.
mean
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
func_treelize
()
def
dist
(
input
,
other
,
*
args
,
**
kwargs
):
"""
Returns the p-norm of (``input`` - ``other``)
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randn(5)
>>> t1
tensor([-0.6566, 1.2243, 1.5018, -0.1492, 0.8947])
>>> t2 = torch.randn(5)
>>> t2
tensor([0.5898, 0.6839, 0.0388, 0.4649, 0.7964])
>>> ttorch.dist(t1, t2)
tensor(2.0911)
>>> tt1 = ttorch.randn({'a': (5, ), 'b': {'x': (6, )}})
>>> tt1
<Tensor 0x7f95f68495f8>
├── a --> tensor([-0.5491, 1.5006, -0.0483, 1.2282, -1.4837])
└── b --> <Tensor 0x7f95f68494e0>
└── x --> tensor([-1.8414, 1.2913, 0.0943, 0.3473, 1.2717, 0.6013])
>>> tt2 = ttorch.randn({'a': (5, ), 'b': {'x': (6, )}})
>>> tt2
<Tensor 0x7f95f68ef2b0>
├── a --> tensor([ 0.1389, -0.7804, -1.3048, -1.1066, 1.3225])
└── b --> <Tensor 0x7f95f6849dd8>
└── x --> tensor([ 1.4873, 0.2218, -0.1063, -0.8726, -0.6756, 0.4805])
>>> ttorch.dist(tt1, tt2)
<Tensor 0x7f95f6849358>
├── a --> tensor(4.5366)
└── b --> <Tensor 0x7f95f68494a8>
└── x --> tensor(4.1904)
"""
return
torch
.
dist
(
input
,
other
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
func_treelize
()
def
norm
(
input
,
*
args
,
**
kwargs
):
"""
Returns the matrix norm or vector norm of a given tensor.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randn(3, 4)
>>> t1
tensor([[ 0.0363, -1.7385, 1.0669, 2.6967],
[ 0.0848, 0.2735, 0.3538, 0.2271],
[-0.1014, 1.1351, -0.5761, -1.2671]])
>>> ttorch.norm(t1)
tensor(3.8638)
>>> tt1 = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt1
<Tensor 0x7f95f684f4a8>
├── a --> tensor([[-0.5012, 2.0900, 0.0151],
│ [-0.5035, 0.2144, 0.8370]])
└── b --> <Tensor 0x7f95f684f400>
└── x --> tensor([[ 0.3911, 0.3557, -2.2156, 0.3653],
[-0.3503, 1.2182, -0.2364, -0.2854],
[-1.5770, -0.7349, 0.8391, -0.2845]])
>>> ttorch.norm(tt1)
<Tensor 0x7f95f684fa20>
├── a --> tensor(2.3706)
└── b --> <Tensor 0x7f95f684f978>
└── x --> tensor(3.2982)
"""
return
torch
.
norm
(
input
,
*
args
,
**
kwargs
)
treetensor/torch/funcs/operation.py
浏览文件 @
6350b914
...
...
@@ -5,7 +5,8 @@ from treevalue.utils import post_process
from
.base
import
doc_from_base
,
func_treelize
,
auto_tensor
__all__
=
[
'cat'
,
'split'
,
'stack'
,
'reshape'
,
'where'
,
'squeeze'
,
'unsqueeze'
,
'cat'
,
'split'
,
'chunk'
,
'stack'
,
'reshape'
,
'where'
,
'squeeze'
,
'unsqueeze'
,
]
...
...
@@ -203,6 +204,67 @@ def split(tensor, split_size_or_sections, *args, **kwargs):
return
torch
.
split
(
tensor
,
split_size_or_sections
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
post_process
(
lambda
r
:
tuple
(
map
(
auto_tensor
,
r
)))
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
dict
(
template
=
[
None
]))
@
post_process
(
lambda
r
:
list
(
r
))
def
chunk
(
input
,
chunks
,
*
args
,
**
kwargs
):
"""
Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t = torch.randint(100, (4, 5))
>>> t
tensor([[54, 97, 12, 48, 62],
[92, 87, 28, 53, 54],
[65, 82, 40, 26, 61],
[75, 43, 86, 99, 7]])
>>> ttorch.chunk(t, 2)
(tensor([[54, 97, 12, 48, 62],
[92, 87, 28, 53, 54]]), tensor([[65, 82, 40, 26, 61],
[75, 43, 86, 99, 7]]))
>>> tt = ttorch.randint(100, {
... 'a': (4, 5),
... 'b': {'x': (2, 3, 4)},
... })
>>> tt
<Tensor 0x7f667e2fb358>
├── a --> tensor([[80, 2, 15, 45, 48],
│ [38, 89, 34, 10, 34],
│ [18, 99, 33, 38, 20],
│ [43, 21, 35, 43, 37]])
└── b --> <Tensor 0x7f667e2fb278>
└── x --> tensor([[[19, 17, 39, 68],
[41, 69, 33, 89],
[31, 88, 39, 14]],
[[27, 81, 84, 35],
[29, 65, 17, 72],
[53, 50, 75, 0]]])
>>> ttorch.chunk(tt, 2)
(<Tensor 0x7f667e9b7eb8>
├── a --> tensor([[80, 2, 15, 45, 48],
│ [38, 89, 34, 10, 34]])
└── b --> <Tensor 0x7f667e2e7cf8>
└── x --> tensor([[[19, 17, 39, 68],
[41, 69, 33, 89],
[31, 88, 39, 14]]])
, <Tensor 0x7f66f176dac8>
├── a --> tensor([[18, 99, 33, 38, 20],
│ [43, 21, 35, 43, 37]])
└── b --> <Tensor 0x7f668030ba58>
└── x --> tensor([[[27, 81, 84, 35],
[29, 65, 17, 72],
[53, 50, 75, 0]]])
"""
return
torch
.
chunk
(
input
,
chunks
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
func_treelize
(
subside
=
dict
(
return_type
=
TreeValue
))
def
stack
(
tensors
,
*
args
,
**
kwargs
):
...
...
treetensor/torch/tensor.py
浏览文件 @
6350b914
...
...
@@ -662,6 +662,16 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return
self
.
split
(
split_size
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
post_process
(
lambda
r
:
tuple
(
map
(
replaceable_partial
(
_auto_torch
,
cls
=
Tensor
),
r
)))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
dict
(
template
=
[
None
]))
@
post_process
(
lambda
r
:
list
(
r
))
def
chunk
(
self
,
chunks
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.chunk`.
"""
return
self
.
chunk
(
chunks
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
method_treelize
()
def
reshape
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -713,3 +723,35 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
See :func:`treetensor.torch.where`.
"""
return
self
.
where
(
condition
,
y
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
method_treelize
()
def
std
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.std`.
"""
return
self
.
std
(
*
args
,
**
kwargs
)
@
doc_from_base
()
@
method_treelize
()
def
mean
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.mean`.
"""
return
self
.
mean
(
*
args
,
**
kwargs
)
@
doc_from_base
()
@
method_treelize
()
def
dist
(
self
,
other
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.dist`.
"""
return
self
.
dist
(
other
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
method_treelize
()
def
norm
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.norm`.
"""
return
self
.
norm
(
*
args
,
**
kwargs
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录