Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c95215bc
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c95215bc
编写于
5月 11, 2020
作者:
G
guohongzilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
seperate lr groups and weight_decay groups
上级
3c4c0da8
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
25 addition
and
14 deletion
+25
-14
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+1
-1
mindspore/nn/optim/momentum.py
mindspore/nn/optim/momentum.py
+1
-1
mindspore/nn/optim/optimizer.py
mindspore/nn/optim/optimizer.py
+12
-3
mindspore/nn/optim/rmsprop.py
mindspore/nn/optim/rmsprop.py
+2
-2
mindspore/nn/optim/sgd.py
mindspore/nn/optim/sgd.py
+1
-1
tests/ut/python/optimizer/test_optimize_with_parameter_groups.py
...t/python/optimizer/test_optimize_with_parameter_groups.py
+8
-6
未找到文件。
mindspore/nn/optim/adam.py
浏览文件 @
c95215bc
...
...
@@ -243,7 +243,7 @@ class Adam(Optimizer):
self
.
beta1_power
=
beta1_power
beta2_power
=
self
.
beta2_power
*
self
.
beta2
self
.
beta2_power
=
beta2_power
if
self
.
is_group
:
if
self
.
is_group
_lr
:
success
=
self
.
hyper_map
(
F
.
partial
(
adam_opt
,
self
.
opt
,
beta1_power
,
beta2_power
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
),
lr
,
gradients
,
params
,
moment1
,
moment2
)
...
...
mindspore/nn/optim/momentum.py
浏览文件 @
c95215bc
...
...
@@ -111,7 +111,7 @@ class Momentum(Optimizer):
gradients
=
self
.
decay_weight
(
gradients
)
gradients
=
self
.
scale_grad
(
gradients
)
lr
=
self
.
get_lr
()
if
self
.
is_group
:
if
self
.
is_group
_lr
:
success
=
self
.
hyper_map
(
F
.
partial
(
momentum_opt
,
self
.
opt
,
self
.
momentum
),
lr
,
gradients
,
params
,
moments
)
else
:
success
=
self
.
hyper_map
(
F
.
partial
(
momentum_opt
,
self
.
opt
,
self
.
momentum
,
lr
),
gradients
,
params
,
moments
)
...
...
mindspore/nn/optim/optimizer.py
浏览文件 @
c95215bc
...
...
@@ -94,6 +94,7 @@ class Optimizer(Cell):
validator
.
check_number_range
(
"weight_decay"
,
weight_decay
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
None
)
self
.
is_group
=
False
self
.
is_group_lr
=
False
self
.
loss_scale
=
loss_scale
if
isinstance
(
learning_rate
,
float
):
self
.
dynamic_lr
=
False
...
...
@@ -116,14 +117,17 @@ class Optimizer(Cell):
self
.
group_weight_decay
=
[]
self
.
_init_group_params
(
parameters
,
learning_rate
,
weight_decay
)
if
self
.
is_group
:
if
self
.
is_group
_lr
:
self
.
learning_rate
=
ParameterTuple
(
self
.
group_lr
)
else
:
self
.
learning_rate
=
Parameter
(
learning_rate
,
name
=
"learning_rate"
)
if
self
.
is_group
:
self
.
parameters
=
ParameterTuple
(
self
.
params
)
self
.
weight_decay
=
tuple
(
self
.
group_weight_decay
)
decay_filter
=
lambda
x
:
x
>
0
self
.
decay_flags
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
weight_decay
)
else
:
self
.
learning_rate
=
Parameter
(
learning_rate
,
name
=
"learning_rate"
)
self
.
parameters
=
ParameterTuple
(
parameters
)
self
.
weight_decay
=
weight_decay
*
loss_scale
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
...
...
@@ -207,6 +211,7 @@ class Optimizer(Cell):
for
group_param
in
parameters
:
lr_length
=
dynamic_lr_length
if
'lr'
in
group_param
.
keys
():
self
.
is_group_lr
=
True
self
.
_get_single_lr
(
group_param
[
'lr'
])
if
isinstance
(
group_param
[
'lr'
],
Iterable
):
lr_length
=
len
(
group_param
[
'lr'
])
...
...
@@ -247,6 +252,10 @@ class Optimizer(Cell):
else
:
weight_decay_
=
weight_decay
*
self
.
loss_scale
for
key
in
group_param
.
keys
():
if
key
not
in
(
'params'
,
'lr'
,
'weight_decay'
):
logger
.
warning
(
f
"The optimizer cannot parse '
{
key
}
' when setting parameter groups."
)
for
param
in
group_param
[
'params'
]:
if
param
in
params_store
:
raise
RuntimeError
(
f
"The
{
param
.
name
}
parameter has appeared in parameter groups."
)
...
...
@@ -261,7 +270,7 @@ class Optimizer(Cell):
Returns:
float, the learning rate of current step.
"""
if
self
.
is_group
:
if
self
.
is_group
_lr
:
lr
=
self
.
learning_rate
if
self
.
dynamic_lr
:
lr
=
()
...
...
mindspore/nn/optim/rmsprop.py
浏览文件 @
c95215bc
...
...
@@ -176,7 +176,7 @@ class RMSProp(Optimizer):
gradients
=
self
.
scale_grad
(
gradients
)
lr
=
self
.
get_lr
()
if
self
.
centered
:
if
self
.
is_group
:
if
self
.
is_group
_lr
:
success
=
self
.
hyper_map
(
F
.
partial
(
centered_rmsprop_opt
,
self
.
opt
,
self
.
decay
,
self
.
epsilon
,
self
.
momentum
),
lr
,
params
,
self
.
mg
,
self
.
ms
,
self
.
moment
,
gradients
)
else
:
...
...
@@ -184,7 +184,7 @@ class RMSProp(Optimizer):
self
.
momentum
,
lr
),
params
,
self
.
mg
,
self
.
ms
,
self
.
moment
,
gradients
)
else
:
if
self
.
is_group
:
if
self
.
is_group
_lr
:
success
=
self
.
hyper_map
(
F
.
partial
(
rmsprop_opt
,
self
.
opt
,
self
.
decay
,
self
.
epsilon
,
self
.
momentum
),
lr
,
params
,
self
.
ms
,
self
.
moment
,
gradients
)
else
:
...
...
mindspore/nn/optim/sgd.py
浏览文件 @
c95215bc
...
...
@@ -139,7 +139,7 @@ class SGD(Optimizer):
gradients
=
self
.
decay_weight
(
gradients
)
gradients
=
self
.
scale_grad
(
gradients
)
lr
=
self
.
get_lr
()
if
self
.
is_group
:
if
self
.
is_group
_lr
:
success
=
self
.
hyper_map
(
F
.
partial
(
sgd_opt
,
self
.
opt
,
self
.
momentum
),
lr
,
gradients
,
params
,
accum
,
stat
)
else
:
success
=
self
.
hyper_map
(
F
.
partial
(
sgd_opt
,
self
.
opt
,
self
.
momentum
,
lr
),
gradients
,
params
,
accum
,
stat
)
...
...
tests/ut/python/optimizer/test_optimize_with_parameter_groups.py
浏览文件 @
c95215bc
...
...
@@ -65,12 +65,13 @@ def test_group_lr():
opt
=
Momentum
(
group_params
,
learning_rate
=
default_lr
,
momentum
=
0.9
)
assert
opt
.
is_group
is
True
assert
opt
.
is_group_lr
is
True
assert
opt
.
dynamic_lr
is
False
for
lr
,
param
in
zip
(
opt
.
learning_rate
,
opt
.
parameters
):
if
param
in
conv_params
:
assert
lr
.
data
==
Tensor
(
conv_lr
,
mstype
.
float32
)
assert
np
.
all
(
lr
.
data
.
asnumpy
()
==
Tensor
(
conv_lr
,
mstype
.
float32
).
asnumpy
()
)
else
:
assert
lr
.
data
==
Tensor
(
default_lr
,
mstype
.
float32
)
assert
np
.
all
(
lr
.
data
.
asnumpy
()
==
Tensor
(
default_lr
,
mstype
.
float32
).
asnumpy
()
)
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
opt
)
...
...
@@ -96,9 +97,9 @@ def test_group_dynamic_1():
assert
opt
.
dynamic_lr
is
True
for
lr
,
param
in
zip
(
opt
.
learning_rate
,
opt
.
parameters
):
if
param
in
conv_params
:
assert
lr
.
data
==
Tensor
(
np
.
array
([
conv_lr
]
*
3
).
astype
(
np
.
float32
))
assert
np
.
all
(
lr
.
data
.
asnumpy
()
==
Tensor
(
np
.
array
([
conv_lr
]
*
3
).
astype
(
np
.
float32
)).
asnumpy
(
))
else
:
assert
lr
.
data
==
Tensor
(
np
.
array
(
list
(
default_lr
)).
astype
(
np
.
float32
))
assert
np
.
all
(
lr
.
data
.
asnumpy
()
==
Tensor
(
np
.
array
(
list
(
default_lr
)).
astype
(
np
.
float32
)).
asnumpy
(
))
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
opt
)
...
...
@@ -124,9 +125,9 @@ def test_group_dynamic_2():
assert
opt
.
dynamic_lr
is
True
for
lr
,
param
in
zip
(
opt
.
learning_rate
,
opt
.
parameters
):
if
param
in
conv_params
:
assert
lr
.
data
==
Tensor
(
np
.
array
(
list
(
conv_lr
)).
astype
(
np
.
float32
))
assert
np
.
all
(
lr
.
data
==
Tensor
(
np
.
array
(
list
(
conv_lr
)).
astype
(
np
.
float32
)
))
else
:
assert
lr
.
data
==
Tensor
(
np
.
array
([
default_lr
]
*
3
).
astype
(
np
.
float32
))
assert
np
.
all
(
lr
.
data
==
Tensor
(
np
.
array
([
default_lr
]
*
3
).
astype
(
np
.
float32
)
))
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
opt
)
...
...
@@ -184,6 +185,7 @@ def test_weight_decay():
opt
=
SGD
(
group_params
,
learning_rate
=
0.1
,
weight_decay
=
default_weight_decay
)
assert
opt
.
is_group
is
True
assert
opt
.
is_group_lr
is
False
for
weight_decay
,
decay_flags
,
param
in
zip
(
opt
.
weight_decay
,
opt
.
decay_flags
,
opt
.
parameters
):
if
param
in
conv_params
:
assert
weight_decay
==
conv_weight_decay
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录