Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
713ed15b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
713ed15b
编写于
5月 25, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 25, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1332 Fix some bugs for optimizer.
Merge pull request !1332 from liuxiao/fix-bug-for-optimizer
上级
28029388
e7a7de2c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
19 addition
and
6 deletion
+19
-6
mindspore/nn/optim/optimizer.py
mindspore/nn/optim/optimizer.py
+11
-1
mindspore/nn/optim/sgd.py
mindspore/nn/optim/sgd.py
+7
-3
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+1
-2
未找到文件。
mindspore/nn/optim/optimizer.py
浏览文件 @
713ed15b
...
...
@@ -49,7 +49,15 @@ class Optimizer(Cell):
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
Args:
learning_rate (float): A floating point value for the learning rate. Should be greater than 0.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Should be greater than 0.
If the type of `learning_rate` input is int, it will be
converted to float.
parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be
updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`,
the "params", "lr" and "weight_decay" are the keys can be parsed.
...
...
@@ -96,6 +104,8 @@ class Optimizer(Cell):
self
.
is_group
=
False
self
.
is_group_lr
=
False
self
.
loss_scale
=
loss_scale
if
isinstance
(
learning_rate
,
int
):
learning_rate
=
float
(
learning_rate
)
if
isinstance
(
learning_rate
,
float
):
self
.
dynamic_lr
=
False
self
.
gather
=
None
...
...
mindspore/nn/optim/sgd.py
浏览文件 @
713ed15b
...
...
@@ -106,22 +106,26 @@ class SGD(Optimizer):
super
(
SGD
,
self
).
__init__
(
learning_rate
,
params
,
weight_decay
,
loss_scale
)
if
isinstance
(
momentum
,
int
):
momentum
=
float
(
momentum
)
if
not
isinstance
(
momentum
,
float
):
raise
TypeError
(
"momentum should be float number!"
)
if
isinstance
(
momentum
,
float
)
and
momentum
<
0.0
:
raise
ValueError
(
"momentum should be at least 0.0, but got momentum {}"
.
format
(
momentum
))
if
not
isinstance
(
dampening
,
float
):
raise
TypeError
(
"dampening should be float number"
)
if
isinstance
(
dampening
,
int
):
dampening
=
float
(
dampening
)
if
not
isinstance
(
dampening
,
float
):
raise
TypeError
(
"dampening should be float number"
)
if
dampening
<
0.0
:
raise
ValueError
(
"dampening should be at least 0.0, but got dampening {}"
.
format
(
dampening
))
self
.
dampening
=
dampening
if
isinstance
(
weight_decay
,
int
):
weight_decay
=
float
(
weight_decay
)
validator
.
check_value_type
(
"nesterov"
,
nesterov
,
[
bool
],
self
.
cls_name
)
self
.
nesterov
=
nesterov
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
713ed15b
...
...
@@ -2591,8 +2591,7 @@ class Adam(PrimitiveWithInfer):
Mean square gradients, has the same type as `var`.
- **beta1_power** (float) - :math:`beta_1^t` in the updating formula.
- **beta2_power** (float) - :math:`beta_2^t` in the updating formula.
- **lr** (Union[float, Tensor, Iterable]) - :math:`l` in the updating formula.
Iterable type is used for the dynamic learning rate.
- **lr** (float) - :math:`l` in the updating formula.
- **beta1** (float) - The exponential decay rate for the 1st moment estimates.
- **beta2** (float) - The exponential decay rate for the 2nd moment estimates.
- **epsilon** (float) - Term added to the denominator to improve numerical stability.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录