未验证 提交 70385518 编写于 作者: L LiuChiachi 提交者: GitHub

Add EarlyStopping (#28691)

* add early stopping

* add doc for early stopping

* fix sample code bugs

* update infer of mode, update doc, add unittests to increase coverage rate

* fix sample code for early stopping

* update sample code and unittests

* reduce time cost of test_callbacks unittest

* fix model.py code style error
上级 8c8b42f2
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
import os import os
import numbers import numbers
import warnings
import numpy as np
import paddle import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
...@@ -22,7 +25,8 @@ from paddle.utils import try_import ...@@ -22,7 +25,8 @@ from paddle.utils import try_import
from .progressbar import ProgressBar from .progressbar import ProgressBar
__all__ = [ __all__ = [
'Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL', 'LRScheduler' 'Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL', 'LRScheduler',
'EarlyStopping'
] ]
...@@ -45,6 +49,9 @@ def config_callbacks(callbacks=None, ...@@ -45,6 +49,9 @@ def config_callbacks(callbacks=None,
if not any(isinstance(k, ModelCheckpoint) for k in cbks): if not any(isinstance(k, ModelCheckpoint) for k in cbks):
cbks = cbks + [ModelCheckpoint(save_freq, save_dir)] cbks = cbks + [ModelCheckpoint(save_freq, save_dir)]
for k in cbks:
if isinstance(k, EarlyStopping):
k.save_dir = save_dir
if not any(isinstance(k, LRScheduler) for k in cbks): if not any(isinstance(k, LRScheduler) for k in cbks):
cbks = cbks + [LRScheduler()] cbks = cbks + [LRScheduler()]
...@@ -581,6 +588,157 @@ class LRScheduler(Callback): ...@@ -581,6 +588,157 @@ class LRScheduler(Callback):
self.model._optimizer._learning_rate.step() self.model._optimizer._learning_rate.step()
class EarlyStopping(Callback):
"""Stop training when the given monitor stopped improving during evaluation.
Args:
monitor(str): Quantity to be monitored. Default: 'loss'.
mode(str|None): Mode should be one of 'auto', 'min' or 'max'. In 'min'
mode, training will stop until monitored quantity stops decreasing.
In 'max' mode, training will stop until monitored quantity stops
increasing. In 'auto' mode, exact mode can be inferred by the name
of monitor. If 'acc' in monitor, the mode will be considered as
'max', otherwise the mode will be set to 'min'. Default: 'auto'.
patience(int): Number of epochs with no improvement after which
training will be stopped. Default: 0.
verbose(int): The verbosity mode, should be 0 or 1. When verbose=0,
logs will not be printed. When verbose=1, logs will be printed.
Default: 1.
min_delta(int|float): The minimum change of monitored quantity. If
the change is less than min_delta, model could be considered as no
improvement. Default: 0.
baseline(int|float|None): Baseline value for the monitored quantity.
Training will stop if the model doesn't show improvement over the
baseline. Default: None.
save_best_model(bool): Whether to save best model. Default: True.
Examples:
.. code-block:: python
import paddle
from paddle import Model
from paddle.static import InputSpec
from paddle.vision.models import LeNet
from paddle.vision.datasets import MNIST
from paddle.metric import Accuracy
from paddle.nn.layer.loss import CrossEntropyLoss
import paddle.vision.transforms as T
device = paddle.set_device('cpu')
sample_num = 200
save_dir = './best_model_checkpoint'
transform = T.Compose(
[T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = MNIST(mode='train', transform=transform)
val_dataset = MNIST(mode='test', transform=transform)
net = LeNet()
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=net.parameters())
inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')]
model = Model(net, inputs=inputs, labels=labels)
model.prepare(
optim,
loss=CrossEntropyLoss(reduction="sum"),
metrics=[Accuracy()])
callbacks = paddle.callbacks.EarlyStopping(
'loss',
mode='min',
patience=1,
verbose=1,
min_delta=0,
baseline=None,
save_best_model=True)
model.fit(train_dataset,
val_dataset,
batch_size=64,
log_freq=200,
save_freq=10,
save_dir=save_dir,
epochs=20,
callbacks=[callbacks])
"""
def __init__(self,
monitor='loss',
mode='auto',
patience=0,
verbose=1,
min_delta=0,
baseline=None,
save_best_model=True):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.baseline = baseline
self.min_delta = abs(min_delta)
self.wait_epoch = 0
self.best_weights = None
self.stopped_epoch = 0
self.save_best_model = save_best_model
self.save_dir = None # `save_dir` is get from `config_callbacks`
if mode not in ['auto', 'min', 'max']:
warnings.warn('EarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
# When mode == 'auto', the mode should be inferred by `self.monitor`
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_train_begin(self, logs=None):
self.wait_epoch = 0
if self.baseline is not None:
self.best_value = self.baseline
else:
self.best_value = np.inf if self.monitor_op == np.less else -np.inf
self.best_weights = None
def on_eval_end(self, logs=None):
if logs is None or self.monitor not in logs:
warnings.warn(
'Monitor of EarlyStopping should be loss or metric name.')
return
current = logs[self.monitor]
if isinstance(current, (list, tuple)):
current = current[0]
elif isinstance(current, numbers.Number):
current = current
else:
return
if self.monitor_op(current - self.min_delta, self.best_value):
self.best_value = current
self.wait_epoch = 0
if self.save_best_model and self.save_dir is not None:
path = os.path.join(self.save_dir, 'best_model')
self.model.save(path)
else:
self.wait_epoch += 1
if self.wait_epoch >= self.patience:
self.model.stop_training = True
if self.verbose > 0:
print('Epoch %d: Early stopping.' % (self.stopped_epoch + 1))
if self.save_best_model and self.save_dir is not None:
print('Best checkpoint has been saved at %s' %
(os.path.abspath(
os.path.join(self.save_dir, 'best_model'))))
self.stopped_epoch += 1
class VisualDL(Callback): class VisualDL(Callback):
"""VisualDL callback function """VisualDL callback function
Args: Args:
......
...@@ -50,7 +50,7 @@ from paddle.fluid.dygraph.layers import Layer ...@@ -50,7 +50,7 @@ from paddle.fluid.dygraph.layers import Layer
from paddle.metric import Metric from paddle.metric import Metric
from paddle.static import InputSpec as Input from paddle.static import InputSpec as Input
from .callbacks import config_callbacks from .callbacks import config_callbacks, EarlyStopping
from .model_summary import summary from .model_summary import summary
__all__ = ['Model', ] __all__ = ['Model', ]
...@@ -872,6 +872,7 @@ class Model(object): ...@@ -872,6 +872,7 @@ class Model(object):
self._input_info = None self._input_info = None
self._is_shape_inferred = False self._is_shape_inferred = False
self._test_dataloader = None self._test_dataloader = None
self.stop_training = False
if not in_dygraph_mode(): if not in_dygraph_mode():
if not isinstance(inputs, (list, dict, Input)): if not isinstance(inputs, (list, dict, Input)):
...@@ -1479,9 +1480,11 @@ class Model(object): ...@@ -1479,9 +1480,11 @@ class Model(object):
verbose=verbose, verbose=verbose,
metrics=self._metrics_name(), ) metrics=self._metrics_name(), )
if any(isinstance(k, EarlyStopping) for k in cbks) and not do_eval:
warnings.warn("EarlyStopping needs validation data.")
cbks.on_begin('train') cbks.on_begin('train')
for epoch in range(epochs): for epoch in range(epochs):
cbks.on_epoch_begin(epoch) cbks.on_epoch_begin(epoch)
logs = self._run_one_epoch(train_loader, cbks, 'train') logs = self._run_one_epoch(train_loader, cbks, 'train')
cbks.on_epoch_end(epoch, logs) cbks.on_epoch_end(epoch, logs)
...@@ -1497,6 +1500,8 @@ class Model(object): ...@@ -1497,6 +1500,8 @@ class Model(object):
eval_logs = self._run_one_epoch(eval_loader, cbks, 'eval') eval_logs = self._run_one_epoch(eval_loader, cbks, 'eval')
cbks.on_end('eval', eval_logs) cbks.on_end('eval', eval_logs)
if self.stop_training:
break
cbks.on_end('train', logs) cbks.on_end('train', logs)
self._test_dataloader = None self._test_dataloader = None
......
...@@ -18,13 +18,36 @@ import time ...@@ -18,13 +18,36 @@ import time
import random import random
import tempfile import tempfile
import shutil import shutil
import paddle import numpy as np
import paddle
from paddle import Model from paddle import Model
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.vision.models import LeNet from paddle.vision.models import LeNet
from paddle.hapi.callbacks import config_callbacks from paddle.hapi.callbacks import config_callbacks
import paddle.vision.transforms as T import paddle.vision.transforms as T
from paddle.vision.datasets import MNIST
from paddle.metric import Accuracy
from paddle.nn.layer.loss import CrossEntropyLoss
class MnistDataset(MNIST):
def __init__(self, mode, return_label=True, sample_num=None):
super(MnistDataset, self).__init__(mode=mode)
self.return_label = return_label
if sample_num:
self.images = self.images[:sample_num]
self.labels = self.labels[:sample_num]
def __getitem__(self, idx):
img, label = self.images[idx], self.labels[idx]
img = np.reshape(img, [1, 28, 28])
if self.return_label:
return img, np.array(self.labels[idx]).astype('int64')
return img,
def __len__(self):
return len(self.images)
class TestCallbacks(unittest.TestCase): class TestCallbacks(unittest.TestCase):
...@@ -134,6 +157,77 @@ class TestCallbacks(unittest.TestCase): ...@@ -134,6 +157,77 @@ class TestCallbacks(unittest.TestCase):
batch_size=64, batch_size=64,
callbacks=callback) callbacks=callback)
def test_earlystopping(self):
paddle.seed(2020)
for dynamic in [True, False]:
paddle.enable_static if not dynamic else None
device = paddle.set_device('cpu')
sample_num = 100
train_dataset = MnistDataset(mode='train', sample_num=sample_num)
val_dataset = MnistDataset(mode='test', sample_num=sample_num)
net = LeNet()
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=net.parameters())
inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')]
model = Model(net, inputs=inputs, labels=labels)
model.prepare(
optim,
loss=CrossEntropyLoss(reduction="sum"),
metrics=[Accuracy()])
callbacks_0 = paddle.callbacks.EarlyStopping(
'loss',
mode='min',
patience=1,
verbose=1,
min_delta=0,
baseline=None,
save_best_model=True)
callbacks_1 = paddle.callbacks.EarlyStopping(
'acc',
mode='auto',
patience=1,
verbose=1,
min_delta=0,
baseline=0,
save_best_model=True)
callbacks_2 = paddle.callbacks.EarlyStopping(
'loss',
mode='auto_',
patience=1,
verbose=1,
min_delta=0,
baseline=None,
save_best_model=True)
callbacks_3 = paddle.callbacks.EarlyStopping(
'acc_',
mode='max',
patience=1,
verbose=1,
min_delta=0,
baseline=0,
save_best_model=True)
model.fit(
train_dataset,
val_dataset,
batch_size=64,
save_freq=10,
save_dir=self.save_dir,
epochs=10,
verbose=0,
callbacks=[callbacks_0, callbacks_1, callbacks_2, callbacks_3])
# Test for no val_loader
model.fit(train_dataset,
batch_size=64,
save_freq=10,
save_dir=self.save_dir,
epochs=10,
verbose=0,
callbacks=[callbacks_0])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册