callbacks.py 10.0 KB
Newer Older
K
Kaipeng Deng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
20
import sys
K
Kaipeng Deng 已提交
21
import datetime
22 23
import six
import numpy as np
K
Kaipeng Deng 已提交
24 25

import paddle
W
wangguanzhong 已提交
26
import paddle.distributed as dist
K
Kaipeng Deng 已提交
27 28

from ppdet.utils.checkpoint import save_model
W
wangxinxin08 已提交
29
from ppdet.optimizer import ModelEMA
K
Kaipeng Deng 已提交
30 31

from ppdet.utils.logger import setup_logger
32
logger = setup_logger('ppdet.engine')
K
Kaipeng Deng 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

__all__ = ['Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer']


class Callback(object):
    def __init__(self, model):
        self.model = model

    def on_step_begin(self, status):
        pass

    def on_step_end(self, status):
        pass

    def on_epoch_begin(self, status):
        pass

    def on_epoch_end(self, status):
        pass


class ComposeCallback(object):
    def __init__(self, callbacks):
56 57 58 59
        callbacks = [c for c in list(callbacks) if c is not None]
        for c in callbacks:
            assert isinstance(
                c, Callback), "callback should be subclass of Callback"
K
Kaipeng Deng 已提交
60 61 62
        self._callbacks = callbacks

    def on_step_begin(self, status):
63 64
        for c in self._callbacks:
            c.on_step_begin(status)
K
Kaipeng Deng 已提交
65 66

    def on_step_end(self, status):
67 68
        for c in self._callbacks:
            c.on_step_end(status)
K
Kaipeng Deng 已提交
69 70

    def on_epoch_begin(self, status):
71 72
        for c in self._callbacks:
            c.on_epoch_begin(status)
K
Kaipeng Deng 已提交
73 74

    def on_epoch_end(self, status):
75 76
        for c in self._callbacks:
            c.on_epoch_end(status)
K
Kaipeng Deng 已提交
77 78 79 80 81 82 83


class LogPrinter(Callback):
    def __init__(self, model):
        super(LogPrinter, self).__init__(model)

    def on_step_end(self, status):
W
wangguanzhong 已提交
84
        if dist.get_world_size() < 2 or dist.get_rank() == 0:
K
Kaipeng Deng 已提交
85 86
            mode = status['mode']
            if mode == 'train':
K
Kaipeng Deng 已提交
87 88 89 90 91 92 93 94
                epoch_id = status['epoch_id']
                step_id = status['step_id']
                steps_per_epoch = status['steps_per_epoch']
                training_staus = status['training_staus']
                batch_time = status['batch_time']
                data_time = status['data_time']

                epoches = self.model.cfg.epoch
K
Kaipeng Deng 已提交
95 96
                batch_size = self.model.cfg['{}Reader'.format(mode.capitalize(
                ))]['batch_size']
K
Kaipeng Deng 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125

                logs = training_staus.log()
                space_fmt = ':' + str(len(str(steps_per_epoch))) + 'd'
                if step_id % self.model.cfg.log_iter == 0:
                    eta_steps = (epoches - epoch_id) * steps_per_epoch - step_id
                    eta_sec = eta_steps * batch_time.global_avg
                    eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
                    ips = float(batch_size) / batch_time.avg
                    fmt = ' '.join([
                        'Epoch: [{}]',
                        '[{' + space_fmt + '}/{}]',
                        'learning_rate: {lr:.6f}',
                        '{meters}',
                        'eta: {eta}',
                        'batch_cost: {btime}',
                        'data_cost: {dtime}',
                        'ips: {ips:.4f} images/s',
                    ])
                    fmt = fmt.format(
                        epoch_id,
                        step_id,
                        steps_per_epoch,
                        lr=status['learning_rate'],
                        meters=logs,
                        eta=eta_str,
                        btime=str(batch_time),
                        dtime=str(data_time),
                        ips=ips)
                    logger.info(fmt)
K
Kaipeng Deng 已提交
126
            if mode == 'eval':
K
Kaipeng Deng 已提交
127 128 129 130 131
                step_id = status['step_id']
                if step_id % 100 == 0:
                    logger.info("Eval iter: {}".format(step_id))

    def on_epoch_end(self, status):
W
wangguanzhong 已提交
132
        if dist.get_world_size() < 2 or dist.get_rank() == 0:
K
Kaipeng Deng 已提交
133 134
            mode = status['mode']
            if mode == 'eval':
K
Kaipeng Deng 已提交
135 136 137 138 139 140 141 142 143
                sample_num = status['sample_num']
                cost_time = status['cost_time']
                logger.info('Total sample number: {}, averge FPS: {}'.format(
                    sample_num, sample_num / cost_time))


class Checkpointer(Callback):
    def __init__(self, model):
        super(Checkpointer, self).__init__(model)
W
wangxinxin08 已提交
144
        cfg = self.model.cfg
145
        self.best_ap = 0.
W
wangxinxin08 已提交
146
        self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
147 148
        self.save_dir = os.path.join(self.model.cfg.save_dir,
                                     self.model.cfg.filename)
149 150 151 152
        if hasattr(self.model.model, 'student_model'):
            self.weight = self.model.model.student_model
        else:
            self.weight = self.model.model
W
wangxinxin08 已提交
153 154
        if self.use_ema:
            self.ema = ModelEMA(
155
                cfg['ema_decay'], self.weight, use_thres_step=True)
W
wangxinxin08 已提交
156 157 158

    def on_step_end(self, status):
        if self.use_ema:
159
            self.ema.update(self.weight)
K
Kaipeng Deng 已提交
160 161

    def on_epoch_end(self, status):
K
Kaipeng Deng 已提交
162 163
        # Checkpointer only performed during training
        mode = status['mode']
164 165 166
        epoch_id = status['epoch_id']
        weight = None
        save_name = None
W
wangguanzhong 已提交
167
        if dist.get_world_size() < 2 or dist.get_rank() == 0:
168 169 170 171 172 173 174 175
            if mode == 'train':
                end_epoch = self.model.cfg.epoch
                if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
                    save_name = str(
                        epoch_id) if epoch_id != end_epoch - 1 else "model_final"
                    if self.use_ema:
                        weight = self.ema.apply()
                    else:
176
                        weight = self.weight
177 178 179 180 181
            elif mode == 'eval':
                if 'save_best_model' in status and status['save_best_model']:
                    for metric in self.model._metrics:
                        map_res = metric.get_results()
                        key = 'bbox' if 'bbox' in map_res else 'mask'
182 183 184 185 186
                        if key not in map_res:
                            logger.warn("Evaluation results empty, this may be due to " \
                                        "training iterations being too few or not " \
                                        "loading the correct weights.")
                            return
187 188 189 190 191 192
                        if map_res[key][0] > self.best_ap:
                            self.best_ap = map_res[key][0]
                            save_name = 'best_model'
                            if self.use_ema:
                                weight = self.ema.apply()
                            else:
193
                                weight = self.weight
194 195 196 197 198
                        logger.info("Best test {} ap is {:0.3f}.".format(
                            key, self.best_ap))
            if weight:
                save_model(weight, self.model.optimizer, self.save_dir,
                           save_name, epoch_id + 1)
199 200 201 202 203 204 205 206 207 208 209 210


class WiferFaceEval(Callback):
    def __init__(self, model):
        super(WiferFaceEval, self).__init__(model)

    def on_epoch_begin(self, status):
        assert self.model.mode == 'eval', \
            "WiferFaceEval can only be set during evaluation"
        for metric in self.model._metrics:
            metric.update(self.model.model)
        sys.exit()
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235


class VisualDLWriter(Callback):
    """
    Use VisualDL to log data or image
    """

    def __init__(self, model):
        super(VisualDLWriter, self).__init__(model)

        assert six.PY3, "VisualDL requires Python >= 3.5"
        try:
            from visualdl import LogWriter
        except Exception as e:
            logger.error('visualdl not found, plaese install visualdl. '
                         'for example: `pip install visualdl`.')
            raise e
        self.vdl_writer = LogWriter(model.cfg.vdl_log_dir)
        self.vdl_loss_step = 0
        self.vdl_mAP_step = 0
        self.vdl_image_step = 0
        self.vdl_image_frame = 0

    def on_step_end(self, status):
        mode = status['mode']
W
wangguanzhong 已提交
236
        if dist.get_world_size() < 2 or dist.get_rank() == 0:
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
            if mode == 'train':
                training_staus = status['training_staus']
                for loss_name, loss_value in training_staus.get().items():
                    self.vdl_writer.add_scalar(loss_name, loss_value,
                                               self.vdl_loss_step)
                    self.vdl_loss_step += 1
            elif mode == 'test':
                ori_image = status['original_image']
                result_image = status['result_image']
                self.vdl_writer.add_image(
                    "original/frame_{}".format(self.vdl_image_frame), ori_image,
                    self.vdl_image_step)
                self.vdl_writer.add_image(
                    "result/frame_{}".format(self.vdl_image_frame),
                    result_image, self.vdl_image_step)
                self.vdl_image_step += 1
                # each frame can display ten pictures at most.
                if self.vdl_image_step % 10 == 0:
                    self.vdl_image_step = 0
                    self.vdl_image_frame += 1

    def on_epoch_end(self, status):
        mode = status['mode']
W
wangguanzhong 已提交
260
        if dist.get_world_size() < 2 or dist.get_rank() == 0:
261 262 263 264 265 266 267
            if mode == 'eval':
                for metric in self.model._metrics:
                    for key, map_value in metric.get_results().items():
                        self.vdl_writer.add_scalar("{}-mAP".format(key),
                                                   map_value[0],
                                                   self.vdl_mAP_step)
                self.vdl_mAP_step += 1