Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f92fdfb8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f92fdfb8
编写于
11月 28, 2020
作者:
L
LielinJiang
提交者:
GitHub
11月 28, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ReduceLROnPlateau (#29113)
* add ReduceLROnPlateau
上级
b818429a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
275 addition
and
1 deletion
+275
-1
python/paddle/hapi/callbacks.py
python/paddle/hapi/callbacks.py
+169
-1
python/paddle/tests/test_callback_reduce_lr_on_plateau.py
python/paddle/tests/test_callback_reduce_lr_on_plateau.py
+106
-0
未找到文件。
python/paddle/hapi/callbacks.py
浏览文件 @
f92fdfb8
...
...
@@ -27,7 +27,7 @@ from .progressbar import ProgressBar
__all__
=
[
'Callback'
,
'ProgBarLogger'
,
'ModelCheckpoint'
,
'VisualDL'
,
'LRScheduler'
,
'EarlyStopping'
'EarlyStopping'
,
'ReduceLROnPlateau'
]
...
...
@@ -946,3 +946,171 @@ class VisualDL(Callback):
if
(
not
hasattr
(
self
,
'_is_fit'
))
and
hasattr
(
self
,
'writer'
):
self
.
writer
.
close
()
delattr
(
self
,
'writer'
)
class
ReduceLROnPlateau
(
Callback
):
"""Reduce learning rate when a metric of evaluation has stopped improving.
Models often benefit from reducing the learning rate by a factor
of 2-10 once learning stagnates. This callback monitors a
quantity and if no improvement is seen for a 'patience' number
of epochs, the learning rate is reduced.
Args:
monitor(str, optional): Quantity to be monitored. Default: 'loss'.
factor(float, optional): factor by which the learning rate will be reduced.
`new_lr = lr * factor`. Default: 0.1.
patience(int, optional): Number of epochs with no improvement after which
learning rate will be reduced. Default: 10.
verbose(int, optional): The verbosity mode. 0: quiet, 1: update messages.
Default: 1.
mode(str, optional): one of `{'auto', 'min', 'max'}`. In `'min'` mode,
the learning rate will be reduced when the quantity monitored has
stopped decreasing. In 'max' mode, learning rate will reduce until
monitored quantity stops increasing. In 'auto' mode, exact mode
can be inferred by the name of monitor. If 'acc' in monitor, the
mode will be considered as 'max', otherwise the mode will be set
to 'min'. Default: 'auto'.
min_delta(int|float, optional): threshold for measuring the new optimum,
to only focus on significant changes. Default: 0.
cooldown(int, optional): number of epochs to wait before resuming normal operation after
lr has been reduced. Default: 0.
min_lr(float, optional): lower bound on the learning rate. Default: 0.
Examples:
.. code-block:: python
import paddle
from paddle import Model
from paddle.static import InputSpec
from paddle.vision.models import LeNet
from paddle.vision.datasets import MNIST
from paddle.metric import Accuracy
from paddle.nn.layer.loss import CrossEntropyLoss
import paddle.vision.transforms as T
sample_num = 200
transform = T.Compose(
[T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = MNIST(mode='train', transform=transform)
val_dataset = MNIST(mode='test', transform=transform)
net = LeNet()
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=net.parameters())
inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')]
model = Model(net, inputs=inputs, labels=labels)
model.prepare(
optim,
loss=CrossEntropyLoss(),
metrics=[Accuracy()])
callbacks = paddle.callbacks.ReduceLROnPlateau(patience=3, verbose=1)
model.fit(train_dataset,
val_dataset,
batch_size=64,
log_freq=200,
save_freq=10,
epochs=20,
callbacks=[callbacks])
"""
def
__init__
(
self
,
monitor
=
'loss'
,
factor
=
0.1
,
patience
=
10
,
verbose
=
1
,
mode
=
'auto'
,
min_delta
=
1e-4
,
cooldown
=
0
,
min_lr
=
0
):
super
(
ReduceLROnPlateau
,
self
).
__init__
()
self
.
monitor
=
monitor
if
factor
>=
1.0
:
raise
ValueError
(
'ReduceLROnPlateau '
'does not support a factor >= 1.0.'
)
self
.
factor
=
factor
self
.
min_lr
=
min_lr
self
.
min_delta
=
min_delta
self
.
patience
=
patience
self
.
verbose
=
verbose
self
.
cooldown
=
cooldown
self
.
cooldown_counter
=
0
# Cooldown counter.
self
.
wait
=
0
self
.
best
=
0
self
.
mode
=
mode
self
.
monitor_op
=
None
self
.
epoch
=
0
self
.
_reset
()
def
_reset
(
self
):
"""Resets wait counter and cooldown counter.
"""
if
self
.
mode
not
in
[
'auto'
,
'min'
,
'max'
]:
warnings
.
warn
(
'Learning rate reduction mode %s is unknown, '
'fallback to auto mode.'
%
self
.
mode
)
self
.
mode
=
'auto'
if
(
self
.
mode
==
'min'
or
(
self
.
mode
==
'auto'
and
'acc'
not
in
self
.
monitor
)):
self
.
monitor_op
=
lambda
a
,
b
:
np
.
less
(
a
,
b
-
self
.
min_delta
)
self
.
best
=
np
.
Inf
else
:
self
.
monitor_op
=
lambda
a
,
b
:
np
.
greater
(
a
,
b
+
self
.
min_delta
)
self
.
best
=
-
np
.
Inf
self
.
cooldown_counter
=
0
self
.
wait
=
0
def
on_train_begin
(
self
,
logs
=
None
):
self
.
_reset
()
def
on_eval_end
(
self
,
logs
=
None
):
if
logs
is
None
or
self
.
monitor
not
in
logs
:
warnings
.
warn
(
'Monitor of ReduceLROnPlateau should be loss or metric name.'
)
return
else
:
try
:
lr
=
self
.
model
.
_optimizer
.
_learning_rate
if
not
isinstance
(
lr
,
float
):
warnings
.
warn
(
'Expected learning_rate be float, bug got {}.'
.
format
(
type
(
lr
)))
return
except
Exception
as
e
:
warnings
.
warn
(
'There are something wrong when get learning_rate from optimizer: {}.'
.
format
(
e
))
return
current
=
logs
[
self
.
monitor
]
if
isinstance
(
current
,
(
list
,
tuple
)):
current
=
current
[
0
]
elif
isinstance
(
current
,
numbers
.
Number
):
current
=
current
else
:
return
if
self
.
in_cooldown
():
self
.
cooldown_counter
-=
1
self
.
wait
=
0
if
self
.
monitor_op
(
current
,
self
.
best
):
self
.
best
=
current
self
.
wait
=
0
elif
not
self
.
in_cooldown
():
self
.
wait
+=
1
if
self
.
wait
>=
self
.
patience
:
old_lr
=
self
.
model
.
_optimizer
.
get_lr
()
if
old_lr
>
np
.
float32
(
self
.
min_lr
):
new_lr
=
old_lr
*
self
.
factor
new_lr
=
max
(
new_lr
,
self
.
min_lr
)
self
.
model
.
_optimizer
.
_learning_rate
=
new_lr
if
self
.
verbose
>
0
and
ParallelEnv
().
local_rank
==
0
:
print
(
'
\n
Epoch %d: ReduceLROnPlateau reducing learning '
'rate to %s.'
%
(
self
.
epoch
+
1
,
new_lr
))
self
.
cooldown_counter
=
self
.
cooldown
self
.
wait
=
0
self
.
epoch
+=
1
def
in_cooldown
(
self
):
return
self
.
cooldown_counter
>
0
python/paddle/tests/test_callback_reduce_lr_on_plateau.py
0 → 100644
浏览文件 @
f92fdfb8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
unittest
import
time
import
random
import
tempfile
import
shutil
import
numpy
as
np
import
paddle
import
paddle.vision.transforms
as
T
from
paddle
import
Model
from
paddle.static
import
InputSpec
from
paddle.vision.models
import
LeNet
from
paddle.hapi.callbacks
import
config_callbacks
from
paddle.vision.datasets
import
MNIST
from
paddle.metric
import
Accuracy
from
paddle.nn.layer.loss
import
CrossEntropyLoss
# Accelerate unittest
class
CustomMnist
(
MNIST
):
def
__len__
(
self
):
return
8
class
TestReduceLROnPlateau
(
unittest
.
TestCase
):
def
test_reduce_lr_on_plateau
(
self
):
transform
=
T
.
Compose
([
T
.
Transpose
(),
T
.
Normalize
([
127.5
],
[
127.5
])])
train_dataset
=
CustomMnist
(
mode
=
'train'
,
transform
=
transform
)
val_dataset
=
CustomMnist
(
mode
=
'test'
,
transform
=
transform
)
net
=
LeNet
()
optim
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameters
=
net
.
parameters
())
inputs
=
[
InputSpec
([
None
,
1
,
28
,
28
],
'float32'
,
'x'
)]
labels
=
[
InputSpec
([
None
,
1
],
'int64'
,
'label'
)]
model
=
Model
(
net
,
inputs
=
inputs
,
labels
=
labels
)
model
.
prepare
(
optim
,
loss
=
CrossEntropyLoss
(),
metrics
=
[
Accuracy
()])
callbacks
=
paddle
.
callbacks
.
ReduceLROnPlateau
(
patience
=
1
,
verbose
=
1
,
cooldown
=
1
)
model
.
fit
(
train_dataset
,
val_dataset
,
batch_size
=
8
,
log_freq
=
1
,
save_freq
=
10
,
epochs
=
10
,
callbacks
=
[
callbacks
])
def
test_warn_or_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
paddle
.
callbacks
.
ReduceLROnPlateau
(
factor
=
2.0
)
# warning
paddle
.
callbacks
.
ReduceLROnPlateau
(
mode
=
'1'
,
patience
=
3
,
verbose
=
1
)
transform
=
T
.
Compose
([
T
.
Transpose
(),
T
.
Normalize
([
127.5
],
[
127.5
])])
train_dataset
=
CustomMnist
(
mode
=
'train'
,
transform
=
transform
)
val_dataset
=
CustomMnist
(
mode
=
'test'
,
transform
=
transform
)
net
=
LeNet
()
optim
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameters
=
net
.
parameters
())
inputs
=
[
InputSpec
([
None
,
1
,
28
,
28
],
'float32'
,
'x'
)]
labels
=
[
InputSpec
([
None
,
1
],
'int64'
,
'label'
)]
model
=
Model
(
net
,
inputs
=
inputs
,
labels
=
labels
)
model
.
prepare
(
optim
,
loss
=
CrossEntropyLoss
(),
metrics
=
[
Accuracy
()])
callbacks
=
paddle
.
callbacks
.
ReduceLROnPlateau
(
monitor
=
'miou'
,
patience
=
3
,
verbose
=
1
)
model
.
fit
(
train_dataset
,
val_dataset
,
batch_size
=
8
,
log_freq
=
1
,
save_freq
=
10
,
epochs
=
1
,
callbacks
=
[
callbacks
])
optim
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
([
0.001
,
0.0001
],
[
5
,
10
]),
parameters
=
net
.
parameters
())
model
.
prepare
(
optim
,
loss
=
CrossEntropyLoss
(),
metrics
=
[
Accuracy
()])
callbacks
=
paddle
.
callbacks
.
ReduceLROnPlateau
(
monitor
=
'acc'
,
mode
=
'max'
,
patience
=
3
,
verbose
=
1
,
cooldown
=
1
)
model
.
fit
(
train_dataset
,
val_dataset
,
batch_size
=
8
,
log_freq
=
1
,
save_freq
=
10
,
epochs
=
3
,
callbacks
=
[
callbacks
])
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录