Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d4dead93
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d4dead93
编写于
4月 03, 2020
作者:
Z
zhaoting
提交者:
高东海
4月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add weight decay in RMSProp optimizer
上级
a62a9f43
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
12 addition
and
2 deletion
+12
-2
mindspore/nn/optim/rmsprop.py
mindspore/nn/optim/rmsprop.py
+12
-2
未找到文件。
mindspore/nn/optim/rmsprop.py
浏览文件 @
d4dead93
...
...
@@ -18,7 +18,8 @@ from mindspore.common.initializer import initializer
from
mindspore.common.parameter
import
Parameter
from
mindspore._checkparam
import
ParamValidator
as
validator
import
mindspore.common.dtype
as
mstype
from
.optimizer
import
Optimizer
,
grad_scale
from
mindspore.common
import
Tensor
from
.optimizer
import
Optimizer
,
grad_scale
,
apply_decay
rmsprop_opt
=
C
.
MultitypeFuncGraph
(
"rmsprop_opt"
)
centered_rmsprop_opt
=
C
.
MultitypeFuncGraph
(
"rmsprop_opt"
)
...
...
@@ -118,6 +119,9 @@ class RMSProp(Optimizer):
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'beta' not in x.name and 'gamma' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
...
...
@@ -132,7 +136,8 @@ class RMSProp(Optimizer):
>>> model = Model(net, loss, opt)
"""
def
__init__
(
self
,
params
,
learning_rate
=
0.1
,
decay
=
0.9
,
momentum
=
0.0
,
epsilon
=
1e-10
,
use_locking
=
False
,
centered
=
False
,
loss_scale
=
1.0
):
use_locking
=
False
,
centered
=
False
,
loss_scale
=
1.0
,
weight_decay
=
0.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
RMSProp
,
self
).
__init__
(
learning_rate
,
params
)
if
isinstance
(
momentum
,
float
)
and
momentum
<
0.0
:
...
...
@@ -159,6 +164,7 @@ class RMSProp(Optimizer):
self
.
assignadd
=
P
.
AssignAdd
()
self
.
global_step
=
Parameter
(
initializer
(
0
,
[
1
],
mstype
.
int32
),
name
=
"global_step"
)
self
.
axis
=
0
self
.
one
=
Tensor
(
1
,
mstype
.
int32
)
self
.
momentum
=
momentum
...
...
@@ -167,10 +173,14 @@ class RMSProp(Optimizer):
self
.
hyper_map
=
C
.
HyperMap
()
self
.
decay
=
decay
self
.
decay_tf
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
parameters
)
self
.
reciprocal_scale
=
1.0
/
loss_scale
self
.
weight_decay
=
weight_decay
*
loss_scale
def
construct
(
self
,
gradients
):
params
=
self
.
parameters
if
self
.
weight_decay
>
0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
apply_decay
,
self
.
weight_decay
),
self
.
decay_tf
,
params
,
gradients
)
if
self
.
reciprocal_scale
!=
1.0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
self
.
reciprocal_scale
),
gradients
)
if
self
.
dynamic_lr
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录