Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
a1957a0a
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a1957a0a
编写于
9月 28, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): complete autograd part
上级
fbf5cb25
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
358 addition
and
11 deletion
+358
-11
docs/source/api_doc/common/object.rst
docs/source/api_doc/common/object.rst
+1
-1
test/common/test_object.py
test/common/test_object.py
+10
-0
test/torch/funcs/__init__.py
test/torch/funcs/__init__.py
+1
-0
test/torch/funcs/test_autograd.py
test/torch/funcs/test_autograd.py
+31
-0
test/torch/tensor/__init__.py
test/torch/tensor/__init__.py
+1
-0
test/torch/tensor/test_autograd.py
test/torch/tensor/test_autograd.py
+80
-0
treetensor/common/object.py
treetensor/common/object.py
+43
-0
treetensor/torch/funcs/__init__.py
treetensor/torch/funcs/__init__.py
+3
-0
treetensor/torch/funcs/autograd.py
treetensor/torch/funcs/autograd.py
+83
-0
treetensor/torch/funcs/operation.py
treetensor/torch/funcs/operation.py
+2
-6
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+103
-4
未找到文件。
docs/source/api_doc/common/object.rst
浏览文件 @
a1957a0a
...
...
@@ -7,5 +7,5 @@ Object
-----------------
.. autoclass:: Object
:members: __init__
:members: __init__
, all, any
test/common/test_object.py
浏览文件 @
a1957a0a
...
...
@@ -14,3 +14,13 @@ class TestCommonObject:
assert
Object
({
'a'
:
1
,
'b'
:
2
})
==
typetrans
(
TreeValue
({
'a'
:
1
,
'b'
:
2
}),
Object
)
def
test_all
(
self
):
assert
not
Object
({
'a'
:
False
,
'b'
:
{
'x'
:
False
}}).
all
()
assert
not
Object
({
'a'
:
True
,
'b'
:
{
'x'
:
False
}}).
all
()
assert
Object
({
'a'
:
True
,
'b'
:
{
'x'
:
True
}}).
all
()
def
test_any
(
self
):
assert
not
Object
({
'a'
:
False
,
'b'
:
{
'x'
:
False
}}).
any
()
assert
Object
({
'a'
:
True
,
'b'
:
{
'x'
:
False
}}).
any
()
assert
Object
({
'a'
:
True
,
'b'
:
{
'x'
:
True
}}).
any
()
test/torch/funcs/__init__.py
浏览文件 @
a1957a0a
from
.test_autograd
import
TestTorchFuncsAutograd
from
.test_comparison
import
TestTorchFuncsComparison
from
.test_construct
import
TestTorchFuncsConstruct
from
.test_math
import
TestTorchFuncsMath
...
...
test/torch/funcs/test_autograd.py
0 → 100644
浏览文件 @
a1957a0a
import
treetensor.torch
as
ttorch
from
.base
import
choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class
TestTorchFuncsAutograd
:
@
choose_mark
()
def
test_detach
(
self
):
tt1
=
ttorch
.
tensor
({
'a'
:
[
2
,
3
,
4.0
],
'b'
:
{
'x'
:
[[
5
,
6
],
[
7
,
8.0
]]}
},
requires_grad
=
True
)
assert
tt1
.
requires_grad
.
all
()
tt1r
=
ttorch
.
detach
(
tt1
)
assert
tt1
.
requires_grad
.
all
()
assert
tt1r
is
not
tt1
assert
not
tt1r
.
requires_grad
.
any
()
@
choose_mark
()
def
test_detach_
(
self
):
tt1
=
ttorch
.
tensor
({
'a'
:
[
2
,
3
,
4.0
],
'b'
:
{
'x'
:
[[
5
,
6
],
[
7
,
8.0
]]}
},
requires_grad
=
True
)
assert
tt1
.
requires_grad
.
all
()
tt1r
=
ttorch
.
detach_
(
tt1
)
assert
tt1r
is
tt1
assert
not
tt1
.
requires_grad
.
any
()
test/torch/tensor/__init__.py
浏览文件 @
a1957a0a
from
.test_autograd
import
TestTorchTensorAutograd
from
.test_clazz
import
TestTorchTensorClass
from
.test_comparison
import
TestTorchTensorComparison
from
.test_math
import
TestTorchTensorMath
...
...
test/torch/tensor/test_autograd.py
0 → 100644
浏览文件 @
a1957a0a
import
treetensor.torch
as
ttorch
from
.base
import
choose_mark
# noinspection DuplicatedCode,PyUnresolvedReferences
class
TestTorchTensorAutograd
:
@
choose_mark
()
def
test_requires_grad
(
self
):
tt1
=
ttorch
.
tensor
({
'a'
:
[
2
,
3
,
4.0
],
'b'
:
{
'x'
:
[[
5
,
6
],
[
7
,
8.0
]]}
},
requires_grad
=
True
)
assert
tt1
.
requires_grad
.
all
()
tt1
.
a
.
requires_grad_
(
False
)
assert
not
tt1
.
requires_grad
.
all
()
assert
tt1
.
requires_grad
.
any
()
tt1
.
b
.
x
.
requires_grad_
(
False
)
assert
not
tt1
.
requires_grad
.
all
()
assert
not
tt1
.
requires_grad
.
any
()
@
choose_mark
()
def
test_requires_grad_
(
self
):
tt1
=
ttorch
.
tensor
({
'a'
:
[
2
,
3
,
4.0
],
'b'
:
{
'x'
:
[[
5
,
6
],
[
7
,
8.0
]]}
})
assert
not
tt1
.
requires_grad
.
any
()
tt1
.
requires_grad_
(
True
)
assert
tt1
.
requires_grad
.
all
()
tt1
.
a
.
requires_grad_
(
False
)
assert
not
tt1
.
requires_grad
.
all
()
assert
tt1
.
requires_grad
.
any
()
tt1
.
b
.
x
.
requires_grad_
(
False
)
assert
not
tt1
.
requires_grad
.
all
()
assert
not
tt1
.
requires_grad
.
any
()
@
choose_mark
()
def
test_grad
(
self
):
tt1
=
ttorch
.
tensor
({
'a'
:
[
2
,
3
,
4.0
],
'b'
:
{
'x'
:
[[
5
,
6
],
[
7
,
8.0
]]}
},
requires_grad
=
True
)
mq
=
tt1
.
mean
()
**
2
mq
.
backward
()
assert
ttorch
.
isclose
(
tt1
.
grad
,
ttorch
.
tensor
({
'a'
:
[
1.4286
,
1.4286
,
1.4286
],
'b'
:
{
'x'
:
[[
1.4286
,
1.4286
],
[
1.4286
,
1.4286
]]},
}),
atol
=
1e-4
).
all
()
@
choose_mark
()
def
test_detach
(
self
):
tt1
=
ttorch
.
tensor
({
'a'
:
[
2
,
3
,
4.0
],
'b'
:
{
'x'
:
[[
5
,
6
],
[
7
,
8.0
]]}
},
requires_grad
=
True
)
assert
tt1
.
requires_grad
.
all
()
tt1r
=
tt1
.
detach
()
assert
tt1
.
requires_grad
.
all
()
assert
tt1r
is
not
tt1
assert
not
tt1r
.
requires_grad
.
any
()
@
choose_mark
()
def
test_detach_
(
self
):
tt1
=
ttorch
.
tensor
({
'a'
:
[
2
,
3
,
4.0
],
'b'
:
{
'x'
:
[[
5
,
6
],
[
7
,
8.0
]]}
},
requires_grad
=
True
)
assert
tt1
.
requires_grad
.
all
()
tt1r
=
tt1
.
detach_
()
assert
tt1r
is
tt1
assert
not
tt1
.
requires_grad
.
any
()
treetensor/common/object.py
浏览文件 @
a1957a0a
import
builtins
from
treevalue
import
method_treelize
from
.trees
import
BaseTreeStruct
,
clsmeta
from
.wrappers
import
ireduce
__all__
=
[
"Object"
,
...
...
@@ -33,3 +38,41 @@ class Object(BaseTreeStruct, metaclass=clsmeta(_object, allow_dict=True)):
└── c --> 233
"""
super
(
BaseTreeStruct
,
self
).
__init__
(
data
)
@
ireduce
(
builtins
.
all
,
piter
=
list
)
@
method_treelize
()
def
all
(
self
):
"""
The values in this tree is all true or not.
Examples::
>>> from treetensor.common import Object
>>> Object({'a': False, 'b': {'x': False}}).all()
False
>>> Object({'a': True, 'b': {'x': False}}).all()
False
>>> Object({'a': True, 'b': {'x': True}}).all()
True
"""
return
not
not
self
@
ireduce
(
builtins
.
any
,
piter
=
list
)
@
method_treelize
()
def
any
(
self
):
"""
The values in this tree is not all False or yes.
Examples::
>>> from treetensor.common import Object
>>> Object({'a': False, 'b': {'x': False}}).any()
False
>>> Object({'a': True, 'b': {'x': False}}).any()
True
>>> Object({'a': True, 'b': {'x': True}}).any()
True
"""
return
not
not
self
treetensor/torch/funcs/__init__.py
浏览文件 @
a1957a0a
import
sys
from
.autograd
import
*
from
.autograd
import
__all__
as
_autograd_all
from
.comparison
import
*
from
.comparison
import
__all__
as
_comparison_all
from
.construct
import
*
...
...
@@ -15,6 +17,7 @@ from .reduction import __all__ as _reduction_all
from
...utils
import
module_autoremove
__all__
=
[
*
_autograd_all
,
*
_comparison_all
,
*
_construct_all
,
*
_math_all
,
...
...
treetensor/torch/funcs/autograd.py
0 → 100644
浏览文件 @
a1957a0a
import
torch
from
.base
import
doc_from_base
,
func_treelize
from
...common
import
return_self
__all__
=
[
'detach'
,
'detach_'
]
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
func_treelize
()
def
detach
(
input
):
"""
Detach tensor from calculation graph.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt
<Tensor 0x7f5881338eb8>
├── a --> tensor([[ 2.5262, 0.7398, 0.7966],
│ [ 1.3164, 1.2248, -2.2494]], requires_grad=True)
└── b --> <Tensor 0x7f5881338e10>
└── x --> tensor([[ 0.3578, 0.4611, -0.6668, 0.5356],
[-1.4392, -1.2899, -0.0394, 0.8457],
[ 0.4492, -0.5188, -0.2375, -1.2649]], requires_grad=True)
>>> ttorch.detach(tt)
<Tensor 0x7f588133a588>
├── a --> tensor([[ 2.5262, 0.7398, 0.7966],
│ [ 1.3164, 1.2248, -2.2494]])
└── b --> <Tensor 0x7f588133a4e0>
└── x --> tensor([[ 0.3578, 0.4611, -0.6668, 0.5356],
[-1.4392, -1.2899, -0.0394, 0.8457],
[ 0.4492, -0.5188, -0.2375, -1.2649]])
"""
return
torch
.
detach
(
input
)
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
return_self
@
func_treelize
()
def
detach_
(
input
):
"""
In-place version of :func:`treetensor.torch.detach`.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt
<Tensor 0x7f588133aba8>
├── a --> tensor([[-0.1631, -1.1573, 1.3109],
│ [ 2.7277, -0.0745, -1.2577]], requires_grad=True)
└── b --> <Tensor 0x7f588133ab00>
└── x --> tensor([[-0.5876, 0.9836, 1.9584, -0.1513],
[ 0.5369, -1.3986, 0.9361, 0.6765],
[ 0.6465, -0.2212, 1.5499, -1.2156]], requires_grad=True)
>>> ttorch.detach_(tt)
<Tensor 0x7f588133aba8>
├── a --> tensor([[-0.1631, -1.1573, 1.3109],
│ [ 2.7277, -0.0745, -1.2577]])
└── b --> <Tensor 0x7f588133ab00>
└── x --> tensor([[-0.5876, 0.9836, 1.9584, -0.1513],
[ 0.5369, -1.3986, 0.9361, 0.6765],
[ 0.6465, -0.2212, 1.5499, -1.2156]])
"""
return
torch
.
detach_
(
input
)
treetensor/torch/funcs/operation.py
浏览文件 @
a1957a0a
...
...
@@ -117,10 +117,8 @@ def cat(tensors, *args, **kwargs):
# noinspection PyShadowingNames
@
doc_from_base
()
@
post_process
(
lambda
r
:
tuple
(
r
))
@
post_process
(
auto_tensor
)
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
dict
(
template
=
[
None
]))
@
post_process
(
lambda
r
:
list
(
r
))
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
split
(
tensor
,
split_size_or_sections
,
*
args
,
**
kwargs
):
"""
Splits the tensor into chunks. Each chunk is a view of the original tensor.
...
...
@@ -208,10 +206,8 @@ def split(tensor, split_size_or_sections, *args, **kwargs):
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
post_process
(
lambda
r
:
tuple
(
r
))
@
post_process
(
auto_tensor
)
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
dict
(
template
=
[
None
]))
@
post_process
(
lambda
r
:
list
(
r
))
@
func_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
chunk
(
input
,
chunks
,
*
args
,
**
kwargs
):
"""
Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.
...
...
treetensor/torch/tensor.py
浏览文件 @
a1957a0a
...
...
@@ -174,6 +174,72 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return
self
.
shape
@
property
@
method_treelize
()
def
grad
(
self
):
"""
Return the grad data of the whole tree.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt
<Tensor 0x7feec3bcce80>
├── a --> tensor([[-1.4375, 0.0988, 1.2198],
│ [-0.7627, -0.8797, -0.9299]], requires_grad=True)
└── b --> <Tensor 0x7feec3bccdd8>
└── x --> tensor([[ 0.2149, -0.5839, -0.6049, -0.9151],
[ 1.5381, -1.4386, 0.1831, 0.2018],
[-0.0725, -0.9062, -2.6212, 0.5929]], requires_grad=True)
>>> mq = tt.mean() ** 2
>>> mq.backward()
>>> tt.grad
<Tensor 0x7feec3c0fa90>
├── a --> tensor([[-0.0438, -0.0438, -0.0438],
│ [-0.0438, -0.0438, -0.0438]])
└── b --> <Tensor 0x7feec3c0f9e8>
└── x --> tensor([[-0.0438, -0.0438, -0.0438, -0.0438],
[-0.0438, -0.0438, -0.0438, -0.0438],
[-0.0438, -0.0438, -0.0438, -0.0438]])
"""
return
self
.
grad
@
property
@
method_treelize
(
return_type
=
Object
)
def
requires_grad
(
self
):
"""
Return the grad situation of current tree.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt.requires_grad
<Object 0x7feec3c229e8>
├── a --> True
└── b --> <Object 0x7feec3c22940>
└── x --> True
>>> tt.a.requires_grad_(False)
>>> tt.requires_grad
<Object 0x7feec3c0fa58>
├── a --> False
└── b --> <Object 0x7feec3c0f5f8>
└── x --> True
"""
return
self
.
requires_grad
@
doc_from_base
()
@
return_self
@
method_treelize
()
...
...
@@ -181,9 +247,44 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
Change if autograd should record operations on this tensor:
sets this tensor’s ``requires_grad`` attribute in-place. Returns this tensor.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> tt = ttorch.randn({
... 'a': (2, 3),
... 'b': {'x': (3, 4)},
... })
>>> tt.requires_grad_(True)
>>> tt
<Tensor 0x7feec3c22240>
├── a --> tensor([[ 1.4754, 1.1167, 1.5431],
│ [-0.5816, 0.4746, 0.8392]], requires_grad=True)
└── b --> <Tensor 0x7feec3c22128>
└── x --> tensor([[ 0.3361, 0.8194, 0.1297, -0.5547],
[ 0.2531, -0.0637, 0.9822, 2.1618],
[ 2.0140, -0.0929, 0.9304, 1.5430]], requires_grad=True)
"""
return
self
.
requires_grad_
(
requires_grad
)
@
doc_from_base
()
@
method_treelize
()
def
detach
(
self
):
"""
See :func:`treetensor.torch.detach`.
"""
return
self
.
detach
()
@
doc_from_base
()
@
return_self
@
method_treelize
()
def
detach_
(
self
):
"""
In-place version of :meth:`Tensor.detach`.
"""
return
self
.
detach_
()
# noinspection PyShadowingBuiltins,PyUnusedLocal
@
post_reduce
(
torch
.
all
)
@
method_treelize
(
return_type
=
Object
)
...
...
@@ -715,8 +816,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
@
doc_from_base
()
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_torch
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
dict
(
template
=
[
None
]))
@
post_process
(
lambda
r
:
list
(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
split
(
self
,
split_size
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.split`.
...
...
@@ -725,8 +825,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
@
doc_from_base
()
@
post_process
(
lambda
r
:
replaceable_partial
(
auto_torch
,
cls
=
Tensor
)(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
dict
(
template
=
[
None
]))
@
post_process
(
lambda
r
:
list
(
r
))
@
method_treelize
(
return_type
=
TreeValue
,
rise
=
True
)
def
chunk
(
self
,
chunks
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.chunk`.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录