Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b63e0ccb
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b63e0ccb
编写于
10月 28, 2020
作者:
Z
Zhou Wei
提交者:
GitHub
10月 28, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix load check_point bug of LinearWarmup (#28280)
上级
0b678d40
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
57 addition
and
16 deletion
+57
-16
python/paddle/fluid/tests/unittests/test_lr_scheduler.py
python/paddle/fluid/tests/unittests/test_lr_scheduler.py
+36
-14
python/paddle/optimizer/lr.py
python/paddle/optimizer/lr.py
+21
-2
未找到文件。
python/paddle/fluid/tests/unittests/test_lr_scheduler.py
浏览文件 @
b63e0ccb
...
...
@@ -284,11 +284,19 @@ def linear_warmup_lr(epoch_num,
start_lr
,
end_lr
,
verbose
=
False
):
if
epoch_num
<
warmup_steps
:
tmp
=
epoch_num
-
warmup_steps
if
tmp
<
0
:
return
start_lr
+
(
end_lr
-
start_lr
)
*
(
float
(
epoch_num
)
/
float
(
warmup_steps
))
elif
paddle
.
in_dynamic_mode
():
if
tmp
<
3
:
return
0.5
elif
tmp
<
6
:
return
0.2
else
:
return
0.1
else
:
return
learning_rate
return
0.5
def
multi_step_lr
(
epoch_num
,
...
...
@@ -407,6 +415,9 @@ class TestLRScheduler(unittest.TestCase):
paddle
.
disable_static
(
place
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
10
]).
astype
(
"float32"
)
linear
=
paddle
.
nn
.
Linear
(
10
,
10
)
if
paddle_api
.
__name__
==
"LinearWarmup"
:
kwarg
[
'learning_rate'
]
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
[
3
,
6
],
[
0.5
,
0.2
,
0.1
])
scheduler
=
paddle_api
(
**
kwarg
)
adam
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
scheduler
,
parameters
=
linear
.
parameters
())
...
...
@@ -420,12 +431,26 @@ class TestLRScheduler(unittest.TestCase):
adam
.
clear_grad
()
current_lr
=
adam
.
get_lr
()
expected_lr
=
python_func
(
epoch
,
**
kwarg
)
if
paddle_api
.
__name__
!=
"CosineAnnealingDecay"
:
self
.
assertEqual
(
current_lr
,
expected_lr
)
scheduler
.
step
()
else
:
if
paddle_api
.
__name__
==
"CosineAnnealingDecay"
:
self
.
assertAlmostEqual
(
current_lr
,
expected_lr
)
scheduler
.
step
(
epoch
+
1
)
elif
paddle_api
.
__name__
==
"LinearWarmup"
:
self
.
assertAlmostEqual
(
current_lr
,
expected_lr
)
state_dict
=
adam
.
state_dict
()
scheduler1
=
paddle
.
optimizer
.
lr
.
LinearWarmup
(
**
kwarg
)
adam1
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
scheduler1
,
parameters
=
linear
.
parameters
())
adam1
.
set_state_dict
(
state_dict
)
self
.
assertEqual
(
scheduler
.
last_epoch
,
scheduler1
.
last_epoch
)
self
.
assertEqual
(
scheduler
.
last_lr
,
scheduler1
.
last_lr
)
self
.
assertEqual
(
scheduler
.
learning_rate
.
last_lr
,
scheduler1
.
learning_rate
.
last_lr
)
self
.
assertEqual
(
scheduler
.
learning_rate
.
last_epoch
,
scheduler1
.
learning_rate
.
last_epoch
)
scheduler
.
step
()
else
:
self
.
assertEqual
(
current_lr
,
expected_lr
)
scheduler
.
step
()
def
test_scheduler
(
self
):
with
self
.
assertRaises
(
NotImplementedError
):
...
...
@@ -464,8 +489,7 @@ class TestLRScheduler(unittest.TestCase):
"decay_steps"
:
20
,
"end_lr"
:
0
,
"power"
:
1.0
,
"cycle"
:
False
,
"verbose"
:
True
"cycle"
:
False
}),
(
polynomial_lr
,
paddle
.
optimizer
.
lr
.
PolynomialDecay
,
{
"learning_rate"
:
0.5
,
"decay_steps"
:
20
,
...
...
@@ -475,10 +499,9 @@ class TestLRScheduler(unittest.TestCase):
"verbose"
:
False
}),
(
linear_warmup_lr
,
paddle
.
optimizer
.
lr
.
LinearWarmup
,
{
'learning_rate'
:
0.5
,
'warmup_steps'
:
2
0
,
'warmup_steps'
:
1
0
,
'start_lr'
:
0
,
'end_lr'
:
0.5
,
"verbose"
:
True
'end_lr'
:
0.5
}),
(
exponential_lr
,
paddle
.
optimizer
.
lr
.
ExponentialDecay
,
{
"learning_rate"
:
0.5
,
"gamma"
:
0.9
,
...
...
@@ -486,8 +509,7 @@ class TestLRScheduler(unittest.TestCase):
}),
(
multi_step_lr
,
paddle
.
optimizer
.
lr
.
MultiStepDecay
,
{
"learning_rate"
:
0.5
,
"milestones"
:
[
3
,
6
,
9
,
15
,
20
],
"gamma"
:
0.8
,
"verbose"
:
True
"gamma"
:
0.8
}),
(
step_lr
,
paddle
.
optimizer
.
lr
.
StepDecay
,
{
"learning_rate"
:
0.5
,
"step_size"
:
2
,
...
...
@@ -510,7 +532,7 @@ class TestLRScheduler(unittest.TestCase):
for
place
in
places
:
paddle
.
enable_static
()
#
self._test_static(python_func, paddle_api, kwarg, place)
self
.
_test_static
(
python_func
,
paddle_api
,
kwarg
,
place
)
paddle
.
disable_static
(
place
)
self
.
_test_dygraph
(
python_func
,
paddle_api
,
kwarg
,
place
)
paddle
.
enable_static
()
...
...
python/paddle/optimizer/lr.py
浏览文件 @
b63e0ccb
...
...
@@ -365,7 +365,6 @@ class PiecewiseDecay(LRScheduler):
last_epoch
=
last_epoch
,
verbose
=
verbose
)
def
get_lr
(
self
):
for
i
in
range
(
len
(
self
.
boundaries
)):
if
self
.
last_epoch
<
self
.
boundaries
[
i
]:
return
self
.
values
[
i
]
...
...
@@ -750,14 +749,34 @@ class LinearWarmup(LRScheduler):
end_lr
,
start_lr
)
super
(
LinearWarmup
,
self
).
__init__
(
start_lr
,
last_epoch
,
verbose
)
def
state_dict
(
self
):
"""
Returns the state of the LinearWarmup scheduler as a :class:`dict`.
It is a subset of ``self.__dict__`` .
"""
state_dict
=
super
(
LinearWarmup
,
self
).
state_dict
()
if
isinstance
(
self
.
learning_rate
,
LRScheduler
):
state_dict
[
"LinearWarmup_LR"
]
=
self
.
learning_rate
.
state_dict
()
return
state_dict
def
set_state_dict
(
self
,
state_dict
):
"""
Loads state_dict for LinearWarmup scheduler.
"""
super
(
LinearWarmup
,
self
).
set_state_dict
(
state_dict
)
if
isinstance
(
self
.
learning_rate
,
LRScheduler
):
self
.
learning_rate
.
set_state_dict
(
state_dict
[
"LinearWarmup_LR"
])
def
get_lr
(
self
):
if
self
.
last_epoch
<
self
.
warmup_steps
:
return
(
self
.
end_lr
-
self
.
start_lr
)
*
float
(
self
.
last_epoch
)
/
float
(
self
.
warmup_steps
)
+
self
.
start_lr
else
:
if
isinstance
(
self
.
learning_rate
,
LRScheduler
):
lr_value
=
self
.
learning_rate
()
self
.
learning_rate
.
step
()
return
self
.
learning_rate
()
return
lr_value
return
self
.
learning_rate
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录