Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
c5c230d8
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c5c230d8
编写于
9月 10, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): refactor the current code
上级
daa01e40
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
291 addition
and
127 deletion
+291
-127
test/numpy/test_funcs.py
test/numpy/test_funcs.py
+22
-1
test/tensor/test_funcs.py
test/tensor/test_funcs.py
+54
-12
test/tensor/test_treetensor.py
test/tensor/test_treetensor.py
+13
-0
treetensor/common/__init__.py
treetensor/common/__init__.py
+1
-1
treetensor/common/trees.py
treetensor/common/trees.py
+22
-17
treetensor/common/wrappers.py
treetensor/common/wrappers.py
+34
-4
treetensor/numpy/funcs.py
treetensor/numpy/funcs.py
+19
-8
treetensor/numpy/numpy.py
treetensor/numpy/numpy.py
+4
-4
treetensor/tensor/funcs.py
treetensor/tensor/funcs.py
+83
-51
treetensor/tensor/size.py
treetensor/tensor/size.py
+6
-3
treetensor/tensor/tensor.py
treetensor/tensor/tensor.py
+27
-26
treetensor/utils/__init__.py
treetensor/utils/__init__.py
+1
-0
treetensor/utils/func.py
treetensor/utils/func.py
+5
-0
未找到文件。
test/numpy/test_funcs.py
浏览文件 @
c5c230d8
...
...
@@ -35,13 +35,25 @@ class TestNumpyFuncs:
}
})
def
test__numpy_all
(
self
):
def
test_all
(
self
):
assert
not
_numpy_all
(
np
.
array
([
True
,
True
,
False
]))
assert
_numpy_all
(
np
.
array
([
True
,
True
,
True
]))
assert
not
_numpy_all
(
self
.
_DEMO_1
==
self
.
_DEMO_2
)
assert
_numpy_all
(
self
.
_DEMO_1
==
self
.
_DEMO_3
)
assert
not
_numpy_all
(
np
.
array
([
1
,
2
,
3
])
==
np
.
array
([
1
,
2
,
4
]))
assert
_numpy_all
(
np
.
array
([
1
,
2
,
3
])
==
np
.
array
([
1
,
2
,
3
]))
def
test_equal
(
self
):
assert
_numpy_all
(
equal
(
np
.
array
([
1
,
2
,
3
]),
np
.
array
([
1
,
2
,
3
]),
))
assert
not
_numpy_all
(
equal
(
np
.
array
([
1
,
2
,
3
]),
np
.
array
([
1
,
2
,
4
]),
))
assert
_numpy_all
(
equal
(
self
.
_DEMO_1
,
self
.
_DEMO_2
)
==
TreeNumpy
({
'a'
:
np
.
array
([[
True
,
True
,
True
],
[
True
,
True
,
False
]]),
...
...
@@ -64,6 +76,15 @@ class TestNumpyFuncs:
)
def
test_array_equal
(
self
):
assert
_numpy_all
(
array_equal
(
np
.
array
([
1
,
2
,
3
]),
np
.
array
([
1
,
2
,
3
]),
))
assert
not
_numpy_all
(
array_equal
(
np
.
array
([
1
,
2
,
3
]),
np
.
array
([
1
,
2
,
4
]),
))
assert
array_equal
(
self
.
_DEMO_1
,
self
.
_DEMO_2
)
==
TreeNumpy
({
'a'
:
False
,
'b'
:
True
,
...
...
test/tensor/test_funcs.py
浏览文件 @
c5c230d8
import
pytest
import
torch
from
treevalue
import
TreeValue
from
treetensor.tensor
import
TreeTensor
,
zeros
,
zeros_like
,
ones
,
ones_like
,
randint
,
randint_like
,
randn
,
\
randn_like
,
full
,
full_like
,
TreeSize
...
...
@@ -11,13 +12,13 @@ from treetensor.tensor import all as _tensor_all
class
TestTensorFuncs
:
def
test_zeros
(
self
):
assert
_tensor_all
(
zeros
((
2
,
3
))
==
torch
.
zeros
(
2
,
3
))
assert
_tensor_all
(
zeros
({
assert
_tensor_all
(
zeros
(
TreeValue
(
{
'a'
:
(
2
,
3
),
'b'
:
(
5
,
6
),
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
})
==
TreeTensor
({
})
)
==
TreeTensor
({
'a'
:
torch
.
zeros
(
2
,
3
),
'b'
:
torch
.
zeros
(
5
,
6
),
'x'
:
{
...
...
@@ -50,13 +51,13 @@ class TestTensorFuncs:
def
test_ones
(
self
):
assert
_tensor_all
(
ones
((
2
,
3
))
==
torch
.
ones
(
2
,
3
))
assert
_tensor_all
(
ones
({
assert
_tensor_all
(
ones
(
TreeValue
(
{
'a'
:
(
2
,
3
),
'b'
:
(
5
,
6
),
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
})
==
TreeTensor
({
})
)
==
TreeTensor
({
'a'
:
torch
.
ones
(
2
,
3
),
'b'
:
torch
.
ones
(
5
,
6
),
'x'
:
{
...
...
@@ -93,13 +94,13 @@ class TestTensorFuncs:
assert
0.98
<=
_target
.
view
(
60000
).
std
().
tolist
()
<=
1.02
assert
_target
.
shape
==
torch
.
Size
([
200
,
300
])
_target
=
randn
({
_target
=
randn
(
TreeValue
(
{
'a'
:
(
2
,
3
),
'b'
:
(
5
,
6
),
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
})
})
)
assert
_target
.
shape
==
TreeSize
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
5
,
6
]),
...
...
@@ -132,13 +133,13 @@ class TestTensorFuncs:
})
def
test_randint
(
self
):
_target
=
randint
({
_target
=
randint
(
TreeValue
(
{
'a'
:
(
2
,
3
),
'b'
:
(
5
,
6
),
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
},
-
10
,
10
)
}
)
,
-
10
,
10
)
assert
_tensor_all
(
_target
<
10
)
assert
_tensor_all
(
-
10
<=
_target
)
assert
_target
.
shape
==
TreeSize
({
...
...
@@ -149,13 +150,13 @@ class TestTensorFuncs:
}
})
_target
=
randint
({
_target
=
randint
(
TreeValue
(
{
'a'
:
(
2
,
3
),
'b'
:
(
5
,
6
),
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
},
10
)
}
)
,
10
)
assert
_tensor_all
(
_target
<
10
)
assert
_tensor_all
(
0
<=
_target
)
assert
_target
.
shape
==
TreeSize
({
...
...
@@ -206,13 +207,13 @@ class TestTensorFuncs:
})
def
test_full
(
self
):
_target
=
full
({
_target
=
full
(
TreeValue
(
{
'a'
:
(
2
,
3
),
'b'
:
(
5
,
6
),
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
},
233
)
}
)
,
233
)
assert
_tensor_all
(
_target
==
233
)
assert
_target
.
shape
==
TreeSize
({
'a'
:
torch
.
Size
([
2
,
3
]),
...
...
@@ -240,3 +241,44 @@ class TestTensorFuncs:
'd'
:
torch
.
Size
([
1
,
1
,
2
]),
}
})
def
test_all
(
self
):
r1
=
_tensor_all
(
torch
.
tensor
([
1
,
1
,
1
])
==
1
)
assert
torch
.
is_tensor
(
r1
)
assert
r1
==
torch
.
tensor
(
True
)
r2
=
_tensor_all
(
torch
.
tensor
([
1
,
1
,
2
])
==
1
)
assert
torch
.
is_tensor
(
r2
)
assert
r2
==
torch
.
tensor
(
False
)
r3
=
_tensor_all
(
TreeTensor
({
'a'
:
torch
.
Tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
Tensor
([
4
,
5
,
6
]),
'x'
:
{
'c'
:
torch
.
Tensor
([
7
,
8
,
9
])
}
})
==
TreeTensor
({
'a'
:
torch
.
Tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
Tensor
([
4
,
5
,
6
]),
'x'
:
{
'c'
:
torch
.
Tensor
([
7
,
8
,
9
])
}
}))
assert
torch
.
is_tensor
(
r3
)
assert
r3
==
torch
.
tensor
(
True
)
r4
=
_tensor_all
(
TreeTensor
({
'a'
:
torch
.
Tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
Tensor
([
4
,
5
,
6
]),
'x'
:
{
'c'
:
torch
.
Tensor
([
7
,
8
,
9
])
}
})
==
TreeTensor
({
'a'
:
torch
.
Tensor
([
1
,
2
,
3
]),
'b'
:
torch
.
Tensor
([
4
,
5
,
6
]),
'x'
:
{
'c'
:
torch
.
Tensor
([
7
,
8
,
8
])
}
}))
assert
torch
.
is_tensor
(
r4
)
assert
r4
==
torch
.
tensor
(
False
)
test/tensor/test_treetensor.py
浏览文件 @
c5c230d8
...
...
@@ -22,6 +22,15 @@ class TestTensorTreetensor:
}
})
_DEMO_2
=
TreeTensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([[
1
,
2
],
[
5
,
60
]]),
'x'
:
{
'c'
:
torch
.
tensor
([
3
,
5
,
6
,
7
]),
'd'
:
torch
.
tensor
([[[
1
,
2
],
[
8
,
9
]]]),
}
})
def
test_numel
(
self
):
assert
self
.
_DEMO_1
.
numel
()
==
18
...
...
@@ -48,3 +57,7 @@ class TestTensorTreetensor:
'd'
:
torch
.
tensor
([[[
1
,
2
],
[
8
,
9
]]],
dtype
=
torch
.
float32
),
}
}))
def
test_all
(
self
):
assert
(
self
.
_DEMO_1
==
self
.
_DEMO_1
).
all
()
assert
not
(
self
.
_DEMO_1
==
self
.
_DEMO_2
).
all
()
treetensor/common/__init__.py
浏览文件 @
c5c230d8
from
.trees
import
TreeData
,
TreeObject
,
BaseTreeStruct
from
.wrappers
import
kwreduce
,
vreduce
from
.wrappers
import
kwreduce
,
vreduce
,
ireduce
treetensor/common/trees.py
浏览文件 @
c5c230d8
import
operator
from
abc
import
ABCMeta
from
treevalue
import
func_treelize
,
general_tree_valu
e
from
treevalue
import
general_tree_value
,
method_treeliz
e
class
BaseTreeStruct
(
general_tree_value
(),
metaclass
=
ABCMeta
):
...
...
@@ -12,29 +11,35 @@ class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
pass
_OPERATORS
=
{}
for
_op_name
in
getattr
(
operator
,
'__all__'
):
_OPERATORS
[
_op_name
]
=
func_treelize
()(
getattr
(
operator
,
_op_name
))
class
TreeData
(
BaseTreeStruct
,
metaclass
=
ABCMeta
):
"""
Overview:
In ``TreeData`` class, all the comparison operators will be override.
"""
@
method_treelize
()
def
__eq__
(
self
,
other
):
return
self
==
other
class
TreeData
(
BaseTreeStruct
):
def
__
l
e__
(
self
,
other
):
return
_OPERATORS
[
'le'
](
self
,
other
)
@
method_treelize
()
def
__
n
e__
(
self
,
other
):
return
self
!=
other
@
method_treelize
()
def
__lt__
(
self
,
other
):
return
_OPERATORS
[
'lt'
](
self
,
other
)
return
self
<
other
def
__ge__
(
self
,
other
):
return
_OPERATORS
[
'ge'
](
self
,
other
)
@
method_treelize
()
def
__le__
(
self
,
other
):
return
self
<=
other
@
method_treelize
()
def
__gt__
(
self
,
other
):
return
_OPERATORS
[
'gt'
](
self
,
other
)
def
__eq__
(
self
,
other
):
return
_OPERATORS
[
'eq'
](
self
,
other
)
return
self
>
other
def
__ne__
(
self
,
other
):
return
_OPERATORS
[
'ne'
](
self
,
other
)
@
method_treelize
()
def
__ge__
(
self
,
other
):
return
self
>=
other
class
TreeObject
(
BaseTreeStruct
):
...
...
treetensor/common/wrappers.py
浏览文件 @
c5c230d8
from
collections
import
namedtuple
from
functools
import
wraps
from
itertools
import
chain
from
treevalue
import
TreeValue
from
treevalue
import
reduce_
as
treevalue_reduce
def
kwreduce
(
r
educe_
func
):
def
kwreduce
(
rfunc
):
def
_decorator
(
func
):
@
wraps
(
func
)
def
_new_func
(
*
args
,
**
kwargs
):
_result
=
func
(
*
args
,
**
kwargs
)
if
isinstance
(
_result
,
TreeValue
):
return
treevalue_reduce
(
_result
,
r
educe_
func
)
return
treevalue_reduce
(
_result
,
rfunc
)
else
:
return
_result
...
...
@@ -19,5 +21,33 @@ def kwreduce(reduce_func):
return
_decorator
def
vreduce
(
vreduce_func
):
return
kwreduce
(
lambda
**
kws
:
vreduce_func
(
kws
.
values
()))
def
vreduce
(
rfunc
):
return
kwreduce
(
lambda
**
kws
:
rfunc
(
kws
.
values
()))
def
ireduce
(
rfunc
):
_IterReduceWrapper
=
namedtuple
(
"_IterReduceWrapper"
,
[
'v'
])
def
_reduce_func
(
values
):
_list
=
[]
for
item
in
values
:
if
isinstance
(
item
,
_IterReduceWrapper
):
_list
.
append
(
item
.
v
)
else
:
_list
.
append
([
item
])
return
_IterReduceWrapper
(
chain
(
*
_list
))
def
_decorator
(
func
):
rifunc
=
vreduce
(
_reduce_func
)(
func
)
@
wraps
(
func
)
def
_new_func
(
*
args
,
**
kwargs
):
_iw
=
rifunc
(
*
args
,
**
kwargs
)
if
isinstance
(
_iw
,
_IterReduceWrapper
):
return
rfunc
(
_iw
.
v
)
else
:
return
_iw
return
_new_func
return
_decorator
treetensor/numpy/funcs.py
浏览文件 @
c5c230d8
from
functools
import
partial
import
numpy
as
np
from
treevalue
import
func_treelize
from
treevalue
import
func_treelize
as
original_func_treelize
from
.numpy
import
TreeNumpy
from
..common
import
vreduce
from
..common
import
ireduce
from
..utils
import
replaceable_partial
func_treelize
=
replaceable_partial
(
original_func_treelize
,
return_type
=
TreeNumpy
)
@
ireduce
(
all
)
@
func_treelize
()
def
all
(
a
,
*
args
,
**
kwargs
):
return
np
.
all
(
a
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
equal
(
x1
,
x2
,
*
args
,
**
kwargs
):
return
np
.
equal
(
x1
,
x2
,
*
args
,
**
kwargs
)
_treelize
=
partial
(
func_treelize
,
return_type
=
TreeNumpy
)
all
=
vreduce
(
all
)(
_treelize
()(
np
.
all
)
)
equal
=
_treelize
()(
np
.
equal
)
array_equal
=
_treelize
()(
np
.
array_equal
)
@
func_treelize
(
)
def
array_equal
(
a1
,
a2
,
*
args
,
**
kwargs
):
return
np
.
array_equal
(
a1
,
a2
,
*
args
,
**
kwargs
)
treetensor/numpy/numpy.py
浏览文件 @
c5c230d8
import
numpy
as
np
from
treevalue
import
method_treelize
from
..common
import
TreeObject
,
TreeData
,
v
reduce
from
..common
import
TreeObject
,
TreeData
,
i
reduce
class
TreeNumpy
(
TreeData
):
...
...
@@ -15,18 +15,18 @@ class TreeNumpy(TreeData):
return
self
.
tolist
()
@
property
@
v
reduce
(
sum
)
@
i
reduce
(
sum
)
@
method_treelize
(
return_type
=
TreeObject
)
def
size
(
self
:
np
.
ndarray
)
->
int
:
return
self
.
size
@
property
@
v
reduce
(
sum
)
@
i
reduce
(
sum
)
@
method_treelize
(
return_type
=
TreeObject
)
def
nbytes
(
self
:
np
.
ndarray
)
->
int
:
return
self
.
nbytes
@
v
reduce
(
sum
)
@
i
reduce
(
sum
)
@
method_treelize
(
return_type
=
TreeObject
)
def
sum
(
self
:
np
.
ndarray
,
*
args
,
**
kwargs
):
return
self
.
sum
(
*
args
,
**
kwargs
)
treetensor/tensor/funcs.py
浏览文件 @
c5c230d8
from
functools
import
partial
,
wraps
from
typing
import
Tuple
import
torch
from
treevalue
import
func_treelize
,
TreeValue
from
.tensor
import
TreeTensor
from
..common
import
vreduce
_treelize
=
partial
(
func_treelize
,
return_type
=
TreeTensor
)
_python_all
=
all
def
_size_based_treelize
(
*
args_
,
prefix
:
bool
=
False
,
tuple_
:
bool
=
False
,
**
kwargs_
):
def
_decorator
(
func
):
@
_treelize
(
*
args_
,
**
kwargs_
)
def
_sub_func
(
size
:
Tuple
[
int
,
...],
*
args
,
**
kwargs
):
_size_args
=
(
size
,)
if
tuple_
else
size
_args
=
(
*
args
,
*
_size_args
)
if
prefix
else
(
*
_size_args
,
*
args
)
return
func
(
*
_args
,
**
kwargs
)
@
wraps
(
func
)
def
_new_func
(
size
,
*
args
,
**
kwargs
):
if
isinstance
(
size
,
(
TreeValue
,
dict
)):
size
=
TreeTensor
(
size
)
return
_sub_func
(
size
,
*
args
,
**
kwargs
)
return
_new_func
return
_decorator
# Tensor generation based on shapes
zeros
=
_size_based_treelize
()(
torch
.
zeros
)
randn
=
_size_based_treelize
()(
torch
.
randn
)
randint
=
_size_based_treelize
(
prefix
=
True
,
tuple_
=
True
)(
torch
.
randint
)
ones
=
_size_based_treelize
()(
torch
.
ones
)
full
=
_size_based_treelize
(
tuple_
=
True
)(
torch
.
full
)
empty
=
_size_based_treelize
()(
torch
.
empty
)
# Tensor generation based on another tensor
zeros_like
=
_treelize
()(
torch
.
zeros_like
)
randn_like
=
_treelize
()(
torch
.
randn_like
)
randint_like
=
_treelize
()(
torch
.
randint_like
)
ones_like
=
_treelize
()(
torch
.
ones_like
)
full_like
=
_treelize
()(
torch
.
full_like
)
empty_like
=
_treelize
()(
torch
.
empty_like
)
# Tensor operators
all
=
vreduce
(
all
)(
_treelize
()(
torch
.
all
))
eq
=
_treelize
()(
torch
.
eq
)
equal
=
_treelize
()(
torch
.
equal
)
from
treevalue
import
func_treelize
as
original_func_treelize
from
.tensor
import
TreeTensor
,
_reduce_tensor_wrap
from
..common
import
TreeObject
,
ireduce
from
..utils
import
replaceable_partial
func_treelize
=
replaceable_partial
(
original_func_treelize
,
return_type
=
TreeTensor
)
@
func_treelize
()
def
zeros
(
size
,
*
args
,
**
kwargs
):
return
torch
.
zeros
(
*
size
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
zeros_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
zeros_like
(
input_
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
randn
(
size
,
*
args
,
**
kwargs
):
return
torch
.
randn
(
*
size
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
randn_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
randn_like
(
input_
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
randint
(
size
,
*
args
,
**
kwargs
):
return
torch
.
randint
(
*
args
,
size
,
**
kwargs
)
@
func_treelize
()
def
randint_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
randint_like
(
input_
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
ones
(
size
,
*
args
,
**
kwargs
):
return
torch
.
ones
(
*
size
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
ones_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
ones_like
(
input_
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
full
(
size
,
*
args
,
**
kwargs
):
return
torch
.
full
(
size
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
full_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
full_like
(
input_
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
empty
(
size
,
*
args
,
**
kwargs
):
return
torch
.
empty
(
size
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
empty_like
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
empty_like
(
input_
,
*
args
,
**
kwargs
)
@
ireduce
(
_reduce_tensor_wrap
(
torch
.
all
))
@
func_treelize
(
return_type
=
TreeObject
)
def
all
(
input_
,
*
args
,
**
kwargs
):
return
torch
.
all
(
input_
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
eq
(
input_
,
other
,
*
args
,
**
kwargs
):
return
torch
.
eq
(
input_
,
other
,
*
args
,
**
kwargs
)
@
func_treelize
()
def
equal
(
input_
,
other
,
*
args
,
**
kwargs
):
return
torch
.
equal
(
input_
,
other
,
*
args
,
**
kwargs
)
treetensor/tensor/size.py
浏览文件 @
c5c230d8
import
torch
from
treevalue
import
func_treelize
from
treevalue
import
func_treelize
as
original_func_treelize
from
..common
import
BaseTreeStruct
,
TreeObject
from
..common
import
TreeObject
from
..utils
import
replaceable_partial
func_treelize
=
replaceable_partial
(
original_func_treelize
)
# noinspection PyTypeChecker
class
TreeSize
(
BaseTreeStru
ct
):
class
TreeSize
(
TreeObje
ct
):
@
func_treelize
(
return_type
=
TreeObject
)
def
numel
(
self
:
torch
.
Size
)
->
TreeObject
:
return
self
.
numel
()
...
...
treetensor/tensor/tensor.py
浏览文件 @
c5c230d8
import
numpy
as
np
import
torch
from
treevalue
import
method_treelize
,
TreeValue
from
treevalue
import
method_treelize
from
treevalue.utils
import
pre_process
from
.size
import
TreeSize
from
..common
import
TreeObject
,
TreeData
,
v
reduce
from
..common
import
TreeObject
,
TreeData
,
i
reduce
from
..numpy
import
TreeNumpy
def
_same_merge
(
eq
,
hash_
,
**
kwargs
):
kws
=
{
key
:
value
for
key
,
value
in
kwargs
.
items
()
if
not
(
isinstance
(
value
,
TreeValue
)
and
not
value
)
}
class
_Wrapper
:
def
__init__
(
self
,
v
):
self
.
v
=
v
def
__hash__
(
self
):
return
hash_
(
self
.
v
)
def
__eq__
(
self
,
other
):
return
eq
(
self
.
v
,
other
.
v
)
if
len
(
set
(
_Wrapper
(
v
)
for
v
in
kws
.
values
()))
==
1
:
return
list
(
kws
.
values
())[
0
]
else
:
return
TreeTensor
(
kws
)
_reduce_tensor_wrap
=
pre_process
(
lambda
it
:
((
torch
.
tensor
([
*
it
]),),
{}))
# noinspection PyTypeChecker,PyShadowingBuiltins,PyArgumentList
...
...
@@ -51,7 +32,7 @@ class TreeTensor(TreeData):
def
to
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
return
self
.
to
(
*
args
,
**
kwargs
)
@
v
reduce
(
sum
)
@
i
reduce
(
sum
)
@
method_treelize
(
return_type
=
TreeObject
)
def
numel
(
self
:
torch
.
Tensor
):
return
self
.
numel
()
...
...
@@ -61,7 +42,27 @@ class TreeTensor(TreeData):
def
shape
(
self
:
torch
.
Tensor
):
return
self
.
shape
@
vreduce
(
all
)
@
ireduce
(
_reduce_tensor_wrap
(
torch
.
all
)
)
@
method_treelize
(
return_type
=
TreeObject
)
def
all
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
def
all
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
bool
:
return
self
.
all
(
*
args
,
**
kwargs
)
@
ireduce
(
_reduce_tensor_wrap
(
torch
.
any
))
@
method_treelize
(
return_type
=
TreeObject
)
def
any
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
bool
:
return
self
.
any
(
*
args
,
**
kwargs
)
@
ireduce
(
_reduce_tensor_wrap
(
torch
.
max
))
@
method_treelize
(
return_type
=
TreeObject
)
def
max
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
return
self
.
max
(
*
args
,
**
kwargs
)
@
ireduce
(
_reduce_tensor_wrap
(
torch
.
min
))
@
method_treelize
(
return_type
=
TreeObject
)
def
min
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
return
self
.
min
(
*
args
,
**
kwargs
)
@
ireduce
(
_reduce_tensor_wrap
(
torch
.
sum
))
@
method_treelize
(
return_type
=
TreeObject
)
def
sum
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
return
self
.
sum
(
*
args
,
**
kwargs
)
treetensor/utils/__init__.py
0 → 100644
浏览文件 @
c5c230d8
from
.func
import
replaceable_partial
treetensor/utils/func.py
0 → 100644
浏览文件 @
c5c230d8
def
replaceable_partial
(
func
,
**
kws
):
def
_new_func
(
*
args
,
**
kwargs
):
return
func
(
*
args
,
**
{
**
kws
,
**
kwargs
})
return
_new_func
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录