Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2627e1f7
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2627e1f7
编写于
10月 28, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/grad_manager): allow multiple calls of `release`
GitOrigin-RevId: 38ca4c78ff1fb7c8b76716a6fe347333b33478ef
上级
67a543f3
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
66 addition
and
23 deletion
+66
-23
imperative/python/megengine/autodiff/grad_manager.py
imperative/python/megengine/autodiff/grad_manager.py
+23
-11
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+2
-5
imperative/python/test/unit/autodiff/test_grad_manger.py
imperative/python/test/unit/autodiff/test_grad_manger.py
+41
-5
imperative/python/test/unit/module/test_qat.py
imperative/python/test/unit/module/test_qat.py
+0
-2
未找到文件。
imperative/python/megengine/autodiff/grad_manager.py
浏览文件 @
2627e1f7
...
@@ -3,9 +3,12 @@ from contextlib import contextmanager
...
@@ -3,9 +3,12 @@ from contextlib import contextmanager
from
typing
import
Callable
from
typing
import
Callable
from
..core.autodiff.grad
import
Grad
from
..core.autodiff.grad
import
Grad
from
..tensor
import
Tensor
,
tensor
from
..logger
import
get_logger
from
..tensor
import
Tensor
from
..utils.future
import
Future
from
..utils.future
import
Future
logger
=
get_logger
(
__name__
)
backwarding_grad_manager
=
None
backwarding_grad_manager
=
None
...
@@ -67,7 +70,7 @@ class GradManager:
...
@@ -67,7 +70,7 @@ class GradManager:
self
.
_after_backward_callback
=
[]
self
.
_after_backward_callback
=
[]
self
.
_gradients
=
dict
()
self
.
_gradients
=
dict
()
def
attach
(
self
,
params
,
callbacks
=
None
):
def
attach
(
self
,
params
:
list
,
callbacks
=
None
):
r
"""Registers parameters that gradients should be calculated with respect to.
r
"""Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this:
Callback Functions should have a signature like this:
...
@@ -77,7 +80,7 @@ class GradManager:
...
@@ -77,7 +80,7 @@ class GradManager:
# do something
# do something
return grad
return grad
:param params: registered parameters
:param params:
to be
registered parameters
:param callbacks: list of callback functions
:param callbacks: list of callback functions
"""
"""
if
callbacks
is
None
:
if
callbacks
is
None
:
...
@@ -95,6 +98,20 @@ class GradManager:
...
@@ -95,6 +98,20 @@ class GradManager:
self
.
_record_param
(
id
(
p
))
self
.
_record_param
(
id
(
p
))
return
self
return
self
def
detach
(
self
,
params
:
list
):
r
"""Remove specific registered parameters and callback functions.
:param params: registered parameters
"""
if
isinstance
(
params
,
Tensor
):
params
=
[
params
]
for
idx
,
param
in
enumerate
(
params
):
if
id
(
param
)
in
self
.
_param_dict
:
self
.
_param_dict
.
pop
(
id
(
param
))
self
.
_call_back_dict
.
pop
(
id
(
param
))
else
:
logger
.
warning
(
"params with index {} is not attached."
.
format
(
idx
))
def
_register_after_backward_callback
(
self
,
callback
):
def
_register_after_backward_callback
(
self
,
callback
):
self
.
_after_backward_callback
.
append
(
callback
)
self
.
_after_backward_callback
.
append
(
callback
)
return
self
return
self
...
@@ -136,7 +153,7 @@ class GradManager:
...
@@ -136,7 +153,7 @@ class GradManager:
else
:
else
:
param
.
grad
+=
grad
param
.
grad
+=
grad
finally
:
finally
:
self
.
_stop_record
()
self
.
release
()
backwarding_grad_manager
=
cache
backwarding_grad_manager
=
cache
def
record
(
self
):
def
record
(
self
):
...
@@ -167,15 +184,10 @@ class GradManager:
...
@@ -167,15 +184,10 @@ class GradManager:
def
release
(
self
):
def
release
(
self
):
r
"""Stops recording and releases resources for gradients calculation.
r
"""Stops recording and releases resources for gradients calculation.
"""
"""
if
not
self
.
_recording
:
raise
RuntimeError
(
"not recording"
)
self
.
_stop_record
()
def
_stop_record
(
self
):
if
self
.
_grad
is
not
None
:
if
self
.
_grad
is
not
None
:
self
.
_grad
.
__exit__
(
None
,
None
,
None
)
self
.
_grad
.
__exit__
(
None
,
None
,
None
)
self
.
_recording
=
False
self
.
_grad
=
None
self
.
_grad
=
None
self
.
_recording
=
False
self
.
_gradients
=
dict
()
self
.
_gradients
=
dict
()
def
__enter__
(
self
):
def
__enter__
(
self
):
...
@@ -183,4 +195,4 @@ class GradManager:
...
@@ -183,4 +195,4 @@ class GradManager:
return
self
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
_stop_record
()
self
.
release
()
imperative/python/megengine/tensor.py
浏览文件 @
2627e1f7
...
@@ -85,11 +85,8 @@ class Tensor(_Tensor):
...
@@ -85,11 +85,8 @@ class Tensor(_Tensor):
def
detach
(
self
):
def
detach
(
self
):
r
"""
r
"""
Returns a new tensor which is treated as constant during backward gradient calcuation,
Returns a new tensor sharing the same data memory, which is treated as a constant
i.e. its gradient is zero.
during backward gradient calcuation, i.e. its gradient is zero.
:param inp: input tensor
"""
"""
Wrapper
=
type
(
self
)
Wrapper
=
type
(
self
)
Tensor
=
type
(
self
.
__wrapped__
)
Tensor
=
type
(
self
.
__wrapped__
)
...
...
imperative/python/test/unit/autodiff/test_grad_manger.py
浏览文件 @
2627e1f7
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine
as
mge
from
megengine
import
autodiff
as
ad
import
megengine.functional
as
F
from
megengine.autodiff
import
GradManager
def
test_basic
():
x
=
mge
.
tensor
([
1.0
,
3.0
,
5.0
]).
reshape
(
1
,
3
)
w
=
mge
.
tensor
([
2.0
,
4.0
,
6.0
]).
reshape
(
3
,
1
)
b
=
mge
.
tensor
(
-
1.0
)
gm
=
GradManager
().
attach
([
w
,
b
])
gm
.
record
()
p
=
F
.
matmul
(
x
,
w
)
y
=
p
+
b
gm
.
backward
(
y
)
gm
.
release
()
# is not necessary
np
.
testing
.
assert_equal
(
w
.
grad
.
numpy
(),
[[
1
],
[
3
],
[
5
]])
np
.
testing
.
assert_equal
(
b
.
grad
.
numpy
(),
[
1
])
w
.
grad
=
None
b
.
grad
=
None
with
gm
:
p
=
F
.
matmul
(
x
,
w
)
y
=
p
+
b
gm
.
backward
(
y
)
np
.
testing
.
assert_equal
(
w
.
grad
.
numpy
(),
[[
1
],
[
3
],
[
5
]])
np
.
testing
.
assert_equal
(
b
.
grad
.
numpy
(),
[
1
])
def
test_attach_in_with_block
():
def
test_attach_in_with_block
():
a
=
mge
.
Parameter
([
1.0
])
a
=
mge
.
Parameter
([
1.0
])
g
=
ad
.
GradManager
()
g
m
=
GradManager
()
with
g
:
with
g
m
:
b
=
a
*
3
b
=
a
*
3
g
.
attach
(
b
)
g
m
.
attach
(
b
)
c
=
b
+
1
c
=
b
+
1
g
.
backward
(
c
)
g
m
.
backward
(
c
)
assert
int
(
b
.
grad
.
numpy
())
==
1
assert
int
(
b
.
grad
.
numpy
())
==
1
imperative/python/test/unit/module/test_qat.py
浏览文件 @
2627e1f7
...
@@ -27,8 +27,6 @@ def test_qat_convbn2d():
...
@@ -27,8 +27,6 @@ def test_qat_convbn2d():
disable_fake_quant
(
qat_module
)
disable_fake_quant
(
qat_module
)
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_channels
,
32
,
32
).
astype
(
np
.
float32
))
inputs
=
tensor
(
np
.
random
.
randn
(
4
,
in_channels
,
32
,
32
).
astype
(
np
.
float32
))
normal_outputs
=
module
(
inputs
)
normal_outputs
=
module
(
inputs
)
# import pdb
# pdb.set_trace()
qat_outputs
=
qat_module
(
inputs
)
qat_outputs
=
qat_module
(
inputs
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
5e-6
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
(),
atol
=
5e-6
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录