Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4e95c136
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4e95c136
编写于
9月 02, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(sgd): sgd supports nesterov momentum
GitOrigin-RevId: 13eda179da9b79573f692916a02b2a51a4449a14
上级
ff431e72
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
23 addition
and
13 deletion
+23
-13
imperative/python/megengine/optimizer/sgd.py
imperative/python/megengine/optimizer/sgd.py
+17
-11
imperative/python/test/integration/test_optimizer.py
imperative/python/test/integration/test_optimizer.py
+6
-2
未找到文件。
imperative/python/megengine/optimizer/sgd.py
浏览文件 @
4e95c136
...
...
@@ -16,7 +16,7 @@ from .optimizer import Optimizer
class
SGD
(
Optimizer
):
r
"""Implements stochastic gradient descent.
Nesterov momentum is based on the formula from
`"On the importance of initialization and momentum in deep learning" <http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf>`_ .
...
...
@@ -25,6 +25,7 @@ class SGD(Optimizer):
parameter groups.
lr: learning rate.
momentum: momentum factor. Default: 0.0
nesterov: enables Nesterov momentum. Default: False
weight_decay: weight decay (L2 penalty). Default: 0.0
"""
...
...
@@ -33,6 +34,7 @@ class SGD(Optimizer):
params
:
Union
[
Iterable
[
Parameter
],
dict
],
lr
:
float
,
momentum
:
float
=
0.0
,
nesterov
:
bool
=
False
,
weight_decay
:
float
=
0.0
,
):
assert
lr
>=
0.0
,
"Invalid learning rate: {}"
.
format
(
lr
)
...
...
@@ -40,9 +42,11 @@ class SGD(Optimizer):
assert
weight_decay
>=
0.0
,
"Invalid weight_decay value: {}"
.
format
(
weight_decay
)
assert
not
nesterov
or
momentum
>
0.0
,
"Nesterov momentum requires a momentum"
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
weight_decay
=
weight_decay
)
super
().
__init__
(
params
,
defaults
)
self
.
nesterov
=
nesterov
self
.
_disable_type_convert
=
True
def
_create_state
(
self
,
param_group
):
...
...
@@ -76,20 +80,22 @@ class SGD(Optimizer):
grad
=
grad
+
param
*
_weight_decay
if
inplace_mode
:
if
momentum
:
if
momentum
!=
0.0
:
v
=
self
.
_state
[
param
][
"momentum_buffer"
]
_inplace_add_
(
v
,
grad
,
alpha
=
_momentum
,
beta
=
c1
)
_inplace_add_
(
param
,
v
,
alpha
=
c1
,
beta
=
_neg_lr
)
else
:
_inplace_add_
(
param
,
grad
,
alpha
=
c1
,
beta
=
_neg_lr
)
if
self
.
nesterov
:
grad
=
grad
+
v
*
_momentum
else
:
grad
=
v
_inplace_add_
(
param
,
grad
,
alpha
=
c1
,
beta
=
_neg_lr
)
continue
if
momentum
:
if
momentum
!=
0.0
:
v
=
self
.
_state
[
param
][
"momentum_buffer"
]
# v = v * _momentum + grad
v
*=
_momentum
v
+=
grad
param
-=
_lr
*
v
else
:
param
-=
_lr
*
grad
if
self
.
nesterov
:
grad
=
grad
+
v
*
_momentum
else
:
grad
=
v
param
-=
_lr
*
grad
imperative/python/test/integration/test_optimizer.py
浏览文件 @
4e95c136
...
...
@@ -124,6 +124,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
"case"
,
[
{
"momentum"
:
0.9
,
"lr"
:
0.01
},
# SGD with momentum
{
"momentum"
:
0.9
,
"lr"
:
0.01
,
"nesterov"
:
True
},
# with nesterov momentum
{
"lr"
:
0.01
},
# simple SGD
{
"weight_decay"
:
0.1
,
"lr"
:
0.01
},
# with weight_decay
],
...
...
@@ -144,9 +145,12 @@ def test_sgd(monkeypatch, case, update_lr, inplace_mode):
grad
=
param
.
grad
.
numpy
()
if
hasattr
(
self
,
"weight_decay"
)
and
self
.
weight_decay
!=
0.0
:
grad
=
grad
+
ori_params
[
param
]
*
self
.
weight_decay
if
hasattr
(
self
,
"momentum"
):
if
hasattr
(
self
,
"momentum"
)
and
self
.
momentum
!=
0.0
:
self
.
slots
[
param
]
=
grad
+
self
.
slots
[
param
]
*
self
.
momentum
delta
=
-
self
.
lr
*
self
.
slots
[
param
]
if
hasattr
(
self
,
"nesterov"
)
and
self
.
nesterov
:
delta
=
-
self
.
lr
*
(
grad
+
self
.
slots
[
param
]
*
self
.
momentum
)
else
:
delta
=
-
self
.
lr
*
self
.
slots
[
param
]
else
:
delta
=
-
self
.
lr
*
grad
np
.
testing
.
assert_almost_equal
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录