Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e70b2f54
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e70b2f54
编写于
5月 11, 2020
作者:
G
guohongzilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add optimizer.get_lr_parameter() method
上级
fd72534a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
69 addition
and
0 deletion
+69
-0
mindspore/nn/optim/optimizer.py
mindspore/nn/optim/optimizer.py
+31
-0
tests/ut/python/optimizer/test_optimize_with_parameter_groups.py
...t/python/optimizer/test_optimize_with_parameter_groups.py
+38
-0
未找到文件。
mindspore/nn/optim/optimizer.py
浏览文件 @
e70b2f54
...
...
@@ -257,6 +257,7 @@ class Optimizer(Cell):
logger
.
warning
(
f
"The optimizer cannot parse '
{
key
}
' when setting parameter groups."
)
for
param
in
group_param
[
'params'
]:
validator
.
check_value_type
(
"parameter"
,
param
,
[
Parameter
],
self
.
cls_name
)
if
param
in
params_store
:
raise
RuntimeError
(
f
"The
{
param
.
name
}
parameter has appeared in parameter groups."
)
params_store
.
append
(
param
)
...
...
@@ -286,6 +287,36 @@ class Optimizer(Cell):
F
.
control_depend
(
lr
,
self
.
assignadd
(
self
.
global_step
,
1
))
return
lr
def
get_lr_parameter
(
self
,
param
):
"""
Get the learning rate of parameter.
Args:
param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`.
Returns:
Parameter, single `Parameter` or `list[Parameter]` according to the input type.
"""
if
not
isinstance
(
param
,
(
Parameter
,
list
)):
raise
TypeError
(
f
"The 'param' only support 'Parameter' or 'list' type."
)
if
isinstance
(
param
,
list
):
lr
=
[]
for
p
in
param
:
validator
.
check_value_type
(
"parameter"
,
p
,
[
Parameter
],
self
.
cls_name
)
if
self
.
is_group_lr
:
index
=
self
.
parameters
.
index
(
p
)
lr
.
append
(
self
.
learning_rate
[
index
])
else
:
lr
.
append
(
self
.
learning_rate
)
else
:
if
self
.
is_group_lr
:
index
=
self
.
parameters
.
index
(
param
)
lr
=
self
.
learning_rate
[
index
]
else
:
lr
=
self
.
learning_rate
return
lr
def
construct
(
self
,
*
hyper_params
):
raise
NotImplementedError
...
...
tests/ut/python/optimizer/test_optimize_with_parameter_groups.py
浏览文件 @
e70b2f54
...
...
@@ -210,3 +210,41 @@ def test_group_repeat_param():
{
'params'
:
no_conv_params
}]
with
pytest
.
raises
(
RuntimeError
):
Adam
(
group_params
,
learning_rate
=
default_lr
)
def
test_get_lr_parameter_with_group
():
net
=
LeNet5
()
conv_lr
=
0.1
default_lr
=
0.3
conv_params
=
list
(
filter
(
lambda
x
:
'conv'
in
x
.
name
,
net
.
trainable_params
()))
no_conv_params
=
list
(
filter
(
lambda
x
:
'conv'
not
in
x
.
name
,
net
.
trainable_params
()))
group_params
=
[{
'params'
:
conv_params
,
'lr'
:
conv_lr
},
{
'params'
:
no_conv_params
,
'lr'
:
default_lr
}]
opt
=
SGD
(
group_params
)
assert
opt
.
is_group_lr
is
True
for
param
in
opt
.
parameters
:
lr
=
opt
.
get_lr_parameter
(
param
)
assert
lr
.
name
==
'lr_'
+
param
.
name
lr_list
=
opt
.
get_lr_parameter
(
conv_params
)
for
lr
,
param
in
zip
(
lr_list
,
conv_params
):
assert
lr
.
name
==
'lr_'
+
param
.
name
def
test_get_lr_parameter_with_no_group
():
net
=
LeNet5
()
conv_weight_decay
=
0.8
conv_params
=
list
(
filter
(
lambda
x
:
'conv'
in
x
.
name
,
net
.
trainable_params
()))
no_conv_params
=
list
(
filter
(
lambda
x
:
'conv'
not
in
x
.
name
,
net
.
trainable_params
()))
group_params
=
[{
'params'
:
conv_params
,
'weight_decay'
:
conv_weight_decay
},
{
'params'
:
no_conv_params
}]
opt
=
SGD
(
group_params
)
assert
opt
.
is_group_lr
is
False
for
param
in
opt
.
parameters
:
lr
=
opt
.
get_lr_parameter
(
param
)
assert
lr
.
name
==
opt
.
learning_rate
.
name
params_error
=
[
1
,
2
,
3
]
with
pytest
.
raises
(
TypeError
):
opt
.
get_lr_parameter
(
params_error
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录