callbacks.py 40.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16
import time
17
import numbers
L
LiuChiachi 已提交
18 19 20
import warnings

import numpy as np
21

22
import paddle
Z
zhaoyingli 已提交
23
from paddle.fluid.dygraph.parallel import ParallelEnv
24
from paddle.utils import try_import
25 26 27

from .progressbar import ProgressBar

Z
zhiboniu 已提交
28
__all__ = []
29 30


31 32 33 34 35 36 37 38 39 40 41 42 43
def config_callbacks(
    callbacks=None,
    model=None,
    batch_size=None,
    epochs=None,
    steps=None,
    log_freq=2,
    verbose=2,
    save_freq=1,
    save_dir=None,
    metrics=None,
    mode='train',
):
44 45 46 47 48 49 50 51
    cbks = callbacks or []
    cbks = cbks if isinstance(cbks, (list, tuple)) else [cbks]
    if not any(isinstance(k, ProgBarLogger) for k in cbks) and verbose:
        cbks = [ProgBarLogger(log_freq, verbose=verbose)] + cbks

    if not any(isinstance(k, ModelCheckpoint) for k in cbks):
        cbks = cbks + [ModelCheckpoint(save_freq, save_dir)]

L
LiuChiachi 已提交
52 53 54
    for k in cbks:
        if isinstance(k, EarlyStopping):
            k.save_dir = save_dir
55 56 57
    if not any(isinstance(k, LRScheduler) for k in cbks):
        cbks = cbks + [LRScheduler()]

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    cbk_list = CallbackList(cbks)
    cbk_list.set_model(model)
    metrics = metrics or [] if mode != 'test' else []
    params = {
        'batch_size': batch_size,
        'epochs': epochs,
        'steps': steps,
        'verbose': verbose,
        'metrics': metrics,
    }
    cbk_list.set_params(params)
    return cbk_list


class CallbackList(object):
    def __init__(self, callbacks=None):
        # copy
        self.callbacks = [c for c in callbacks]
        self.params = {}
        self.model = None

    def append(self, callback):
        self.callbacks.append(callback)

    def __iter__(self):
        return iter(self.callbacks)

    def set_params(self, params):
        for c in self.callbacks:
            c.set_params(params)

    def set_model(self, model):
        for c in self.callbacks:
            c.set_model(model)

    def _call(self, name, *args):
        for c in self.callbacks:
            func = getattr(c, name)
            func(*args)

    def _check_mode(self, mode):
99 100 101 102 103
        assert mode in [
            'train',
            'eval',
            'predict',
        ], 'mode should be train, eval or predict'
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133

    def on_begin(self, mode, logs=None):
        self._check_mode(mode)
        name = 'on_{}_begin'.format(mode)
        self._call(name, logs)

    def on_end(self, mode, logs=None):
        self._check_mode(mode)
        name = 'on_{}_end'.format(mode)
        self._call(name, logs)

    def on_epoch_begin(self, epoch=None, logs=None):
        self._call('on_epoch_begin', epoch, logs)

    def on_epoch_end(self, epoch=None, logs=None):
        self._call('on_epoch_end', epoch, logs)

    def on_batch_begin(self, mode, step=None, logs=None):
        self._check_mode(mode)
        name = 'on_{}_batch_begin'.format(mode)
        self._call(name, step, logs)

    def on_batch_end(self, mode, step=None, logs=None):
        self._check_mode(mode)
        name = 'on_{}_batch_end'.format(mode)
        self._call(name, step, logs)


class Callback(object):
    """
134 135
    Base class used to build new callbacks. And new callbacks could also
    terminate training by setting `model.stop_training=True`.
136 137 138 139

    Examples:

        .. code-block:: python
140

141
            import paddle
142 143

            # build a simple model checkpoint callback
144
            class ModelCheckpoint(paddle.callbacks.Callback):
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
                def __init__(self, save_freq=1, save_dir=None):
                    self.save_freq = save_freq
                    self.save_dir = save_dir

                def on_epoch_end(self, epoch, logs=None):
                    if self.model is not None and epoch % self.save_freq == 0:
                        path = '{}/{}'.format(self.save_dir, epoch)
                        print('save checkpoint at {}'.format(path))
                        self.model.save(path)

    """

    def __init__(self):
        self.model = None
        self.params = {}

    def set_params(self, params):
        """
        Set parameters, which is dict. The keys contain:

          - 'batch_size': an integer. Number of samples per batch.
          - 'epochs': an integer. Number of epochs.
          - 'steps': an integer. Number of steps of one epoch.
168 169
          - 'verbose': an integer. Verbose mode is 0, 1 or 2. 0 = silent, 1 = progress bar, 2 = one line per epoch.
          - 'metrics': a list of str. Names of metrics, including 'loss' and the names of paddle.metric.Metric.
170 171 172 173
        """
        self.params = params

    def set_model(self, model):
174
        """model is instance of paddle.Model."""
175 176 177 178 179 180 181 182 183 184 185 186 187 188
        self.model = model

    def on_train_begin(self, logs=None):
        """Called at the start of training.

        Args:
            logs (dict): The logs is a dict or None.
        """

    def on_train_end(self, logs=None):
        """Called at the end of training.

        Args:
            logs (dict): The logs is a dict or None. The keys of logs
189
                passed by paddle.Model contains 'loss', metric names and
190 191 192 193 194 195 196 197
                `batch_size`.
        """

    def on_eval_begin(self, logs=None):
        """Called at the start of evaluation.

        Args:
            logs (dict): The logs is a dict or None. The keys of logs
198
                passed by paddle.Model contains 'steps' and 'metrics',
199 200
                The `steps` is number of total steps of validation dataset.
                The `metrics` is a list of str including 'loss' and the names
201
                of paddle.metric.Metric.
202 203 204 205 206 207 208
        """

    def on_eval_end(self, logs=None):
        """Called at the end of evaluation.

        Args:
            logs (dict): The logs is a dict or None. The `logs` passed by
209
                paddle.Model is a dict contains 'loss', metrics and 'batch_size'
210 211 212
                of last batch of validation dataset.
        """

213
    def on_predict_begin(self, logs=None):
214 215 216 217 218 219
        """Called at the beginning of predict.

        Args:
            logs (dict): The logs is a dict or None.
        """

220
    def on_predict_end(self, logs=None):
221 222 223 224 225 226 227 228 229 230 231 232
        """Called at the end of predict.

        Args:
            logs (dict): The logs is a dict or None.
        """

    def on_epoch_begin(self, epoch, logs=None):
        """Called at the beginning of each epoch.

        Args:
            epoch (int): The index of epoch.
            logs (dict): The logs is a dict or None. The `logs` passed by
233
                paddle.Model is None.
234 235 236 237 238 239 240 241
        """

    def on_epoch_end(self, epoch, logs=None):
        """Called at the end of each epoch.

        Args:
            epoch (int): The index of epoch.
            logs (dict): The logs is a dict or None. The `logs` passed by
242
                paddle.Model is a dict, contains 'loss', metrics and 'batch_size'
243 244 245 246 247 248 249 250 251
                of last batch.
        """

    def on_train_batch_begin(self, step, logs=None):
        """Called at the beginning of each batch in training.

        Args:
            step (int): The index of step (or iteration).
            logs (dict): The logs is a dict or None. The `logs` passed by
252
                paddle.Model is empty.
253 254 255 256 257 258 259 260
        """

    def on_train_batch_end(self, step, logs=None):
        """Called at the end of each batch in training.

        Args:
            step (int): The index of step (or iteration).
            logs (dict): The logs is a dict or None. The `logs` passed by
261
                paddle.Model is a dict, contains 'loss', metrics and 'batch_size'
262 263 264 265 266 267 268 269 270
                of current batch.
        """

    def on_eval_batch_begin(self, step, logs=None):
        """Called at the beginning of each batch in evaluation.

        Args:
            step (int): The index of step (or iteration).
            logs (dict): The logs is a dict or None. The `logs` passed by
271
                paddle.Model is empty.
272 273 274 275 276 277 278 279
        """

    def on_eval_batch_end(self, step, logs=None):
        """Called at the end of each batch in evaluation.

        Args:
            step (int): The index of step (or iteration).
            logs (dict): The logs is a dict or None. The `logs` passed by
280
                paddle.Model is a dict, contains 'loss', metrics and 'batch_size'
281 282 283
                of current batch.
        """

284
    def on_predict_batch_begin(self, step, logs=None):
285 286 287 288 289 290 291
        """Called at the beginning of each batch in predict.

        Args:
            step (int): The index of step (or iteration).
            logs (dict): The logs is a dict or None.
        """

292
    def on_predict_batch_end(self, step, logs=None):
293 294 295 296 297 298 299 300 301
        """Called at the end of each batch in predict.

        Args:
            step (int): The index of step (or iteration).
            logs (dict): The logs is a dict or None.
        """


class ProgBarLogger(Callback):
302
    """
303 304 305
    Logger callback function to print loss and metrics to stdout. It supports
    silent mode (not print), progress bar or one line per each printing,
    see arguments for more detailed.
306

307
    Args:
308 309
        log_freq (int): The frequency, in number of steps,
            the logs such as loss, metrics are printed. Default: 1.
310
        verbose (int): The verbosity mode, should be 0, 1, or 2.
311
            0 = silent, 1 = progress bar, 2 = one line each printing, 3 = 2 +
312
            time counter, such as average reader cost, samples per second.
313
            Default: 2.
314 315 316 317

    Examples:
        .. code-block:: python

318
            import paddle
319
            import paddle.vision.transforms as T
320
            from paddle.vision.datasets import MNIST
321
            from paddle.static import InputSpec
322

323 324
            inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
            labels = [InputSpec([None, 1], 'int64', 'label')]
325

326 327 328 329
            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
330
            train_dataset = MNIST(mode='train', transform=transform)
331

332
            lenet = paddle.vision.models.LeNet()
L
LielinJiang 已提交
333
            model = paddle.Model(lenet,
334
                inputs, labels)
335

L
LielinJiang 已提交
336
            optim = paddle.optimizer.Adam(0.001, parameters=lenet.parameters())
337
            model.prepare(optimizer=optim,
338 339
                        loss=paddle.nn.CrossEntropyLoss(),
                        metrics=paddle.metric.Accuracy())
340

341
            callback = paddle.callbacks.ProgBarLogger(log_freq=10)
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
            model.fit(train_dataset, batch_size=64, callbacks=callback)
    """

    def __init__(self, log_freq=1, verbose=2):
        self.epochs = None
        self.steps = None
        self.progbar = None
        self.verbose = verbose
        self.log_freq = log_freq

    def _is_print(self):
        return self.verbose and ParallelEnv().local_rank == 0

    def on_train_begin(self, logs=None):
        self.epochs = self.params['epochs']
        assert self.epochs
        self.train_metrics = self.params['metrics']
        assert self.train_metrics

361 362 363 364 365 366 367 368
        self._train_timer = {
            'data_time': 0,
            'batch_time': 0,
            'count': 0,
            'samples': 0,
        }
        if self._is_print():
            print(
J
Jiaqi Liu 已提交
369
                "The loss value printed in the log is the current step, and the metric is the average value of previous steps."
370 371
            )

372 373 374 375 376 377 378 379
    def on_epoch_begin(self, epoch=None, logs=None):
        self.steps = self.params['steps']
        self.epoch = epoch
        self.train_step = 0
        if self.epochs and self._is_print():
            print('Epoch %d/%d' % (epoch + 1, self.epochs))
        self.train_progbar = ProgressBar(num=self.steps, verbose=self.verbose)

380 381
        self._train_timer['batch_start_time'] = time.time()

382 383 384 385 386 387 388 389 390 391
    def _updates(self, logs, mode):
        values = []
        metrics = getattr(self, '%s_metrics' % (mode))
        progbar = getattr(self, '%s_progbar' % (mode))
        steps = getattr(self, '%s_step' % (mode))

        for k in metrics:
            if k in logs:
                values.append((k, logs[k]))

392 393 394 395 396
        if self.verbose == 3 and hasattr(self, '_%s_timer' % (mode)):
            timer = getattr(self, '_%s_timer' % (mode))
            cnt = timer['count'] if timer['count'] > 0 else 1.0
            samples = timer['samples'] if timer['samples'] > 0 else 1.0
            values.append(
397 398
                ('avg_reader_cost', "%.5f sec" % (timer['data_time'] / cnt))
            )
399
            values.append(
400 401
                ('avg_batch_cost', "%.5f sec" % (timer['batch_time'] / cnt))
            )
402
            values.append(
403 404 405 406 407 408
                (
                    'ips',
                    "%.5f samples/sec"
                    % (samples / (timer['data_time'] + timer['batch_time'])),
                )
            )
409 410
            timer['count'] = 0
            timer['samples'] = 0
411 412
            timer['data_time'] = 0.0
            timer['batch_time'] = 0.0
413

414 415
        progbar.update(steps, values)

416 417 418
    def on_train_batch_begin(self, step, logs=None):
        self._train_timer['batch_data_end_time'] = time.time()
        self._train_timer['data_time'] += (
419 420 421
            self._train_timer['batch_data_end_time']
            - self._train_timer['batch_start_time']
        )
422

423 424 425 426
    def on_train_batch_end(self, step, logs=None):
        logs = logs or {}
        self.train_step += 1

427
        self._train_timer['batch_time'] += (
428 429
            time.time() - self._train_timer['batch_data_end_time']
        )
430 431 432
        self._train_timer['count'] += 1
        samples = logs.get('batch_size', 1)
        self._train_timer['samples'] += samples
433 434 435
        if self._is_print() and self.train_step % self.log_freq == 0:
            if self.steps is None or self.train_step < self.steps:
                self._updates(logs, 'train')
436
        self._train_timer['batch_start_time'] = time.time()
437 438 439 440 441 442 443 444 445 446 447 448

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        if self._is_print() and (self.steps is not None):
            self._updates(logs, 'train')

    def on_eval_begin(self, logs=None):
        self.eval_steps = logs.get('steps', None)
        self.eval_metrics = logs.get('metrics', [])
        self.eval_step = 0
        self.evaled_samples = 0

449 450 451 452 453 454 455
        self._eval_timer = {
            'data_time': 0,
            'batch_time': 0,
            'count': 0,
            'samples': 0,
        }

456 457 458
        self.eval_progbar = ProgressBar(
            num=self.eval_steps, verbose=self.verbose
        )
459 460
        if self._is_print():
            print('Eval begin...')
461 462 463 464 465 466

        self._eval_timer['batch_start_time'] = time.time()

    def on_eval_batch_begin(self, step, logs=None):
        self._eval_timer['batch_data_end_time'] = time.time()
        self._eval_timer['data_time'] += (
467 468 469
            self._eval_timer['batch_data_end_time']
            - self._eval_timer['batch_start_time']
        )
470 471 472 473 474 475 476

    def on_eval_batch_end(self, step, logs=None):
        logs = logs or {}
        self.eval_step += 1
        samples = logs.get('batch_size', 1)
        self.evaled_samples += samples

477
        self._eval_timer['batch_time'] += (
478 479
            time.time() - self._eval_timer['batch_data_end_time']
        )
480 481 482 483
        self._eval_timer['count'] += 1
        samples = logs.get('batch_size', 1)
        self._eval_timer['samples'] += samples

484 485 486 487
        if self._is_print() and self.eval_step % self.log_freq == 0:
            if self.eval_steps is None or self.eval_step < self.eval_steps:
                self._updates(logs, 'eval')

488 489 490
        self._eval_timer['batch_start_time'] = time.time()

    def on_predict_begin(self, logs=None):
491 492 493 494
        self.test_steps = logs.get('steps', None)
        self.test_metrics = logs.get('metrics', [])
        self.test_step = 0
        self.tested_samples = 0
495 496 497 498 499 500 501 502

        self._test_timer = {
            'data_time': 0,
            'batch_time': 0,
            'count': 0,
            'samples': 0,
        }

503 504 505
        self.test_progbar = ProgressBar(
            num=self.test_steps, verbose=self.verbose
        )
506 507 508
        if self._is_print():
            print('Predict begin...')

509 510 511 512 513
        self._test_timer['batch_start_time'] = time.time()

    def on_predict_batch_begin(self, step, logs=None):
        self._test_timer['batch_data_end_time'] = time.time()
        self._test_timer['data_time'] += (
514 515 516
            self._test_timer['batch_data_end_time']
            - self._test_timer['batch_start_time']
        )
517 518

    def on_predict_batch_end(self, step, logs=None):
519 520 521 522 523
        logs = logs or {}
        self.test_step += 1
        samples = logs.get('batch_size', 1)
        self.tested_samples += samples

524
        self._test_timer['batch_time'] += (
525 526
            time.time() - self._test_timer['batch_data_end_time']
        )
527 528 529 530
        self._test_timer['count'] += 1
        samples = logs.get('batch_size', 1)
        self._test_timer['samples'] += samples

531 532 533 534
        if self.test_step % self.log_freq == 0 and self._is_print():
            if self.test_steps is None or self.test_step < self.test_steps:
                self._updates(logs, 'test')

535 536
        self._test_timer['batch_start_time'] = time.time()

537 538 539 540 541 542
    def on_eval_end(self, logs=None):
        logs = logs or {}
        if self._is_print() and (self.eval_steps is not None):
            self._updates(logs, 'eval')
            print('Eval samples: %d' % (self.evaled_samples))

543
    def on_predict_end(self, logs=None):
544 545 546 547 548 549 550 551
        logs = logs or {}
        if self._is_print():
            if self.test_step % self.log_freq != 0 or self.verbose == 1:
                self._updates(logs, 'test')
            print('Predict samples: %d' % (self.tested_samples))


class ModelCheckpoint(Callback):
552
    """
553 554 555
    Model checkpoint callback function to save model weights and optimizer
    state during training in conjunction with model.fit(). Currently,
    ModelCheckpoint only supports saving after a fixed number of epochs.
556

557
    Args:
558 559
        save_freq(int): The frequency, in number of epochs, the model checkpoint
            are saved. Default: 1.
560
        save_dir(str|None): The directory to save checkpoint during training.
561
            If None, will not save checkpoint. Default: None.
562 563 564 565

    Examples:
        .. code-block:: python

566
            import paddle
567
            import paddle.vision.transforms as T
568
            from paddle.vision.datasets import MNIST
569
            from paddle.static import InputSpec
570

571 572
            inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
            labels = [InputSpec([None, 1], 'int64', 'label')]
573

574 575 576 577
            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
578
            train_dataset = MNIST(mode='train', transform=transform)
579

580
            lenet = paddle.vision.models.LeNet()
L
LielinJiang 已提交
581
            model = paddle.Model(lenet,
582
                inputs, labels)
583

L
LielinJiang 已提交
584
            optim = paddle.optimizer.Adam(0.001, parameters=lenet.parameters())
585
            model.prepare(optimizer=optim,
586 587
                        loss=paddle.nn.CrossEntropyLoss(),
                        metrics=paddle.metric.Accuracy())
588

589
            callback = paddle.callbacks.ModelCheckpoint(save_dir='./temp')
590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605
            model.fit(train_dataset, batch_size=64, callbacks=callback)
    """

    def __init__(self, save_freq=1, save_dir=None):
        self.save_freq = save_freq
        self.save_dir = save_dir

    def on_epoch_begin(self, epoch=None, logs=None):
        self.epoch = epoch

    def _is_save(self):
        return self.model and self.save_dir and ParallelEnv().local_rank == 0

    def on_epoch_end(self, epoch, logs=None):
        if self._is_save() and self.epoch % self.save_freq == 0:
            path = '{}/{}'.format(self.save_dir, epoch)
606
            print('save checkpoint at {}'.format(os.path.abspath(path)))
607 608 609 610 611
            self.model.save(path)

    def on_train_end(self, logs=None):
        if self._is_save():
            path = '{}/final'.format(self.save_dir)
612
            print('save checkpoint at {}'.format(os.path.abspath(path)))
613
            self.model.save(path)
614 615


616 617
class LRScheduler(Callback):
    """Lr scheduler callback function
618

619
    Args:
S
sunzhongkai588 已提交
620
        by_step(bool, optional): whether to update learning rate scheduler
621
            by step. Default: True.
S
sunzhongkai588 已提交
622
        by_epoch(bool, optional): whether to update learning rate scheduler
623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640
            by epoch. Default: False.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.vision.transforms as T
            from paddle.static import InputSpec

            inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
            labels = [InputSpec([None, 1], 'int64', 'label')]

            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
            train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)

641
            lenet = paddle.vision.models.LeNet()
642 643 644 645 646 647
            model = paddle.Model(lenet,
                inputs, labels)

            base_lr = 1e-3
            boundaries = [5, 8]
            wamup_steps = 4
648

649 650 651 652 653 654 655 656
            def make_optimizer(parameters=None):
                momentum = 0.9
                weight_decay = 5e-4
                values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
                learning_rate = paddle.optimizer.lr.PiecewiseDecay(
                    boundaries=boundaries, values=values)
                learning_rate = paddle.optimizer.lr.LinearWarmup(
                    learning_rate=learning_rate,
657
                    warmup_steps=wamup_steps,
658 659 660 661 662 663 664 665 666
                    start_lr=base_lr / 5.,
                    end_lr=base_lr,
                    verbose=True)
                optimizer = paddle.optimizer.Momentum(
                    learning_rate=learning_rate,
                    weight_decay=weight_decay,
                    momentum=momentum,
                    parameters=parameters)
                return optimizer
667

668 669 670 671 672
            optim = make_optimizer(parameters=lenet.parameters())
            model.prepare(optimizer=optim,
                        loss=paddle.nn.CrossEntropyLoss(),
                        metrics=paddle.metric.Accuracy())

673
            # if LRScheduler callback not set, an instance LRScheduler update by step
674 675 676 677 678 679 680 681 682 683 684
            # will be created auto.
            model.fit(train_dataset, batch_size=64)

            # create a learning rate scheduler update by epoch
            callback = paddle.callbacks.LRScheduler(by_step=False, by_epoch=True)
            model.fit(train_dataset, batch_size=64, callbacks=callback)
    """

    def __init__(self, by_step=True, by_epoch=False):
        if by_step and by_epoch:
            raise ValueError(
685 686
                "by_step option is mutually exclusive with by_epoch"
            )
687 688 689 690 691 692

        self.by_step = by_step
        self.by_epoch = by_epoch

    def on_epoch_end(self, epoch, logs=None):
        if self.by_epoch:
693 694 695 696 697 698 699 700
            if (
                self.model._optimizer
                and hasattr(self.model._optimizer, '_learning_rate')
                and isinstance(
                    self.model._optimizer._learning_rate,
                    paddle.optimizer.lr.LRScheduler,
                )
            ):
701 702 703 704
                self.model._optimizer._learning_rate.step()

    def on_train_batch_end(self, step, logs=None):
        if self.by_step:
705 706 707 708 709 710 711 712
            if (
                self.model._optimizer
                and hasattr(self.model._optimizer, '_learning_rate')
                and isinstance(
                    self.model._optimizer._learning_rate,
                    paddle.optimizer.lr.LRScheduler,
                )
            ):
713 714 715
                self.model._optimizer._learning_rate.step()


L
LiuChiachi 已提交
716
class EarlyStopping(Callback):
717 718
    """Stop training when the given monitor stopped improving during evaluation
    by setting `model.stop_training=True`.
719

L
LiuChiachi 已提交
720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739
    Args:
        monitor(str): Quantity to be monitored. Default: 'loss'.
        mode(str|None): Mode should be one of 'auto', 'min' or 'max'. In 'min'
            mode, training will stop until monitored quantity stops decreasing.
            In 'max' mode, training will stop until monitored quantity stops
            increasing. In 'auto' mode, exact mode can be inferred by the name
            of monitor. If 'acc' in monitor, the mode will be considered as
            'max', otherwise the mode will be set to 'min'. Default: 'auto'.
        patience(int): Number of epochs with no improvement after which
            training will be stopped. Default: 0.
        verbose(int): The verbosity mode, should be 0 or 1. When verbose=0,
            logs will not be printed. When verbose=1, logs will be printed.
            Default: 1.
        min_delta(int|float): The minimum change of monitored quantity. If
            the change is less than min_delta, model could be considered as no
            improvement. Default: 0.
        baseline(int|float|None): Baseline value for the monitored quantity.
            Training will stop if the model doesn't show improvement over the
            baseline. Default: None.
        save_best_model(bool): Whether to save best model. Default: True.
740

L
LiuChiachi 已提交
741 742 743 744 745 746 747 748 749
    Examples:
        .. code-block:: python

            import paddle
            from paddle import Model
            from paddle.static import InputSpec
            from paddle.vision.models import LeNet
            from paddle.vision.datasets import MNIST
            from paddle.metric import Accuracy
750
            from paddle.nn import CrossEntropyLoss
L
LiuChiachi 已提交
751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789
            import paddle.vision.transforms as T

            device = paddle.set_device('cpu')
            sample_num = 200
            save_dir = './best_model_checkpoint'
            transform = T.Compose(
                [T.Transpose(), T.Normalize([127.5], [127.5])])
            train_dataset = MNIST(mode='train', transform=transform)
            val_dataset = MNIST(mode='test', transform=transform)
            net = LeNet()
            optim = paddle.optimizer.Adam(
                learning_rate=0.001, parameters=net.parameters())

            inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
            labels = [InputSpec([None, 1], 'int64', 'label')]

            model = Model(net, inputs=inputs, labels=labels)
            model.prepare(
                optim,
                loss=CrossEntropyLoss(reduction="sum"),
                metrics=[Accuracy()])
            callbacks = paddle.callbacks.EarlyStopping(
                'loss',
                mode='min',
                patience=1,
                verbose=1,
                min_delta=0,
                baseline=None,
                save_best_model=True)
            model.fit(train_dataset,
                      val_dataset,
                      batch_size=64,
                      log_freq=200,
                      save_freq=10,
                      save_dir=save_dir,
                      epochs=20,
                      callbacks=[callbacks])
    """

790 791 792 793 794 795 796 797 798 799
    def __init__(
        self,
        monitor='loss',
        mode='auto',
        patience=0,
        verbose=1,
        min_delta=0,
        baseline=None,
        save_best_model=True,
    ):
L
LiuChiachi 已提交
800 801 802 803 804 805 806 807 808 809
        super(EarlyStopping, self).__init__()
        self.monitor = monitor
        self.patience = patience
        self.verbose = verbose
        self.baseline = baseline
        self.min_delta = abs(min_delta)
        self.wait_epoch = 0
        self.best_weights = None
        self.stopped_epoch = 0
        self.save_best_model = save_best_model
810 811
        # The value of `save_dir` is set in function `config_callbacks`
        self.save_dir = None
L
LiuChiachi 已提交
812
        if mode not in ['auto', 'min', 'max']:
813 814 815 816
            warnings.warn(
                'EarlyStopping mode %s is unknown, '
                'fallback to auto mode.' % mode
            )
L
LiuChiachi 已提交
817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844
            mode = 'auto'
        if mode == 'min':
            self.monitor_op = np.less
        elif mode == 'max':
            self.monitor_op = np.greater
        # When mode == 'auto', the mode should be inferred by `self.monitor`
        else:
            if 'acc' in self.monitor:
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less

        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1

    def on_train_begin(self, logs=None):
        self.wait_epoch = 0
        if self.baseline is not None:
            self.best_value = self.baseline
        else:
            self.best_value = np.inf if self.monitor_op == np.less else -np.inf
            self.best_weights = None

    def on_eval_end(self, logs=None):
        if logs is None or self.monitor not in logs:
            warnings.warn(
845 846
                'Monitor of EarlyStopping should be loss or metric name.'
            )
L
LiuChiachi 已提交
847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868
            return
        current = logs[self.monitor]
        if isinstance(current, (list, tuple)):
            current = current[0]
        elif isinstance(current, numbers.Number):
            current = current
        else:
            return

        if self.monitor_op(current - self.min_delta, self.best_value):
            self.best_value = current
            self.wait_epoch = 0
            if self.save_best_model and self.save_dir is not None:
                path = os.path.join(self.save_dir, 'best_model')
                self.model.save(path)
        else:
            self.wait_epoch += 1
        if self.wait_epoch >= self.patience:
            self.model.stop_training = True
            if self.verbose > 0:
                print('Epoch %d: Early stopping.' % (self.stopped_epoch + 1))
                if self.save_best_model and self.save_dir is not None:
869 870 871 872 873 874 875 876
                    print(
                        'Best checkpoint has been saved at %s'
                        % (
                            os.path.abspath(
                                os.path.join(self.save_dir, 'best_model')
                            )
                        )
                    )
L
LiuChiachi 已提交
877 878 879
        self.stopped_epoch += 1


880
class VisualDL(Callback):
881 882 883
    """
    VisualDL callback function.

884 885 886 887 888 889 890
    Args:
        log_dir (str): The directory to save visualdl log file.

    Examples:
        .. code-block:: python

            import paddle
891
            import paddle.vision.transforms as T
892 893 894 895 896
            from paddle.static import InputSpec

            inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
            labels = [InputSpec([None, 1], 'int64', 'label')]

897 898 899 900 901 902
            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
            train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
            eval_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
903

904
            net = paddle.vision.models.LeNet()
905 906 907 908 909 910
            model = paddle.Model(net, inputs, labels)

            optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
            model.prepare(optimizer=optim,
                        loss=paddle.nn.CrossEntropyLoss(),
                        metrics=paddle.metric.Accuracy())
911

912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964
            ## uncomment following lines to fit model with visualdl callback function
            # callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
            # model.fit(train_dataset, eval_dataset, batch_size=64, callbacks=callback)

    """

    def __init__(self, log_dir):
        self.log_dir = log_dir
        self.epochs = None
        self.steps = None
        self.epoch = 0

    def _is_write(self):
        return ParallelEnv().local_rank == 0

    def on_train_begin(self, logs=None):
        self.epochs = self.params['epochs']
        assert self.epochs
        self.train_metrics = self.params['metrics']
        assert self.train_metrics
        self._is_fit = True
        self.train_step = 0

    def on_epoch_begin(self, epoch=None, logs=None):
        self.steps = self.params['steps']
        self.epoch = epoch

    def _updates(self, logs, mode):
        if not self._is_write():
            return
        if not hasattr(self, 'writer'):
            visualdl = try_import('visualdl')
            self.writer = visualdl.LogWriter(self.log_dir)

        metrics = getattr(self, '%s_metrics' % (mode))
        current_step = getattr(self, '%s_step' % (mode))

        if mode == 'train':
            total_step = current_step
        else:
            total_step = self.epoch

        for k in metrics:
            if k in logs:
                temp_tag = mode + '/' + k

                if isinstance(logs[k], (list, tuple)):
                    temp_value = logs[k][0]
                elif isinstance(logs[k], numbers.Number):
                    temp_value = logs[k]
                else:
                    continue

965 966 967
                self.writer.add_scalar(
                    tag=temp_tag, step=total_step, value=temp_value
                )
968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993

    def on_train_batch_end(self, step, logs=None):
        logs = logs or {}
        self.train_step += 1

        if self._is_write():
            self._updates(logs, 'train')

    def on_eval_begin(self, logs=None):
        self.eval_steps = logs.get('steps', None)
        self.eval_metrics = logs.get('metrics', [])
        self.eval_step = 0
        self.evaled_samples = 0

    def on_train_end(self, logs=None):
        if hasattr(self, 'writer'):
            self.writer.close()
            delattr(self, 'writer')

    def on_eval_end(self, logs=None):
        if self._is_write():
            self._updates(logs, 'eval')

            if (not hasattr(self, '_is_fit')) and hasattr(self, 'writer'):
                self.writer.close()
                delattr(self, 'writer')
L
LielinJiang 已提交
994 995 996 997 998 999 1000 1001


class ReduceLROnPlateau(Callback):
    """Reduce learning rate when a metric of evaluation has stopped improving.
    Models often benefit from reducing the learning rate by a factor
    of 2-10 once learning stagnates. This callback monitors a
    quantity and if no improvement is seen for a 'patience' number
    of epochs, the learning rate is reduced.
1002

L
LielinJiang 已提交
1003 1004 1005 1006 1007 1008 1009 1010 1011
    Args:
        monitor(str, optional): Quantity to be monitored. Default: 'loss'.
        factor(float, optional): factor by which the learning rate will be reduced.
            `new_lr = lr * factor`. Default: 0.1.
        patience(int, optional): Number of epochs with no improvement after which
            learning rate will be reduced. Default: 10.
        verbose(int, optional): The verbosity mode. 0: quiet, 1: update messages.
            Default: 1.
        mode(str, optional): one of `{'auto', 'min', 'max'}`. In `'min'` mode,
1012 1013 1014 1015 1016
            the learning rate will be reduced when the quantity monitored has
            stopped decreasing. In 'max' mode, learning rate will reduce until
            monitored quantity stops increasing. In 'auto' mode, exact mode
            can be inferred by the name of monitor. If 'acc' in monitor, the
            mode will be considered as 'max', otherwise the mode will be set
L
LielinJiang 已提交
1017
            to 'min'. Default: 'auto'.
1018
        min_delta(int|float, optional): threshold for measuring the new optimum,
L
LielinJiang 已提交
1019 1020 1021 1022
            to only focus on significant changes. Default: 0.
        cooldown(int, optional): number of epochs to wait before resuming normal operation after
            lr has been reduced. Default: 0.
        min_lr(float, optional): lower bound on the learning rate. Default: 0.
1023

L
LielinJiang 已提交
1024 1025
    Examples:
          .. code-block:: python
1026

L
LielinJiang 已提交
1027 1028 1029 1030 1031 1032 1033
              import paddle
              from paddle import Model
              from paddle.static import InputSpec
              from paddle.vision.models import LeNet
              from paddle.vision.datasets import MNIST
              from paddle.metric import Accuracy
              from paddle.nn.layer.loss import CrossEntropyLoss
1034
              import paddle.vision.transforms as T
L
LielinJiang 已提交
1035 1036 1037 1038 1039 1040 1041
              sample_num = 200
              transform = T.Compose(
                  [T.Transpose(), T.Normalize([127.5], [127.5])])
              train_dataset = MNIST(mode='train', transform=transform)
              val_dataset = MNIST(mode='test', transform=transform)
              net = LeNet()
              optim = paddle.optimizer.Adam(
1042
                  learning_rate=0.001, parameters=net.parameters())
L
LielinJiang 已提交
1043
              inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
1044
              labels = [InputSpec([None, 1], 'int64', 'label')]
L
LielinJiang 已提交
1045 1046 1047 1048
              model = Model(net, inputs=inputs, labels=labels)
              model.prepare(
                  optim,
                  loss=CrossEntropyLoss(),
1049
                  metrics=[Accuracy()])
L
LielinJiang 已提交
1050 1051 1052 1053 1054 1055 1056 1057
              callbacks = paddle.callbacks.ReduceLROnPlateau(patience=3, verbose=1)
              model.fit(train_dataset,
                          val_dataset,
                          batch_size=64,
                          log_freq=200,
                          save_freq=10,
                          epochs=20,
                          callbacks=[callbacks])
1058

L
LielinJiang 已提交
1059 1060
    """

1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071
    def __init__(
        self,
        monitor='loss',
        factor=0.1,
        patience=10,
        verbose=1,
        mode='auto',
        min_delta=1e-4,
        cooldown=0,
        min_lr=0,
    ):
L
LielinJiang 已提交
1072 1073 1074 1075
        super(ReduceLROnPlateau, self).__init__()

        self.monitor = monitor
        if factor >= 1.0:
1076 1077 1078
            raise ValueError(
                'ReduceLROnPlateau ' 'does not support a factor >= 1.0.'
            )
L
LielinJiang 已提交
1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094

        self.factor = factor
        self.min_lr = min_lr
        self.min_delta = min_delta
        self.patience = patience
        self.verbose = verbose
        self.cooldown = cooldown
        self.cooldown_counter = 0  # Cooldown counter.
        self.wait = 0
        self.best = 0
        self.mode = mode
        self.monitor_op = None
        self.epoch = 0
        self._reset()

    def _reset(self):
1095
        """Resets wait counter and cooldown counter."""
L
LielinJiang 已提交
1096
        if self.mode not in ['auto', 'min', 'max']:
1097 1098 1099 1100
            warnings.warn(
                'Learning rate reduction mode %s is unknown, '
                'fallback to auto mode.' % self.mode
            )
L
LielinJiang 已提交
1101
            self.mode = 'auto'
1102 1103 1104
        if self.mode == 'min' or (
            self.mode == 'auto' and 'acc' not in self.monitor
        ):
L
LielinJiang 已提交
1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118
            self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
            self.best = np.Inf
        else:
            self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
            self.best = -np.Inf
        self.cooldown_counter = 0
        self.wait = 0

    def on_train_begin(self, logs=None):
        self._reset()

    def on_eval_end(self, logs=None):
        if logs is None or self.monitor not in logs:
            warnings.warn(
1119 1120
                'Monitor of ReduceLROnPlateau should be loss or metric name.'
            )
L
LielinJiang 已提交
1121 1122 1123 1124 1125 1126 1127
            return
        else:
            try:
                lr = self.model._optimizer._learning_rate
                if not isinstance(lr, float):
                    warnings.warn(
                        'Expected learning_rate be float, bug got {}.'.format(
1128 1129 1130
                            type(lr)
                        )
                    )
L
LielinJiang 已提交
1131 1132 1133
                    return
            except Exception as e:
                warnings.warn(
1134 1135 1136 1137
                    'There are something wrong when get learning_rate from optimizer: {}.'.format(
                        e
                    )
                )
L
LielinJiang 已提交
1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163
                return

        current = logs[self.monitor]
        if isinstance(current, (list, tuple)):
            current = current[0]
        elif isinstance(current, numbers.Number):
            current = current
        else:
            return

        if self.in_cooldown():
            self.cooldown_counter -= 1
            self.wait = 0

        if self.monitor_op(current, self.best):
            self.best = current
            self.wait = 0
        elif not self.in_cooldown():
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.model._optimizer.get_lr()
                if old_lr > np.float32(self.min_lr):
                    new_lr = old_lr * self.factor
                    new_lr = max(new_lr, self.min_lr)
                    self.model._optimizer._learning_rate = new_lr
                    if self.verbose > 0 and ParallelEnv().local_rank == 0:
1164 1165 1166 1167
                        print(
                            '\nEpoch %d: ReduceLROnPlateau reducing learning '
                            'rate to %s.' % (self.epoch + 1, new_lr)
                        )
L
LielinJiang 已提交
1168 1169 1170 1171 1172 1173
                    self.cooldown_counter = self.cooldown
                    self.wait = 0
        self.epoch += 1

    def in_cooldown(self):
        return self.cooldown_counter > 0