Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3af10563
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3af10563
编写于
9月 29, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/grad): attach grad immediately
GitOrigin-RevId: e3a168c03ab78aafabfe3d41b0e18c61ee07c256
上级
dc3c17ba
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
34 addition
and
10 deletion
+34
-10
imperative/python/megengine/autodiff/grad_manager.py
imperative/python/megengine/autodiff/grad_manager.py
+19
-10
imperative/python/test/unit/autodiff/test_grad_manger.py
imperative/python/test/unit/autodiff/test_grad_manger.py
+15
-0
未找到文件。
imperative/python/megengine/autodiff/grad_manager.py
浏览文件 @
3af10563
...
...
@@ -3,7 +3,7 @@ from contextlib import contextmanager
from
typing
import
Callable
from
..core.autodiff.grad
import
Grad
from
..tensor
import
tensor
from
..tensor
import
Tensor
,
tensor
from
..utils.future
import
Future
backwarding_grad_manager
=
None
...
...
@@ -84,10 +84,15 @@ class GradManager:
callbacks
=
[]
if
isinstance
(
callbacks
,
Callable
):
callbacks
=
[
callbacks
]
if
isinstance
(
params
,
Tensor
):
params
=
[
params
]
for
p
in
params
:
self
.
_param_dict
[
id
(
p
)]
=
p
for
cb
in
callbacks
:
self
.
_call_back_dict
[
id
(
p
)].
append
(
cb
)
if
self
.
_grad
is
not
None
:
for
p
in
params
:
self
.
_record_param
(
id
(
p
))
return
self
def
_register_after_backward_callback
(
self
,
callback
):
...
...
@@ -143,17 +148,21 @@ class GradManager:
self
.
_recording
=
True
self
.
_grad
=
grad
for
param_id
in
self
.
_param_dict
.
keys
():
param_wrapper
=
self
.
_param_dict
[
param_id
]
callbacks
=
self
.
_call_back_dict
[
param_id
]
self
.
_record_param
(
param_id
)
grad
.
__enter__
()
def
callback
(
param
,
grad
,
callbacks
=
callbacks
,
p
=
param_wrapper
,
gm
=
self
):
ret
=
grad
for
cb
in
callbacks
:
ret
=
cb
(
param
,
ret
)
gm
.
_gradients
[
id
(
p
)]
=
ret
def
_record_param
(
self
,
param_id
):
param_wrapper
=
self
.
_param_dict
[
param_id
]
callbacks
=
self
.
_call_back_dict
[
param_id
]
grad
.
wrt
(
param_wrapper
,
callback
=
callback
)
grad
.
__enter__
()
def
callback
(
param
,
grad
,
callbacks
=
callbacks
,
p
=
param_wrapper
,
gm
=
self
):
ret
=
grad
for
cb
in
callbacks
:
ret
=
cb
(
param
,
ret
)
gm
.
_gradients
[
id
(
p
)]
=
ret
# NOTE: override prev callback wrt when called serval times
self
.
_grad
.
wrt
(
param_wrapper
,
callback
=
callback
)
def
release
(
self
):
r
"""Stops recording and releases resources for gradients calculation.
...
...
imperative/python/test/unit/autodiff/test_grad_manger.py
0 → 100644
浏览文件 @
3af10563
import
numpy
as
np
import
megengine
as
mge
from
megengine
import
autodiff
as
ad
def
test_attach_in_with_block
():
a
=
mge
.
Parameter
([
1.0
])
g
=
ad
.
GradManager
()
with
g
:
b
=
a
*
3
g
.
attach
(
b
)
c
=
b
+
1
g
.
backward
(
c
)
assert
int
(
b
.
grad
.
numpy
())
==
1
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录