Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
553432c9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
553432c9
编写于
6月 12, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 12, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1958 Fix some description to API about optimizer.
Merge pull request !1958 from liuxiao/fix-for-issuse
上级
c27d4157
52790b74
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
24 addition
and
20 deletion
+24
-20
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+9
-8
mindspore/nn/optim/lamb.py
mindspore/nn/optim/lamb.py
+6
-4
mindspore/nn/optim/lars.py
mindspore/nn/optim/lars.py
+2
-2
mindspore/nn/optim/sgd.py
mindspore/nn/optim/sgd.py
+5
-4
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+2
-2
未找到文件。
mindspore/nn/optim/adam.py
浏览文件 @
553432c9
...
@@ -162,13 +162,14 @@ class Adam(Optimizer):
...
@@ -162,13 +162,14 @@ class Adam(Optimizer):
in the value of 'order_params' but not in any group will use default learning rate and default weight
in the value of 'order_params' but not in any group will use default learning rate and default weight
decay.
decay.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
learning_rate (Union[int, float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
When the learning_rate is float or learning_rate is a
but the dims of the Tensor is 0, use fixed learning rate.
Tensor but the dims of the Tensor is 0, use fixed learning
Other cases are not supported. Default: 1e-3.
rate. Other cases are not supported. It should be equal to
or greater than 0. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0). Default:
0.9.
0.9.
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0). Default:
...
@@ -181,7 +182,7 @@ class Adam(Optimizer):
...
@@ -181,7 +182,7 @@ class Adam(Optimizer):
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
If True, updates the gradients using NAG.
If True, updates the gradients using NAG.
If False, updates the gradients without using NAG. Default: False.
If False, updates the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
weight_decay (float): Weight decay (L2 penalty).
It should be equal to or greater than 0.
Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
Inputs:
Inputs:
...
...
mindspore/nn/optim/lamb.py
浏览文件 @
553432c9
...
@@ -143,10 +143,12 @@ class Lamb(Optimizer):
...
@@ -143,10 +143,12 @@ class Lamb(Optimizer):
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
should be class mindspore.Parameter.
decay_steps (int): The steps of the lr decay. Should be equal to or greater than 1.
decay_steps (int): The steps of the lr decay. Should be equal to or greater than 1.
warmup_steps (int): The steps of lr warm up. Default: 0.
warmup_steps (int): The steps of lr warm up. Should be equal to or greater than 0. Default: 0.
start_learning_rate (float): A floating point value for the learning rate. Default: 0.1.
start_learning_rate (float): A floating point value for the learning rate. Should be equal to
end_learning_rate (float): A floating point value for the end learning rate. Default: 0.0001.
or greater than 0. Default: 0.1.
power (float): The power of the polynomial. Default: 1.0.
end_learning_rate (float): A floating point value for the end learning rate. Should be equal to
or greater than 0. Default: 0.0001.
power (float): The power of the polynomial. It must be positive. Default: 1.0.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
Should be in range (0.0, 1.0).
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
...
...
mindspore/nn/optim/lars.py
浏览文件 @
553432c9
...
@@ -59,13 +59,13 @@ class LARS(Optimizer):
...
@@ -59,13 +59,13 @@ class LARS(Optimizer):
optimizer (Optimizer): MindSpore optimizer for which to wrap and modify gradients.
optimizer (Optimizer): MindSpore optimizer for which to wrap and modify gradients.
epsilon (float): Term added to the denominator to improve numerical stability. Default: 1e-05.
epsilon (float): Term added to the denominator to improve numerical stability. Default: 1e-05.
hyperpara (float): Trust coefficient for calculating the local learning rate. Default: 0.001.
hyperpara (float): Trust coefficient for calculating the local learning rate. Default: 0.001.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
weight_decay (float): Weight decay (L2 penalty).
It should be equal to or greater than 0.
Default: 0.0.
use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: False.
use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: False.
decay_filter (Function): A function to determine whether apply weight decay on parameters. Default:
decay_filter (Function): A function to determine whether apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
lars_filter (Function): A function to determine whether apply lars algorithm. Default:
lars_filter (Function): A function to determine whether apply lars algorithm. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
loss_scale (float): A floating point value for the loss scale.
It should be greater than 0.
Default: 1.0.
Inputs:
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is
- **gradients** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is
...
...
mindspore/nn/optim/sgd.py
浏览文件 @
553432c9
...
@@ -73,10 +73,11 @@ class SGD(Optimizer):
...
@@ -73,10 +73,11 @@ class SGD(Optimizer):
take the i-th value as the learning rate.
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 0.1.
Other cases are not supported. It should be equal to or
momentum (float): A floating point value the momentum. Default: 0.0.
greater than 0. Default: 0.1.
dampening (float): A floating point value of dampening for momentum. Default: 0.0.
momentum (float): A floating point value the momentum. should be at least 0.0. Default: 0.0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
dampening (float): A floating point value of dampening for momentum. should be at least 0.0. Default: 0.0.
weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0.
nesterov (bool): Enables the Nesterov momentum. Default: False.
nesterov (bool): Enables the Nesterov momentum. Default: False.
loss_scale (float): A floating point value for the loss scale, which should be larger
loss_scale (float): A floating point value for the loss scale, which should be larger
than 0.0. Default: 1.0.
than 0.0. Default: 1.0.
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
553432c9
...
@@ -3465,7 +3465,7 @@ class SparseApplyFtrl(PrimitiveWithInfer):
...
@@ -3465,7 +3465,7 @@ class SparseApplyFtrl(PrimitiveWithInfer):
validator
.
check_value_type
(
"l1"
,
l1
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"l1"
,
l1
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"l2"
,
l2
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"l2"
,
l2
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"lr_power"
,
lr_power
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"lr_power"
,
lr_power
,
[
float
],
self
.
name
)
self
.
lr
=
validator
.
check_number_range
(
"lr"
,
lr
,
0.0
,
float
(
"inf"
),
Rel
.
INC_
LEFT
,
self
.
name
)
self
.
lr
=
validator
.
check_number_range
(
"lr"
,
lr
,
0.0
,
float
(
"inf"
),
Rel
.
INC_
NEITHER
,
self
.
name
)
self
.
l1
=
validator
.
check_number
(
"l1"
,
l1
,
0.0
,
Rel
.
GE
,
self
.
name
)
self
.
l1
=
validator
.
check_number
(
"l1"
,
l1
,
0.0
,
Rel
.
GE
,
self
.
name
)
self
.
l2
=
validator
.
check_number
(
"l2"
,
l2
,
0.0
,
Rel
.
GE
,
self
.
name
)
self
.
l2
=
validator
.
check_number
(
"l2"
,
l2
,
0.0
,
Rel
.
GE
,
self
.
name
)
self
.
lr_power
=
validator
.
check_number
(
"lr_power"
,
lr_power
,
0
,
Rel
.
LE
,
self
.
name
)
self
.
lr_power
=
validator
.
check_number
(
"lr_power"
,
lr_power
,
0
,
Rel
.
LE
,
self
.
name
)
...
@@ -3656,7 +3656,7 @@ class CTCLoss(PrimitiveWithInfer):
...
@@ -3656,7 +3656,7 @@ class CTCLoss(PrimitiveWithInfer):
"""
"""
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
preprocess_collapse_repeated
=
False
,
ctc_merge_repeated
=
Fals
e
,
def
__init__
(
self
,
preprocess_collapse_repeated
=
False
,
ctc_merge_repeated
=
Tru
e
,
ignore_longer_outputs_than_inputs
=
False
):
ignore_longer_outputs_than_inputs
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
"inputs"
,
"labels_indices"
,
"labels_values"
,
"sequence_length"
],
self
.
init_prim_io_names
(
inputs
=
[
"inputs"
,
"labels_indices"
,
"labels_values"
,
"sequence_length"
],
outputs
=
[
"loss"
,
"gradient"
])
outputs
=
[
"loss"
,
"gradient"
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录