Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
0150535c
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
0150535c
编写于
9月 28, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): upgrade all and any
上级
980844e9
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
175 addition
and
38 deletion
+175
-38
test/torch/funcs/test_reduction.py
test/torch/funcs/test_reduction.py
+56
-6
test/torch/tensor/test_reduction.py
test/torch/tensor/test_reduction.py
+42
-0
treetensor/torch/funcs/construct.py
treetensor/torch/funcs/construct.py
+5
-2
treetensor/torch/funcs/reduction.py
treetensor/torch/funcs/reduction.py
+42
-20
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+30
-10
未找到文件。
test/torch/funcs/test_reduction.py
浏览文件 @
0150535c
...
...
@@ -27,7 +27,7 @@ class TestTorchFuncsReduction:
r4
=
ttorch
.
all
({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
True
]),
})
.
all
()
})
assert
torch
.
is_tensor
(
r4
)
assert
r4
==
torch
.
tensor
(
True
)
assert
r4
...
...
@@ -35,7 +35,7 @@ class TestTorchFuncsReduction:
r5
=
ttorch
.
all
({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
False
]),
})
.
all
()
})
assert
torch
.
is_tensor
(
r5
)
assert
r5
==
torch
.
tensor
(
False
)
assert
not
r5
...
...
@@ -43,11 +43,36 @@ class TestTorchFuncsReduction:
r6
=
ttorch
.
all
({
'a'
:
torch
.
tensor
([
False
,
False
,
False
]),
'b'
:
torch
.
tensor
([
False
,
False
,
False
]),
})
.
all
()
})
assert
torch
.
is_tensor
(
r6
)
assert
r6
==
torch
.
tensor
(
False
)
assert
not
r6
r7
=
ttorch
.
all
(
ttorch
.
tensor
({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
False
]),
}),
reduce
=
False
)
assert
(
r7
==
ttorch
.
tensor
({
'a'
:
True
,
'b'
:
False
})).
all
()
r8
=
ttorch
.
all
(
ttorch
.
tensor
({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
False
]),
}),
dim
=
0
)
assert
(
r8
==
ttorch
.
tensor
({
'a'
:
True
,
'b'
:
False
})).
all
()
with
pytest
.
warns
(
UserWarning
):
r9
=
ttorch
.
all
(
ttorch
.
tensor
({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
False
]),
}),
dim
=
0
,
reduce
=
True
)
assert
(
r9
==
ttorch
.
tensor
({
'a'
:
True
,
'b'
:
False
})).
all
()
@
choose_mark
()
def
test_any
(
self
):
r1
=
ttorch
.
any
(
torch
.
tensor
([
True
,
True
,
True
]))
...
...
@@ -68,7 +93,7 @@ class TestTorchFuncsReduction:
r4
=
ttorch
.
any
({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
True
]),
})
.
all
()
})
assert
torch
.
is_tensor
(
r4
)
assert
r4
==
torch
.
tensor
(
True
)
assert
r4
...
...
@@ -76,7 +101,7 @@ class TestTorchFuncsReduction:
r5
=
ttorch
.
any
({
'a'
:
torch
.
tensor
([
True
,
True
,
True
]),
'b'
:
torch
.
tensor
([
True
,
True
,
False
]),
})
.
all
()
})
assert
torch
.
is_tensor
(
r5
)
assert
r5
==
torch
.
tensor
(
True
)
assert
r5
...
...
@@ -84,11 +109,36 @@ class TestTorchFuncsReduction:
r6
=
ttorch
.
any
({
'a'
:
torch
.
tensor
([
False
,
False
,
False
]),
'b'
:
torch
.
tensor
([
False
,
False
,
False
]),
})
.
all
()
})
assert
torch
.
is_tensor
(
r6
)
assert
r6
==
torch
.
tensor
(
False
)
assert
not
r6
r7
=
ttorch
.
any
(
ttorch
.
tensor
({
'a'
:
torch
.
tensor
([
True
,
True
,
False
]),
'b'
:
torch
.
tensor
([
False
,
False
,
False
]),
}),
reduce
=
False
)
assert
(
r7
==
ttorch
.
tensor
({
'a'
:
True
,
'b'
:
False
})).
all
()
r8
=
ttorch
.
any
(
ttorch
.
tensor
({
'a'
:
torch
.
tensor
([
True
,
True
,
False
]),
'b'
:
torch
.
tensor
([
False
,
False
,
False
]),
}),
dim
=
0
)
assert
(
r8
==
ttorch
.
tensor
({
'a'
:
True
,
'b'
:
False
})).
all
()
with
pytest
.
warns
(
UserWarning
):
r9
=
ttorch
.
any
(
ttorch
.
tensor
({
'a'
:
torch
.
tensor
([
True
,
True
,
False
]),
'b'
:
torch
.
tensor
([
False
,
False
,
False
]),
}),
dim
=
0
,
reduce
=
True
)
assert
(
r9
==
ttorch
.
tensor
({
'a'
:
True
,
'b'
:
False
})).
all
()
@
choose_mark
()
def
test_min
(
self
):
t1
=
ttorch
.
min
(
torch
.
tensor
([
1.0
,
2.0
,
1.5
]))
...
...
test/torch/tensor/test_reduction.py
浏览文件 @
0150535c
import
pytest
import
torch
import
treetensor.torch
as
ttorch
...
...
@@ -24,6 +25,22 @@ class TestTorchTensorReduction:
assert
t2
.
dtype
==
torch
.
bool
assert
not
t2
t3
=
ttorch
.
tensor
({
'a'
:
[
True
,
False
],
'b'
:
{
'x'
:
[[
True
,
True
,
],
[
True
,
True
,
]]}
}).
all
(
reduce
=
False
)
assert
(
t3
==
ttorch
.
tensor
({
'a'
:
False
,
'b'
:
{
'x'
:
True
},
})).
all
()
t4
=
ttorch
.
tensor
({
'a'
:
[
True
,
False
],
'b'
:
{
'x'
:
[[
True
,
True
,
],
[
True
,
True
,
]]}
}).
all
(
dim
=
0
)
assert
(
t4
==
ttorch
.
tensor
({
'a'
:
False
,
'b'
:
{
'x'
:
[
True
,
True
]},
})).
all
()
@
choose_mark
()
def
test_any
(
self
):
t1
=
ttorch
.
Tensor
({
...
...
@@ -42,6 +59,31 @@ class TestTorchTensorReduction:
assert
t2
.
dtype
==
torch
.
bool
assert
not
t2
t3
=
ttorch
.
Tensor
({
'a'
:
[
True
,
False
],
'b'
:
{
'x'
:
[[
False
,
False
,
],
[
False
,
False
,
]]}
}).
any
(
reduce
=
False
)
assert
(
t3
==
ttorch
.
tensor
({
'a'
:
True
,
'b'
:
False
,
}))
t4
=
ttorch
.
Tensor
({
'a'
:
[
True
,
False
],
'b'
:
{
'x'
:
[[
False
,
False
,
],
[
False
,
False
,
]]}
}).
any
(
dim
=
0
)
assert
(
t4
==
ttorch
.
tensor
({
'a'
:
True
,
'b'
:
[
False
,
False
],
}))
with
pytest
.
warns
(
UserWarning
):
t5
=
ttorch
.
Tensor
({
'a'
:
[
True
,
False
],
'b'
:
{
'x'
:
[[
False
,
False
,
],
[
False
,
False
,
]]}
}).
any
(
dim
=
0
,
reduce
=
True
)
assert
(
t5
==
ttorch
.
tensor
({
'a'
:
True
,
'b'
:
[
False
,
False
],
}))
@
choose_mark
()
def
test_max
(
self
):
t1
=
ttorch
.
Tensor
({
...
...
treetensor/torch/funcs/construct.py
浏览文件 @
0150535c
...
...
@@ -15,7 +15,7 @@ __all__ = [
@
doc_from_base
()
@
func_treelize
()
def
tensor
(
*
args
,
**
kwargs
):
def
tensor
(
data
,
*
args
,
**
kwargs
):
"""
In ``treetensor``, you can create a tree tensor with simple data structure.
...
...
@@ -36,7 +36,10 @@ def tensor(*args, **kwargs):
└── c --> tensor([[ True, False],
[False, True]])
"""
return
torch
.
tensor
(
*
args
,
**
kwargs
)
if
torch
.
is_tensor
(
data
):
return
data
else
:
return
torch
.
tensor
(
data
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins
...
...
treetensor/torch/funcs/reduction.py
浏览文件 @
0150535c
...
...
@@ -11,11 +11,23 @@ __all__ = [
]
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_reduce
(
torch
.
all
)
@
func_treelize
(
return_type
=
Object
)
def
_all_r
(
input
,
*
args
,
**
kwargs
):
return
input
# noinspection PyShadowingBuiltins
@
func_treelize
()
def
_all_nr
(
input
,
*
args
,
**
kwargs
):
return
torch
.
all
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
doc_from_base
()
@
rmreduce
(
torch
.
all
)
@
func_treelize
(
return_type
=
Object
)
def
all
(
input
,
*
args
,
**
kwargs
):
@
auto_reduce
(
_all_r
,
_all_nr
)
def
all
(
input
,
*
args
,
reduce
=
None
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``all`` result of a whole tree with this function.
...
...
@@ -32,29 +44,39 @@ def all(input, *args, **kwargs):
>>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}))
tensor(False)
.. note::
In this ``all`` function, the return value should be a tensor with single boolean value.
If what you need is a tree of boolean tensors, you should do like this
>>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}), reduce=False)
<Tensor 0x7fcda55652b0>
├── a --> tensor(True)
└── b --> <Tensor 0x7fcda5565208>
└── x --> tensor(False)
>>> ttorch.tensor({
... 'a': [True, True],
... 'b': {'x': [True, False]},
... }).map(lambda x: torch.all(x))
<Tensor 0x7ff363bbc588>
>>> ttorch.all(ttorch.tensor({'a': [True, True], 'b': {'x': [True, False]}}), dim=0)
<Tensor 0x7fcda5565780>
├── a --> tensor(True)
└── b --> <Tensor 0x7ff363bb643
8>
└── b --> <Tensor 0x7fcda55656d
8>
└── x --> tensor(False)
"""
return
torch
.
all
(
input
,
*
args
,
**
kwargs
)
pass
# pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_reduce
(
torch
.
any
)
@
func_treelize
(
return_type
=
Object
)
def
_any_r
(
input
,
*
args
,
**
kwargs
):
return
input
# noinspection PyShadowingBuiltins
@
func_treelize
()
def
_any_nr
(
input
,
*
args
,
**
kwargs
):
return
torch
.
any
(
input
,
*
args
,
**
kwargs
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
doc_from_base
()
@
rmreduce
(
torch
.
any
)
@
func_treelize
(
return_type
=
Object
)
def
any
(
input
,
*
args
,
**
kwargs
):
@
auto_reduce
(
_any_r
,
_any_nr
)
def
any
(
input
,
*
args
,
reduce
=
None
,
**
kwargs
):
"""
In ``treetensor``, you can get the ``any`` result of a whole tree with this function.
...
...
@@ -86,7 +108,7 @@ def any(input, *args, **kwargs):
└── b --> <Tensor 0x7ff363bc67f0>
└── x --> tensor(False)
"""
return
torch
.
any
(
input
,
*
args
,
**
kwargs
)
pass
# pragma: no cover
# noinspection PyShadowingBuiltins
...
...
treetensor/torch/tensor.py
浏览文件 @
0150535c
...
...
@@ -184,25 +184,45 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return
self
.
requires_grad_
(
requires_grad
)
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_reduce
(
torch
.
all
)
@
method_treelize
(
return_type
=
Object
)
def
__all_r
(
self
,
*
args
,
**
kwargs
):
return
self
# noinspection PyShadowingBuiltins
@
method_treelize
()
def
__all_nr
(
self
,
*
args
,
**
kwargs
):
return
torch
.
all
(
self
,
*
args
,
**
kwargs
)
# noinspection PyArgumentList
@
doc_from_base
()
@
rmreduce
(
torch
.
all
)
@
method_treelize
(
return_type
=
Object
)
def
all
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
bool
:
@
auto_reduce
(
__all_r
,
__all_nr
)
def
all
(
self
:
torch
.
Tensor
,
*
args
,
reduce
=
None
,
**
kwargs
)
->
bool
:
"""
See :func:`treetensor.torch.all`
"""
return
self
.
all
(
*
args
,
**
kwargs
)
pass
# pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_reduce
(
torch
.
any
)
@
method_treelize
(
return_type
=
Object
)
def
__any_r
(
self
,
*
args
,
**
kwargs
):
return
self
# noinspection PyShadowingBuiltins
@
method_treelize
()
def
__any_nr
(
self
,
*
args
,
**
kwargs
):
return
torch
.
any
(
self
,
*
args
,
**
kwargs
)
# noinspection PyArgumentList
@
doc_from_base
()
@
rmreduce
(
torch
.
any
)
@
method_treelize
(
return_type
=
Object
)
def
any
(
self
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
bool
:
@
auto_reduce
(
__any_r
,
__any_nr
)
def
any
(
self
:
torch
.
Tensor
,
*
args
,
reduce
=
None
,
**
kwargs
)
->
bool
:
"""
See :func:`treetensor.torch.any`
"""
return
self
.
any
(
*
args
,
**
kwargs
)
pass
# pragma: no cover
@
doc_from_base
()
@
rmreduce
(
torch
.
max
)
...
...
@@ -762,7 +782,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
@
doc_from_base
()
@
auto_reduce
(
__std_r
,
__std_nr
)
@
method_treelize
()
def
std
(
self
,
*
args
,
**
kwargs
):
def
std
(
self
,
*
args
,
reduce
=
None
,
**
kwargs
):
"""
See :func:`treetensor.torch.std`.
"""
...
...
@@ -781,7 +801,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
@
doc_from_base
()
@
auto_reduce
(
__mean_r
,
__mean_nr
)
@
method_treelize
()
def
mean
(
self
,
*
args
,
**
kwargs
):
def
mean
(
self
,
*
args
,
reduce
=
None
,
**
kwargs
):
"""
See :func:`treetensor.torch.mean`.
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录