Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1f2ca74c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1f2ca74c
编写于
4月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!709 fix weight decay error in optimizer AdamWeightDecay
Merge pull request !709 from wangnan39/fix_bug_in_adamweightdecay
上级
b3bea9d8
ddc558fd
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
61 addition
and
21 deletion
+61
-21
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+46
-14
mindspore/nn/optim/lamb.py
mindspore/nn/optim/lamb.py
+8
-6
mindspore/nn/optim/sgd.py
mindspore/nn/optim/sgd.py
+7
-1
未找到文件。
mindspore/nn/optim/adam.py
浏览文件 @
1f2ca74c
...
...
@@ -31,8 +31,8 @@ _learning_rate_update_func = ['linear', 'cos', 'sin']
adam_opt
=
C
.
MultitypeFuncGraph
(
"adam_opt"
)
@
adam_opt
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_update_run_op
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
):
@
adam_opt
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
def
_update_run_op
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
,
decay_flag
):
"""
Update parameters.
...
...
@@ -67,7 +67,8 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
next_v
=
op_mul
(
beta2
,
v
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta2
,
op_square
(
gradient
))
update
=
next_m
/
(
op_sqrt
(
next_v
)
+
eps
)
update
=
update
+
op_mul
(
weight_decay_tensor
,
param
)
if
decay_flag
:
update
=
update
+
op_mul
(
weight_decay_tensor
,
param
)
update_with_lr
=
op_mul
(
lr
,
update
)
next_param
=
param
-
op_reshape
(
update_with_lr
,
op_shape
(
param
))
...
...
@@ -90,6 +91,17 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
validator
.
check_number_range
(
"weight_decay"
,
weight_decay
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
prim_name
)
def
_check_learning_rate_value
(
learning_rate
,
end_learning_rate
,
decay_steps
,
power
,
prim_name
):
"""Check the type of inputs."""
validator
.
check_float_positive
(
'learning_rate'
,
learning_rate
,
prim_name
)
validator
.
check_float_legal_value
(
'learning_rate'
,
learning_rate
,
prim_name
)
validator
.
check_float_positive
(
'end_learning_rate'
,
end_learning_rate
,
prim_name
)
validator
.
check_float_legal_value
(
'end_learning_rate'
,
end_learning_rate
,
prim_name
)
validator
.
check_float_positive
(
'power'
,
power
,
prim_name
)
validator
.
check_float_legal_value
(
'power'
,
power
,
prim_name
)
validator
.
check_integer
(
'decay_steps'
,
decay_steps
,
0
,
Rel
.
GT
,
prim_name
)
@
adam_opt
.
register
(
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_run_opt_with_one_number
(
opt
,
lr
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
gradient
,
params
,
moment1
,
...
...
@@ -126,8 +138,13 @@ class Adam(Optimizer):
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
learning_rate (Union[float, Tensor, Iterable]): The Learning rate.
Iterable type is used for the dynamic learning rate.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
...
...
@@ -140,6 +157,8 @@ class Adam(Optimizer):
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
Should be equal to or greater than 1.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
...
...
@@ -207,7 +226,13 @@ class AdamWeightDecay(Optimizer):
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
learning_rate (float): A floating point value for the learning rate. Default: 1e-3.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
...
...
@@ -215,6 +240,8 @@ class AdamWeightDecay(Optimizer):
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
...
...
@@ -228,10 +255,10 @@ class AdamWeightDecay(Optimizer):
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-6
,
weight_decay
=
0.0
):
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-6
,
weight_decay
=
0.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
AdamWeightDecay
,
self
).
__init__
(
learning_rate
,
params
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
self
.
cls_name
)
self
.
lr
=
Tensor
(
np
.
array
([
learning_rate
]).
astype
(
np
.
float32
))
self
.
beta1
=
Tensor
(
np
.
array
([
beta1
]).
astype
(
np
.
float32
))
self
.
beta2
=
Tensor
(
np
.
array
([
beta2
]).
astype
(
np
.
float32
))
self
.
eps
=
Tensor
(
np
.
array
([
eps
]).
astype
(
np
.
float32
))
...
...
@@ -240,13 +267,15 @@ class AdamWeightDecay(Optimizer):
self
.
params
=
self
.
parameters
self
.
moments1
=
self
.
params
.
clone
(
prefix
=
"adam_m"
,
init
=
'zeros'
)
self
.
moments2
=
self
.
params
.
clone
(
prefix
=
"adam_v"
,
init
=
'zeros'
)
self
.
decay_flag
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
params
)
self
.
hyper_map
=
C
.
HyperMap
()
def
construct
(
self
,
gradients
):
updated_velocity
=
self
.
hyper_map
(
F
.
partial
(
adam_opt
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
self
.
lr
,
lr
=
self
.
get_lr
()
updated_velocity
=
self
.
hyper_map
(
F
.
partial
(
adam_opt
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
weight_decay_tensor
),
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
)
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
)
return
updated_velocity
...
...
@@ -269,6 +298,8 @@ class AdamWeightDecayDynamicLR(Optimizer):
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
...
...
@@ -291,10 +322,11 @@ class AdamWeightDecayDynamicLR(Optimizer):
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-6
,
weight_decay
=
0.0
):
weight_decay
=
0.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
AdamWeightDecayDynamicLR
,
self
).
__init__
(
learning_rate
,
params
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
self
.
cls_name
)
_check_learning_rate_value
(
learning_rate
,
end_learning_rate
,
decay_steps
,
power
,
self
.
cls_name
)
# turn them to scalar when me support scalar/tensor mix operations
self
.
global_step
=
Parameter
(
initializer
(
0
,
[
1
]),
name
=
"global_step"
)
self
.
decay_steps
=
Tensor
(
np
.
array
([
decay_steps
]).
astype
(
np
.
float32
))
...
...
@@ -308,7 +340,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
self
.
params
=
self
.
parameters
self
.
moments1
=
self
.
params
.
clone
(
prefix
=
"adam_m"
,
init
=
'zeros'
)
self
.
moments2
=
self
.
params
.
clone
(
prefix
=
"adam_v"
,
init
=
'zeros'
)
self
.
decay_flag
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
params
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
min
=
P
.
Minimum
()
self
.
pow
=
P
.
Pow
()
...
...
@@ -320,7 +352,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
lr
=
self
.
diff_learning_rate
*
self
.
pow
(
self
.
one
-
p
,
self
.
power
)
+
self
.
end_learning_rate
updated_velocity
=
self
.
hyper_map
(
F
.
partial
(
adam_opt
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
weight_decay_tensor
),
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
)
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
)
added_global_step
=
self
.
global_step
+
self
.
one
F
.
control_depend
(
lr
,
added_global_step
)
...
...
mindspore/nn/optim/lamb.py
浏览文件 @
1f2ca74c
...
...
@@ -112,16 +112,18 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
end_learning_rate
,
power
,
beta1
,
beta2
,
eps
,
weight_decay
,
prim_name
):
"""Check the type of inputs."""
validator
.
check_value_type
(
"decay_steps"
,
decay_steps
,
[
int
],
prim_name
)
validator
.
check_value_type
(
"warmup_steps"
,
warmup_steps
,
[
int
],
prim_name
)
validator
.
check_value_type
(
"start_learning_rate"
,
start_learning_rate
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"end_learning_rate"
,
end_learning_rate
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"power"
,
power
,
[
float
],
prim_name
)
validator
.
check_float_positive
(
'start_learning_rate'
,
start_learning_rate
,
prim_name
)
validator
.
check_float_legal_value
(
'start_learning_rate'
,
start_learning_rate
,
prim_name
)
validator
.
check_float_positive
(
'end_learning_rate'
,
end_learning_rate
,
prim_name
)
validator
.
check_float_legal_value
(
'end_learning_rate'
,
end_learning_rate
,
prim_name
)
validator
.
check_float_positive
(
'power'
,
power
,
prim_name
)
validator
.
check_float_legal_value
(
'power'
,
power
,
prim_name
)
validator
.
check_integer
(
'decay_steps'
,
decay_steps
,
0
,
Rel
.
GT
,
prim_name
)
validator
.
check_integer
(
'warmup_steps'
,
decay_steps
,
0
,
Rel
.
GT
,
prim_name
)
validator
.
check_value_type
(
"beta1"
,
beta1
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"beta2"
,
beta2
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"eps"
,
eps
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"weight_dacay"
,
weight_decay
,
[
float
],
prim_name
)
validator
.
check_number_range
(
"decay_steps"
,
decay_steps
,
1
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
prim_name
)
validator
.
check_number_range
(
"beta1"
,
beta1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"beta2"
,
beta2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"eps"
,
eps
,
0.0
,
float
(
"inf"
),
Rel
.
INC_NEITHER
,
prim_name
)
...
...
mindspore/nn/optim/sgd.py
浏览文件 @
1f2ca74c
...
...
@@ -42,7 +42,13 @@ class SGD(Optimizer):
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
learning_rate (float): A floating point value for the learning rate. Default: 0.1.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 0.1.
momentum (float): A floating point value the momentum. Default: 0.
dampening (float): A floating point value of dampening for momentum. Default: 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录