Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9faa32fc
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9faa32fc
编写于
9月 04, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/imperative): fix grad callback
GitOrigin-RevId: 6f843b0106117ca24d08efb5685cd09171197430
上级
6d4fd938
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
22 addition
and
16 deletion
+22
-16
imperative/python/megengine/autodiff/__init__.py
imperative/python/megengine/autodiff/__init__.py
+9
-0
imperative/python/megengine/autodiff/grad_manager.py
imperative/python/megengine/autodiff/grad_manager.py
+11
-4
imperative/python/megengine/core/autodiff/grad.py
imperative/python/megengine/core/autodiff/grad.py
+1
-5
imperative/python/megengine/optimizer/multi_step_lr.py
imperative/python/megengine/optimizer/multi_step_lr.py
+1
-1
imperative/python/megengine/optimizer/sgd.py
imperative/python/megengine/optimizer/sgd.py
+0
-6
未找到文件。
imperative/python/megengine/autodiff/__init__.py
0 → 100644
浏览文件 @
9faa32fc
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.grad_manager
import
GradManager
imperative/python/megengine/autodiff/grad_manager.py
浏览文件 @
9faa32fc
...
@@ -10,8 +10,8 @@ class GradManager:
...
@@ -10,8 +10,8 @@ class GradManager:
self
.
_recording
=
False
self
.
_recording
=
False
self
.
_grad
=
None
self
.
_grad
=
None
def
register
(
self
,
params
,
callback
=
None
):
def
register
(
self
,
params
,
callback
s
=
None
):
self
.
_call_back_pair
.
append
([
params
,
callback
])
self
.
_call_back_pair
.
append
([
list
(
params
),
callbacks
or
[]
])
def
backward
(
self
,
ys
,
dys
=
None
):
def
backward
(
self
,
ys
,
dys
=
None
):
if
not
self
.
_recording
:
if
not
self
.
_recording
:
...
@@ -24,7 +24,7 @@ class GradManager:
...
@@ -24,7 +24,7 @@ class GradManager:
if
not
isinstance
(
ys
,
(
tuple
,
list
)):
if
not
isinstance
(
ys
,
(
tuple
,
list
)):
ys
=
[
ys
]
ys
=
[
ys
]
if
dys
is
None
:
if
dys
is
None
:
dys
=
[
tensor
(
1
).
broadcast
(
y
.
shape
)
for
y
in
ys
]
dys
=
[
tensor
(
1
.0
)
for
y
in
ys
]
if
not
isinstance
(
dys
,
(
tuple
,
list
)):
if
not
isinstance
(
dys
,
(
tuple
,
list
)):
dys
=
[
dys
]
dys
=
[
dys
]
try
:
try
:
...
@@ -42,7 +42,14 @@ class GradManager:
...
@@ -42,7 +42,14 @@ class GradManager:
self
.
_recording
=
True
self
.
_recording
=
True
self
.
_grad
=
grad
self
.
_grad
=
grad
for
params
,
callbacks
in
self
.
_call_back_pair
:
for
params
,
callbacks
in
self
.
_call_back_pair
:
grad
.
wrt
(
*
params
,
callback
=
callbacks
)
def
callback
(
param
,
grad
,
callbacks
=
callbacks
):
ret
=
grad
for
cb
in
callbacks
:
ret
=
cb
(
param
,
ret
)
param
.
grad
=
ret
grad
.
wrt
(
*
params
,
callback
=
callback
)
with
grad
:
with
grad
:
yield
yield
finally
:
finally
:
...
...
imperative/python/megengine/core/autodiff/grad.py
浏览文件 @
9faa32fc
...
@@ -260,13 +260,9 @@ class Grad:
...
@@ -260,13 +260,9 @@ class Grad:
cache
[
v
]
=
g
cache
[
v
]
=
g
if
last_written_to
[
v
]
==
(
seqno
,
i
):
if
last_written_to
[
v
]
==
(
seqno
,
i
):
if
v
.
callback
:
if
v
.
callback
:
grad
=
v
.
callback
(
v
.
callback
(
v
.
owner
(),
Wrapper
(
cache
[
v
])
if
Wrapper
else
cache
[
v
]
v
.
owner
(),
Wrapper
(
cache
[
v
])
if
Wrapper
else
cache
[
v
]
)
)
if
getattr
(
v
.
owner
(),
"grad"
,
None
)
is
None
:
v
.
owner
().
grad
=
grad
else
:
v
.
owner
().
grad
+=
grad
if
v
.
opnode
is
None
:
if
v
.
opnode
is
None
:
# won't read by backward, mark consumed
# won't read by backward, mark consumed
cache
[
v
]
=
None
cache
[
v
]
=
None
...
...
imperative/python/megengine/optimizer/multi_step_lr.py
浏览文件 @
9faa32fc
...
@@ -9,8 +9,8 @@
...
@@ -9,8 +9,8 @@
from
bisect
import
bisect_right
from
bisect
import
bisect_right
from
typing
import
Iterable
as
Iter
from
typing
import
Iterable
as
Iter
from
.optimizer
import
Optimizer
from
.lr_scheduler
import
LRScheduler
from
.lr_scheduler
import
LRScheduler
from
.optimizer
import
Optimizer
class
MultiStepLR
(
LRScheduler
):
class
MultiStepLR
(
LRScheduler
):
...
...
imperative/python/megengine/optimizer/sgd.py
浏览文件 @
9faa32fc
...
@@ -53,10 +53,6 @@ class SGD(Optimizer):
...
@@ -53,10 +53,6 @@ class SGD(Optimizer):
for
param
in
param_group
[
"params"
]:
for
param
in
param_group
[
"params"
]:
if
param
.
__wrapped__
in
self
.
_grad_skip
:
self
.
_grad_skip
.
remove
(
param
.
__wrapped__
)
continue
if
not
isinstance
(
param
.
grad
,
Buffer
):
if
not
isinstance
(
param
.
grad
,
Buffer
):
raise
TypeError
(
raise
TypeError
(
"grad must be a Buffer, maybe you forget to call backward()?"
"grad must be a Buffer, maybe you forget to call backward()?"
...
@@ -76,5 +72,3 @@ class SGD(Optimizer):
...
@@ -76,5 +72,3 @@ class SGD(Optimizer):
self
.
_state
[
param
][
"momentum_buffer"
].
_reset
(
v
)
self
.
_state
[
param
][
"momentum_buffer"
].
_reset
(
v
)
else
:
else
:
param
-=
lr
*
grad
param
-=
lr
*
grad
assert
len
(
self
.
_grad_skip
)
==
0
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录