Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
fbfdb128
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
fbfdb128
编写于
9月 13, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): add ttorch.tensor
上级
eba7817e
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
140 addition
and
59 deletion
+140
-59
test/torch/__init__.py
test/torch/__init__.py
+2
-2
test/torch/test_funcs.py
test/torch/test_funcs.py
+63
-36
test/torch/test_tensor.py
test/torch/test_tensor.py
+5
-5
treetensor/__init__.py
treetensor/__init__.py
+1
-1
treetensor/torch/funcs.py
treetensor/torch/funcs.py
+44
-9
treetensor/torch/size.py
treetensor/torch/size.py
+2
-2
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+4
-4
treetensor/utils/func.py
treetensor/utils/func.py
+19
-0
未找到文件。
test/torch/__init__.py
浏览文件 @
fbfdb128
from
.test_funcs
import
TestT
ensor
Funcs
from
.test_t
reetensor
import
TestTensorTreet
ensor
from
.test_funcs
import
TestT
orch
Funcs
from
.test_t
ensor
import
TestTorchT
ensor
test/torch/test_funcs.py
浏览文件 @
fbfdb128
...
...
@@ -7,7 +7,34 @@ import treetensor.torch as ttorch
# noinspection DuplicatedCode
@
pytest
.
mark
.
unittest
class
TestTensorFuncs
:
class
TestTorchFuncs
:
def
test_tensor
(
self
):
t1
=
ttorch
.
tensor
(
True
)
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
t1
t2
=
ttorch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])
assert
isinstance
(
t2
,
torch
.
Tensor
)
assert
(
t2
==
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])).
all
()
t3
=
ttorch
.
tensor
({
'a'
:
[
1
,
2
],
'b'
:
[[
3
,
4
],
[
5
,
6.2
]],
'x'
:
{
'c'
:
True
,
'd'
:
[
False
,
True
],
}
})
assert
isinstance
(
t3
,
ttorch
.
Tensor
)
assert
(
t3
==
ttorch
.
Tensor
({
'a'
:
torch
.
tensor
([
1
,
2
]),
'b'
:
torch
.
tensor
([[
3
,
4
],
[
5
,
6.2
]]),
'x'
:
{
'c'
:
torch
.
tensor
(
True
),
'd'
:
torch
.
tensor
([
False
,
True
]),
}
})).
all
()
def
test_zeros
(
self
):
assert
ttorch
.
all
(
ttorch
.
zeros
((
2
,
3
))
==
torch
.
zeros
(
2
,
3
))
assert
ttorch
.
all
(
ttorch
.
zeros
(
TreeValue
({
...
...
@@ -16,7 +43,7 @@ class TestTensorFuncs:
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
}))
==
ttorch
.
T
reeT
ensor
({
}))
==
ttorch
.
Tensor
({
'a'
:
torch
.
zeros
(
2
,
3
),
'b'
:
torch
.
zeros
(
5
,
6
),
'x'
:
{
...
...
@@ -30,14 +57,14 @@ class TestTensorFuncs:
torch
.
tensor
([[
0
,
0
,
0
],
[
0
,
0
,
0
]]),
)
assert
ttorch
.
all
(
ttorch
.
zeros_like
(
ttorch
.
TreeTensor
({
ttorch
.
zeros_like
(({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
]),
'x'
:
{
'c'
:
torch
.
tensor
([
5
,
6
,
7
]),
'd'
:
torch
.
tensor
([[[
8
,
9
]]]),
}
}))
==
ttorch
.
T
reeT
ensor
({
}))
==
ttorch
.
Tensor
({
'a'
:
torch
.
tensor
([[
0
,
0
,
0
],
[
0
,
0
,
0
]]),
'b'
:
torch
.
tensor
([
0
,
0
,
0
,
0
]),
'x'
:
{
...
...
@@ -55,7 +82,7 @@ class TestTensorFuncs:
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
}))
==
ttorch
.
T
reeT
ensor
({
}))
==
ttorch
.
Tensor
({
'a'
:
torch
.
ones
(
2
,
3
),
'b'
:
torch
.
ones
(
5
,
6
),
'x'
:
{
...
...
@@ -69,14 +96,14 @@ class TestTensorFuncs:
torch
.
tensor
([[
1
,
1
,
1
],
[
1
,
1
,
1
]])
)
assert
ttorch
.
all
(
ttorch
.
ones_like
(
ttorch
.
TreeTensor
({
ttorch
.
ones_like
(({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
]),
'x'
:
{
'c'
:
torch
.
tensor
([
5
,
6
,
7
]),
'd'
:
torch
.
tensor
([[[
8
,
9
]]]),
}
}))
==
ttorch
.
T
reeT
ensor
({
}))
==
ttorch
.
Tensor
({
'a'
:
torch
.
tensor
([[
1
,
1
,
1
],
[
1
,
1
,
1
]]),
'b'
:
torch
.
tensor
([
1
,
1
,
1
,
1
]),
'x'
:
{
...
...
@@ -99,7 +126,7 @@ class TestTensorFuncs:
'c'
:
(
2
,
3
,
4
),
}
}))
assert
_target
.
shape
==
ttorch
.
Tree
Size
({
assert
_target
.
shape
==
ttorch
.
Size
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
5
,
6
]),
'x'
:
{
...
...
@@ -113,7 +140,7 @@ class TestTensorFuncs:
assert
0.98
<=
_target
.
view
(
60000
).
std
().
tolist
()
<=
1.02
assert
_target
.
shape
==
torch
.
Size
([
200
,
300
])
_target
=
ttorch
.
randn_like
(
ttorch
.
TreeTensor
({
_target
=
ttorch
.
randn_like
(({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
torch
.
float32
),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
],
dtype
=
torch
.
float32
),
'x'
:
{
...
...
@@ -121,7 +148,7 @@ class TestTensorFuncs:
'd'
:
torch
.
tensor
([[[
8
,
9
]]],
dtype
=
torch
.
float32
),
}
}))
assert
_target
.
shape
==
ttorch
.
Tree
Size
({
assert
_target
.
shape
==
ttorch
.
Size
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
4
]),
'x'
:
{
...
...
@@ -140,7 +167,7 @@ class TestTensorFuncs:
}))
assert
ttorch
.
all
(
_target
<
10
)
assert
ttorch
.
all
(
-
10
<=
_target
)
assert
_target
.
shape
==
ttorch
.
Tree
Size
({
assert
_target
.
shape
==
ttorch
.
Size
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
5
,
6
]),
'x'
:
{
...
...
@@ -157,7 +184,7 @@ class TestTensorFuncs:
}))
assert
ttorch
.
all
(
_target
<
10
)
assert
ttorch
.
all
(
0
<=
_target
)
assert
_target
.
shape
==
ttorch
.
Tree
Size
({
assert
_target
.
shape
==
ttorch
.
Size
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
5
,
6
]),
'x'
:
{
...
...
@@ -166,7 +193,7 @@ class TestTensorFuncs:
})
def
test_randint_like
(
self
):
_target
=
ttorch
.
randint_like
(
ttorch
.
TreeTensor
({
_target
=
ttorch
.
randint_like
(({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
]),
'x'
:
{
...
...
@@ -176,7 +203,7 @@ class TestTensorFuncs:
}),
-
10
,
10
)
assert
ttorch
.
all
(
_target
<
10
)
assert
ttorch
.
all
(
-
10
<=
_target
)
assert
_target
.
shape
==
ttorch
.
Tree
Size
({
assert
_target
.
shape
==
ttorch
.
Size
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
4
]),
'x'
:
{
...
...
@@ -185,7 +212,7 @@ class TestTensorFuncs:
}
})
_target
=
ttorch
.
randint_like
(
ttorch
.
TreeTensor
({
_target
=
ttorch
.
randint_like
(({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
]),
'x'
:
{
...
...
@@ -195,7 +222,7 @@ class TestTensorFuncs:
}),
10
)
assert
ttorch
.
all
(
_target
<
10
)
assert
ttorch
.
all
(
0
<=
_target
)
assert
_target
.
shape
==
ttorch
.
Tree
Size
({
assert
_target
.
shape
==
ttorch
.
Size
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
4
]),
'x'
:
{
...
...
@@ -213,7 +240,7 @@ class TestTensorFuncs:
}
}),
233
)
assert
ttorch
.
all
(
_target
==
233
)
assert
_target
.
shape
==
ttorch
.
Tree
Size
({
assert
_target
.
shape
==
ttorch
.
Size
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
5
,
6
]),
'x'
:
{
...
...
@@ -222,7 +249,7 @@ class TestTensorFuncs:
})
def
test_full_like
(
self
):
_target
=
ttorch
.
full_like
(
ttorch
.
TreeTensor
({
_target
=
ttorch
.
full_like
(({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
]),
'x'
:
{
...
...
@@ -231,7 +258,7 @@ class TestTensorFuncs:
}
}),
233
)
assert
ttorch
.
all
(
_target
==
233
)
assert
_target
.
shape
==
ttorch
.
Tree
Size
({
assert
_target
.
shape
==
ttorch
.
Size
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
4
]),
'x'
:
{
...
...
@@ -248,7 +275,7 @@ class TestTensorFuncs:
'c'
:
(
2
,
3
,
4
),
}
}))
assert
_target
.
shape
==
ttorch
.
Tree
Size
({
assert
_target
.
shape
==
ttorch
.
Size
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
5
,
6
]),
'x'
:
{
...
...
@@ -257,7 +284,7 @@ class TestTensorFuncs:
})
def
test_empty_like
(
self
):
_target
=
ttorch
.
empty_like
(
ttorch
.
TreeTensor
({
_target
=
ttorch
.
empty_like
(({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
]),
'x'
:
{
...
...
@@ -265,7 +292,7 @@ class TestTensorFuncs:
'd'
:
torch
.
tensor
([[[
8
,
9
]]]),
}
}))
assert
_target
.
shape
==
ttorch
.
Tree
Size
({
assert
_target
.
shape
==
ttorch
.
Size
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
4
]),
'x'
:
{
...
...
@@ -290,7 +317,7 @@ class TestTensorFuncs:
assert
r3
==
torch
.
tensor
(
False
)
assert
not
r3
r4
=
ttorch
.
all
(
ttorch
.
TreeTensor
({
r4
=
ttorch
.
all
(({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
True
]),
})).
all
()
...
...
@@ -298,7 +325,7 @@ class TestTensorFuncs:
assert
r4
==
torch
.
tensor
(
True
)
assert
r4
r5
=
ttorch
.
all
(
ttorch
.
TreeTensor
({
r5
=
ttorch
.
all
(({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
False
]),
})).
all
()
...
...
@@ -306,7 +333,7 @@ class TestTensorFuncs:
assert
r5
==
torch
.
tensor
(
False
)
assert
not
r5
r6
=
ttorch
.
all
(
ttorch
.
TreeTensor
({
r6
=
ttorch
.
all
(({
'a'
:
torch
.
tensor
([
False
,
False
,
False
]),
'b'
:
torch
.
tensor
([
False
,
False
,
False
]),
})).
all
()
...
...
@@ -330,7 +357,7 @@ class TestTensorFuncs:
assert
r3
==
torch
.
tensor
(
False
)
assert
not
r3
r4
=
ttorch
.
any
(
ttorch
.
TreeTensor
({
r4
=
ttorch
.
any
(({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
True
]),
})).
all
()
...
...
@@ -338,7 +365,7 @@ class TestTensorFuncs:
assert
r4
==
torch
.
tensor
(
True
)
assert
r4
r5
=
ttorch
.
any
(
ttorch
.
TreeTensor
({
r5
=
ttorch
.
any
(({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
False
]),
})).
all
()
...
...
@@ -346,7 +373,7 @@ class TestTensorFuncs:
assert
r5
==
torch
.
tensor
(
True
)
assert
r5
r6
=
ttorch
.
any
(
ttorch
.
TreeTensor
({
r6
=
ttorch
.
any
(({
'a'
:
torch
.
tensor
([
False
,
False
,
False
]),
'b'
:
torch
.
tensor
([
False
,
False
,
False
]),
})).
all
()
...
...
@@ -360,17 +387,17 @@ class TestTensorFuncs:
assert
ttorch
.
eq
(
torch
.
tensor
([
1
,
1
,
1
]),
1
).
all
()
assert
not
ttorch
.
eq
(
torch
.
tensor
([
1
,
1
,
2
]),
1
).
all
()
assert
ttorch
.
eq
(
ttorch
.
TreeTensor
({
assert
ttorch
.
eq
(({
'a'
:
torch
.
tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
tensor
([
4
,
5
,
6
]),
}),
ttorch
.
TreeTensor
({
}),
({
'a'
:
torch
.
tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
tensor
([
4
,
5
,
6
]),
})).
all
()
assert
not
ttorch
.
eq
(
ttorch
.
TreeTensor
({
assert
not
ttorch
.
eq
(({
'a'
:
torch
.
tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
tensor
([
4
,
5
,
6
]),
}),
ttorch
.
TreeTensor
({
}),
({
'a'
:
torch
.
tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
tensor
([
4
,
5
,
5
]),
})).
all
()
...
...
@@ -384,20 +411,20 @@ class TestTensorFuncs:
assert
isinstance
(
p2
,
bool
)
assert
not
p2
p3
=
ttorch
.
equal
(
ttorch
.
TreeTensor
({
p3
=
ttorch
.
equal
(({
'a'
:
torch
.
tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
tensor
([
4
,
5
,
6
]),
}),
ttorch
.
TreeTensor
({
}),
({
'a'
:
torch
.
tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
tensor
([
4
,
5
,
6
]),
}))
assert
isinstance
(
p3
,
bool
)
assert
p3
p4
=
ttorch
.
equal
(
ttorch
.
TreeTensor
({
p4
=
ttorch
.
equal
(({
'a'
:
torch
.
tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
tensor
([
4
,
5
,
6
]),
}),
ttorch
.
TreeTensor
({
}),
({
'a'
:
torch
.
tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
tensor
([
4
,
5
,
5
]),
}))
...
...
test/torch/test_t
reet
ensor.py
→
test/torch/test_tensor.py
浏览文件 @
fbfdb128
...
...
@@ -6,12 +6,12 @@ from treevalue import func_treelize
import
treetensor.numpy
as
tnp
import
treetensor.torch
as
ttorch
_all_is
=
func_treelize
(
return_type
=
ttorch
.
T
reeT
ensor
)(
lambda
x
,
y
:
x
is
y
)
_all_is
=
func_treelize
(
return_type
=
ttorch
.
Tensor
)(
lambda
x
,
y
:
x
is
y
)
@
pytest
.
mark
.
unittest
class
TestT
ensorTreet
ensor
:
_DEMO_1
=
ttorch
.
T
reeT
ensor
({
class
TestT
orchT
ensor
:
_DEMO_1
=
ttorch
.
Tensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([[
1
,
2
],
[
5
,
6
]]),
'x'
:
{
...
...
@@ -20,7 +20,7 @@ class TestTensorTreetensor:
}
})
_DEMO_2
=
ttorch
.
T
reeT
ensor
({
_DEMO_2
=
ttorch
.
Tensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([[
1
,
2
],
[
5
,
60
]]),
'x'
:
{
...
...
@@ -47,7 +47,7 @@ class TestTensorTreetensor:
assert
_all_is
(
self
.
_DEMO_1
.
cpu
(),
self
.
_DEMO_1
).
reduce
(
lambda
**
kws
:
all
(
kws
.
values
()))
def
test_to
(
self
):
assert
ttorch
.
all
(
self
.
_DEMO_1
.
to
(
torch
.
float32
)
==
ttorch
.
T
reeT
ensor
({
assert
ttorch
.
all
(
self
.
_DEMO_1
.
to
(
torch
.
float32
)
==
ttorch
.
Tensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
torch
.
float32
),
'b'
:
torch
.
tensor
([[
1
,
2
],
[
5
,
6
]],
dtype
=
torch
.
float32
),
'x'
:
{
...
...
treetensor/__init__.py
浏览文件 @
fbfdb128
from
.common
import
TreeObject
from
.numpy
import
TreeNumpy
from
.torch
import
T
reeT
ensor
from
.torch
import
Tensor
treetensor/torch/funcs.py
浏览文件 @
fbfdb128
import
builtins
import
torch
from
treevalue
import
TreeValue
from
treevalue
import
func_treelize
as
original_func_treelize
from
treevalue.utils
import
post_process
from
.tensor
import
T
reeT
ensor
,
tireduce
from
.tensor
import
Tensor
,
tireduce
from
..common
import
TreeObject
,
ireduce
from
..utils
import
replaceable_partial
,
doc_from
from
..utils
import
replaceable_partial
,
doc_from
,
args_mapping
__all__
=
[
'zeros'
,
'zeros_like'
,
...
...
@@ -16,9 +18,13 @@ __all__ = [
'empty'
,
'empty_like'
,
'all'
,
'any'
,
'eq'
,
'equal'
,
'tensor'
,
]
func_treelize
=
replaceable_partial
(
original_func_treelize
,
return_type
=
TreeTensor
)
func_treelize
=
post_process
(
post_process
(
args_mapping
(
lambda
i
,
x
:
Tensor
(
x
)
if
isinstance
(
x
,
(
dict
,
TreeValue
))
else
x
)))(
replaceable_partial
(
original_func_treelize
,
return_type
=
Tensor
)
)
@
doc_from
(
torch
.
zeros
)
...
...
@@ -102,18 +108,20 @@ def all(input_, *args, **kwargs):
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> all(torch.tensor([True, True])) # the same as torch.all
torch.tensor(True)
>>> all(
TreeT
ensor({
>>> 'a':
torch.tensor([True, True])
,
>>> 'b':
torch.tensor([True, True])
,
>>> all(
ttorch.t
ensor({
>>> 'a':
[True, True]
,
>>> 'b':
[True, True]
,
>>> }))
torch.tensor(True)
>>> all(T
reeT
ensor({
>>> 'a':
torch.tensor([True, True])
,
>>> 'b':
torch.tensor([True, False])
,
>>> all(Tensor({
>>> 'a':
[True, True]
,
>>> 'b':
[True, False]
,
>>> }))
torch.tensor(False)
...
...
@@ -139,3 +147,30 @@ def eq(input_, other, *args, **kwargs):
@
func_treelize
()
def
equal
(
input_
,
other
,
*
args
,
**
kwargs
):
return
torch
.
equal
(
input_
,
other
,
*
args
,
**
kwargs
)
@
doc_from
(
torch
.
tensor
)
@
func_treelize
()
def
tensor
(
*
args
,
**
kwargs
):
"""
In ``treetensor``, you can create a tree tensor with simple data structure.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.tensor(True) # the same as torch.tensor(True)
torch.tensor(True)
>>> ttorch.tensor([1, 2, 3]) # the same as torch.tensor([1, 2, 3])
torch.tensor([1, 2, 3])
>>> ttorch.tensor({'a': 1, 'b': [1, 2, 3], 'c': [[True, False], [False, True]]})
ttorch.Tensor({
'a': torch.tensor(1),
'b': torch.tensor([1, 2, 3]),
'c': torch.tensor([[True, False], [False, True]]),
})
"""
return
torch
.
tensor
(
*
args
,
**
kwargs
)
treetensor/torch/size.py
浏览文件 @
fbfdb128
...
...
@@ -7,12 +7,12 @@ from ..utils import replaceable_partial
func_treelize
=
replaceable_partial
(
original_func_treelize
)
__all__
=
[
'
Tree
Size'
'Size'
]
# noinspection PyTypeChecker
class
Tree
Size
(
TreeObject
):
class
Size
(
TreeObject
):
@
func_treelize
(
return_type
=
TreeObject
)
def
numel
(
self
:
torch
.
Size
)
->
TreeObject
:
return
self
.
numel
()
...
...
treetensor/torch/tensor.py
浏览文件 @
fbfdb128
...
...
@@ -3,13 +3,13 @@ import torch
from
treevalue
import
method_treelize
from
treevalue.utils
import
pre_process
from
.size
import
Tree
Size
from
.size
import
Size
from
..common
import
TreeObject
,
TreeData
,
ireduce
from
..numpy
import
TreeNumpy
from
..utils
import
inherit_names
,
current_names
,
doc_from
__all__
=
[
'T
reeT
ensor'
'Tensor'
]
_reduce_tensor_wrap
=
pre_process
(
lambda
it
:
((
torch
.
tensor
([
*
it
]),),
{}))
...
...
@@ -19,7 +19,7 @@ tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduc
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
@
current_names
()
@
inherit_names
(
TreeData
)
class
T
reeT
ensor
(
TreeData
):
class
Tensor
(
TreeData
):
@
doc_from
(
torch
.
Tensor
.
numpy
)
@
method_treelize
(
return_type
=
TreeNumpy
)
def
numpy
(
self
:
torch
.
Tensor
)
->
np
.
ndarray
:
...
...
@@ -53,7 +53,7 @@ class TreeTensor(TreeData):
@
property
@
doc_from
(
torch
.
Tensor
.
shape
)
@
method_treelize
(
return_type
=
Tree
Size
)
@
method_treelize
(
return_type
=
Size
)
def
shape
(
self
:
torch
.
Tensor
):
return
self
.
shape
...
...
treetensor/utils/func.py
浏览文件 @
fbfdb128
from
functools
import
wraps
from
typing
import
Callable
,
Union
,
Any
__all__
=
[
'replaceable_partial'
,
'args_mapping'
,
]
def
replaceable_partial
(
func
,
**
kws
):
@
wraps
(
func
)
def
_new_func
(
*
args
,
**
kwargs
):
return
func
(
*
args
,
**
{
**
kws
,
**
kwargs
})
return
_new_func
def
args_mapping
(
mapper
:
Callable
[[
Union
[
int
,
str
],
Any
],
Any
]):
def
_decorator
(
func
):
@
wraps
(
func
)
def
_new_func
(
*
args
,
**
kwargs
):
return
func
(
*
(
mapper
(
i
,
x
)
for
i
,
x
in
enumerate
(
args
)),
**
{
k
:
mapper
(
k
,
v
)
for
k
,
v
in
kwargs
.
items
()},
)
return
_new_func
return
_decorator
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录