trainer.py 32.9 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
K
Kaipeng Deng 已提交
13 14 15 16 17 18 19
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
G
George Ni 已提交
20
import sys
21
import copy
K
Kaipeng Deng 已提交
22
import time
M
Manuel Garcia 已提交
23

K
Kaipeng Deng 已提交
24
import numpy as np
M
Mark Ma 已提交
25
import typing
F
Feng Ni 已提交
26 27
from PIL import Image, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
K
Kaipeng Deng 已提交
28 29

import paddle
W
wangguanzhong 已提交
30 31
import paddle.distributed as dist
from paddle.distributed import fleet
32
from paddle import amp
K
Kaipeng Deng 已提交
33
from paddle.static import InputSpec
34
from ppdet.optimizer import ModelEMA
K
Kaipeng Deng 已提交
35 36 37

from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
C
cnn 已提交
38
from ppdet.utils.visualizer import visualize_results, save_result
Z
zhiboniu 已提交
39
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
40 41
from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
from ppdet.data.source.sniper_coco import SniperCOCODataSet
K
Kaipeng Deng 已提交
42
from ppdet.data.source.category import get_categories
K
Kaipeng Deng 已提交
43
import ppdet.utils.stats as stats
44
from ppdet.utils import profiler
K
Kaipeng Deng 已提交
45

46
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
G
Guanghua Yu 已提交
47
from .export_utils import _dump_infer_config, _prune_input_spec
K
Kaipeng Deng 已提交
48 49

from ppdet.utils.logger import setup_logger
50
logger = setup_logger('ppdet.engine')
K
Kaipeng Deng 已提交
51 52 53

__all__ = ['Trainer']

54
MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack']
55

K
Kaipeng Deng 已提交
56 57 58 59 60 61 62

class Trainer(object):
    def __init__(self, cfg, mode='train'):
        self.cfg = cfg
        assert mode.lower() in ['train', 'eval', 'test'], \
                "mode should be 'train', 'eval' or 'test'"
        self.mode = mode.lower()
63
        self.optimizer = None
64
        self.is_loaded_weights = False
K
Kaipeng Deng 已提交
65

G
George Ni 已提交
66
        # build data loader
67 68 69 70 71 72 73 74 75
        if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
            self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
        else:
            self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]

        if cfg.architecture == 'DeepSORT' and self.mode == 'train':
            logger.error('DeepSORT has no need of training on mot dataset.')
            sys.exit(1)

76 77 78 79
        if cfg.architecture == 'FairMOT' and self.mode == 'eval':
            images = self.parse_mot_images(cfg)
            self.dataset.set_images(images)

G
George Ni 已提交
80 81 82 83 84 85
        if self.mode == 'train':
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num)

        if cfg.architecture == 'JDE' and self.mode == 'train':
            cfg['JDEEmbeddingHead'][
86 87
                'num_identities'] = self.dataset.num_identities_dict[0]
            # JDE only support single class MOT now.
G
George Ni 已提交
88

F
FlyingQianMM 已提交
89
        if cfg.architecture == 'FairMOT' and self.mode == 'train':
M
minghaoBD 已提交
90 91
            cfg['FairMOTEmbeddingHead'][
                'num_identities_dict'] = self.dataset.num_identities_dict
92
            # FairMOT support single class and multi-class MOT now.
F
FlyingQianMM 已提交
93

K
Kaipeng Deng 已提交
94
        # build model
95 96 97 98 99
        if 'model' not in self.cfg:
            self.model = create(cfg.architecture)
        else:
            self.model = self.cfg.model
            self.is_loaded_weights = True
100

101
        #normalize params for deploy
C
Chang Xu 已提交
102 103 104 105 106
        if 'slim' in cfg and cfg['slim_type'] == 'OFA':
            self.model.model.load_meanstd(cfg['TestReader'][
                'sample_transforms'])
        else:
            self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
107

108 109
        self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
        if self.use_ema:
G
Guanghua Yu 已提交
110 111
            ema_decay = self.cfg.get('ema_decay', 0.9998)
            cycle_epoch = self.cfg.get('cycle_epoch', -1)
112
            self.ema = ModelEMA(
G
Guanghua Yu 已提交
113 114 115 116
                self.model,
                decay=ema_decay,
                use_thres_step=True,
                cycle_epoch=cycle_epoch)
117

K
Kaipeng Deng 已提交
118 119 120
        # EvalDataset build with BatchSampler to evaluate in single device
        # TODO: multi-device evaluate
        if self.mode == 'eval':
121 122 123 124 125 126 127 128 129 130 131
            if cfg.architecture == 'FairMOT':
                self.loader = create('EvalMOTReader')(self.dataset, 0)
            else:
                self._eval_batch_sampler = paddle.io.BatchSampler(
                    self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
                reader_name = '{}Reader'.format(self.mode.capitalize())
                # If metric is VOC, need to be set collate_batch=False.
                if cfg.metric == 'VOC':
                    cfg[reader_name]['collate_batch'] = False
                self.loader = create(reader_name)(self.dataset, cfg.worker_num,
                                                  self._eval_batch_sampler)
K
Kaipeng Deng 已提交
132
        # TestDataset build after user set images, skip loader creation here
K
Kaipeng Deng 已提交
133 134 135 136 137

        # build optimizer in train mode
        if self.mode == 'train':
            steps_per_epoch = len(self.loader)
            self.lr = create('LearningRate')(steps_per_epoch)
W
Wenyu 已提交
138
            self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
K
Kaipeng Deng 已提交
139

M
minghaoBD 已提交
140 141 142 143
            # Unstructured pruner is only enabled in the train mode.
            if self.cfg.get('unstructured_prune'):
                self.pruner = create('UnstructuredPruner')(self.model,
                                                           steps_per_epoch)
M
minghaoBD 已提交
144

W
wangguanzhong 已提交
145 146
        self._nranks = dist.get_world_size()
        self._local_rank = dist.get_rank()
K
Kaipeng Deng 已提交
147

K
Kaipeng Deng 已提交
148 149 150
        self.status = {}

        self.start_epoch = 0
G
George Ni 已提交
151
        self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
K
Kaipeng Deng 已提交
152 153 154 155 156 157 158 159 160 161 162

        # initial default callbacks
        self._init_callbacks()

        # initial default metrics
        self._init_metrics()
        self._reset_metrics()

    def _init_callbacks(self):
        if self.mode == 'train':
            self._callbacks = [LogPrinter(self), Checkpointer(self)]
163
            if self.cfg.get('use_vdl', False):
164
                self._callbacks.append(VisualDLWriter(self))
165 166
            if self.cfg.get('save_proposals', False):
                self._callbacks.append(SniperProposalsGenerator(self))
K
Kaipeng Deng 已提交
167 168 169
            self._compose_callback = ComposeCallback(self._callbacks)
        elif self.mode == 'eval':
            self._callbacks = [LogPrinter(self)]
170 171
            if self.cfg.metric == 'WiderFace':
                self._callbacks.append(WiferFaceEval(self))
K
Kaipeng Deng 已提交
172
            self._compose_callback = ComposeCallback(self._callbacks)
173
        elif self.mode == 'test' and self.cfg.get('use_vdl', False):
174 175
            self._callbacks = [VisualDLWriter(self)]
            self._compose_callback = ComposeCallback(self._callbacks)
K
Kaipeng Deng 已提交
176 177 178 179
        else:
            self._callbacks = []
            self._compose_callback = None

K
Kaipeng Deng 已提交
180 181
    def _init_metrics(self, validate=False):
        if self.mode == 'test' or (self.mode == 'train' and not validate):
G
Guanghua Yu 已提交
182 183
            self._metrics = []
            return
184
        classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
185
        if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO":
W
wangxinxin08 已提交
186
            # TODO: bias should be unified
187
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
S
shangliang Xu 已提交
188 189
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
190
            save_prediction_only = self.cfg.get('save_prediction_only', False)
191 192 193

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
K
Kaipeng Deng 已提交
194 195
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None
196 197 198 199

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
200
            dataset = self.dataset
201 202 203 204
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()
205
                dataset = eval_dataset
206

207
            IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
208 209 210 211 212 213 214 215 216 217 218
            if self.cfg.metric == "COCO":
                self._metrics = [
                    COCOMetric(
                        anno_file=anno_file,
                        clsid2catid=clsid2catid,
                        classwise=classwise,
                        output_eval=output_eval,
                        bias=bias,
                        IouType=IouType,
                        save_prediction_only=save_prediction_only)
                ]
219
            elif self.cfg.metric == "SNIPERCOCO":  # sniper
220 221 222 223 224 225 226 227 228
                self._metrics = [
                    SNIPERCOCOMetric(
                        anno_file=anno_file,
                        dataset=dataset,
                        clsid2catid=clsid2catid,
                        classwise=classwise,
                        output_eval=output_eval,
                        bias=bias,
                        IouType=IouType,
229
                        save_prediction_only=save_prediction_only)
230
                ]
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
        elif self.cfg.metric == 'RBOX':
            # TODO: bias should be unified
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
            save_prediction_only = self.cfg.get('save_prediction_only', False)

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()

            self._metrics = [
                RBoxMetric(
                    anno_file=anno_file,
                    clsid2catid=clsid2catid,
                    classwise=classwise,
                    output_eval=output_eval,
                    bias=bias,
                    save_prediction_only=save_prediction_only)
            ]
K
Kaipeng Deng 已提交
260 261 262
        elif self.cfg.metric == 'VOC':
            self._metrics = [
                VOCMetric(
263
                    label_list=self.dataset.get_label_list(),
K
Kaipeng Deng 已提交
264
                    class_num=self.cfg.num_classes,
265 266
                    map_type=self.cfg.map_type,
                    classwise=classwise)
K
Kaipeng Deng 已提交
267
            ]
268 269 270 271 272 273 274 275 276
        elif self.cfg.metric == 'WiderFace':
            multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
            self._metrics = [
                WiderFaceMetric(
                    image_dir=os.path.join(self.dataset.dataset_dir,
                                           self.dataset.image_dir),
                    anno_file=self.dataset.get_anno(),
                    multi_scale=multi_scale)
            ]
277 278 279 280
        elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
281
            save_prediction_only = self.cfg.get('save_prediction_only', False)
282
            self._metrics = [
283 284 285 286 287 288
                KeyPointTopDownCOCOEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
289
            ]
Z
zhiboniu 已提交
290 291 292 293
        elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
294
            save_prediction_only = self.cfg.get('save_prediction_only', False)
Z
zhiboniu 已提交
295
            self._metrics = [
296 297 298 299 300 301
                KeyPointTopDownMPIIEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
Z
zhiboniu 已提交
302
            ]
G
George Ni 已提交
303 304
        elif self.cfg.metric == 'MOTDet':
            self._metrics = [JDEDetMetric(), ]
K
Kaipeng Deng 已提交
305
        else:
306
            logger.warning("Metric not support for metric type {}".format(
K
Kaipeng Deng 已提交
307
                self.cfg.metric))
K
Kaipeng Deng 已提交
308 309 310 311 312 313 314
            self._metrics = []

    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def register_callbacks(self, callbacks):
315
        callbacks = [c for c in list(callbacks) if c is not None]
K
Kaipeng Deng 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328
        for c in callbacks:
            assert isinstance(c, Callback), \
                    "metrics shoule be instances of subclass of Metric"
        self._callbacks.extend(callbacks)
        self._compose_callback = ComposeCallback(self._callbacks)

    def register_metrics(self, metrics):
        metrics = [m for m in list(metrics) if m is not None]
        for m in metrics:
            assert isinstance(m, Metric), \
                    "metrics shoule be instances of subclass of Metric"
        self._metrics.extend(metrics)

K
Kaipeng Deng 已提交
329
    def load_weights(self, weights):
330 331
        if self.is_loaded_weights:
            return
K
Kaipeng Deng 已提交
332
        self.start_epoch = 0
333
        load_pretrain_weight(self.model, weights)
K
Kaipeng Deng 已提交
334 335
        logger.debug("Load weights {} to start training".format(weights))

336 337 338 339 340 341 342
    def load_weights_sde(self, det_weights, reid_weights):
        if self.model.detector:
            load_weight(self.model.detector, det_weights)
            load_weight(self.model.reid, reid_weights)
        else:
            load_weight(self.model.reid, reid_weights)

K
Kaipeng Deng 已提交
343
    def resume_weights(self, weights):
344 345 346 347 348
        # support Distill resume weights
        if hasattr(self.model, 'student_model'):
            self.start_epoch = load_weight(self.model.student_model, weights,
                                           self.optimizer)
        else:
S
shangliang Xu 已提交
349 350
            self.start_epoch = load_weight(self.model, weights, self.optimizer,
                                           self.ema if self.use_ema else None)
K
Kaipeng Deng 已提交
351
        logger.debug("Resume weights of epoch {}".format(self.start_epoch))
K
Kaipeng Deng 已提交
352

K
Kaipeng Deng 已提交
353
    def train(self, validate=False):
K
Kaipeng Deng 已提交
354
        assert self.mode == 'train', "Model not in 'train' mode"
Z
zhiboniu 已提交
355
        Init_mark = False
K
Kaipeng Deng 已提交
356

357
        sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
W
wangxinxin08 已提交
358 359
                   self.cfg.use_gpu and self._nranks > 1)
        if sync_bn:
360 361
            self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)
W
wangxinxin08 已提交
362

363
        model = self.model
364
        if self.cfg.get('fleet', False):
365
            model = fleet.distributed_model(model)
W
wangguanzhong 已提交
366
            self.optimizer = fleet.distributed_optimizer(self.optimizer)
367
        elif self._nranks > 1:
G
George Ni 已提交
368 369 370 371
            find_unused_parameters = self.cfg[
                'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
            model = paddle.DataParallel(
                self.model, find_unused_parameters=find_unused_parameters)
372

W
Wenyu 已提交
373 374
        # enabel auto mixed precision mode
        if self.cfg.get('amp', False):
375
            scaler = amp.GradScaler(
376 377
                enable=self.cfg.use_gpu or self.cfg.use_npu,
                init_loss_scaling=1024)
K
Kaipeng Deng 已提交
378

K
Kaipeng Deng 已提交
379 380 381 382 383 384 385 386 387 388 389 390
        self.status.update({
            'epoch_id': self.start_epoch,
            'step_id': 0,
            'steps_per_epoch': len(self.loader)
        })

        self.status['batch_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['data_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)

G
Guanghua Yu 已提交
391
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
392 393 394
            flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, self.cfg.worker_num)
            self._flops(flops_loader)
395
        profiler_options = self.cfg.get('profiler_options', None)
G
Guanghua Yu 已提交
396

397 398
        self._compose_callback.on_train_begin(self.status)

K
Kaipeng Deng 已提交
399
        for epoch_id in range(self.start_epoch, self.cfg.epoch):
K
Kaipeng Deng 已提交
400
            self.status['mode'] = 'train'
K
Kaipeng Deng 已提交
401 402 403
            self.status['epoch_id'] = epoch_id
            self._compose_callback.on_epoch_begin(self.status)
            self.loader.dataset.set_epoch(epoch_id)
K
Kaipeng Deng 已提交
404
            model.train()
K
Kaipeng Deng 已提交
405 406 407 408
            iter_tic = time.time()
            for step_id, data in enumerate(self.loader):
                self.status['data_time'].update(time.time() - iter_tic)
                self.status['step_id'] = step_id
409
                profiler.add_profiler_step(profiler_options)
K
Kaipeng Deng 已提交
410
                self._compose_callback.on_step_begin(self.status)
S
shangliang Xu 已提交
411
                data['epoch_id'] = epoch_id
K
Kaipeng Deng 已提交
412

W
Wenyu 已提交
413
                if self.cfg.get('amp', False):
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
                    with amp.auto_cast(enable=self.cfg.use_gpu):
                        # model forward
                        outputs = model(data)
                        loss = outputs['loss']

                    # model backward
                    scaled_loss = scaler.scale(loss)
                    scaled_loss.backward()
                    # in dygraph mode, optimizer.minimize is equal to optimizer.step
                    scaler.minimize(self.optimizer, scaled_loss)
                else:
                    # model forward
                    outputs = model(data)
                    loss = outputs['loss']
                    # model backward
                    loss.backward()
                    self.optimizer.step()
K
Kaipeng Deng 已提交
431 432
                curr_lr = self.optimizer.get_lr()
                self.lr.step()
M
minghaoBD 已提交
433 434
                if self.cfg.get('unstructured_prune'):
                    self.pruner.step()
K
Kaipeng Deng 已提交
435 436 437
                self.optimizer.clear_grad()
                self.status['learning_rate'] = curr_lr

K
Kaipeng Deng 已提交
438
                if self._nranks < 2 or self._local_rank == 0:
K
Kaipeng Deng 已提交
439 440 441 442
                    self.status['training_staus'].update(outputs)

                self.status['batch_time'].update(time.time() - iter_tic)
                self._compose_callback.on_step_end(self.status)
443
                if self.use_ema:
S
shangliang Xu 已提交
444
                    self.ema.update()
F
Feng Ni 已提交
445
                iter_tic = time.time()
K
Kaipeng Deng 已提交
446

M
minghaoBD 已提交
447 448
            if self.cfg.get('unstructured_prune'):
                self.pruner.update_params()
449

S
shangliang Xu 已提交
450 451 452 453 454 455 456 457
            is_snapshot = (self._nranks < 2 or self._local_rank == 0) \
                       and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 or epoch_id == self.end_epoch - 1)
            if is_snapshot and self.use_ema:
                # apply ema weight on model
                weight = copy.deepcopy(self.model.state_dict())
                self.model.set_dict(self.ema.apply())
                self.status['weight'] = weight

K
Kaipeng Deng 已提交
458 459
            self._compose_callback.on_epoch_end(self.status)

S
shangliang Xu 已提交
460
            if validate and is_snapshot:
K
Kaipeng Deng 已提交
461 462 463 464 465 466 467
                if not hasattr(self, '_eval_loader'):
                    # build evaluation dataset and loader
                    self._eval_dataset = self.cfg.EvalDataset
                    self._eval_batch_sampler = \
                        paddle.io.BatchSampler(
                            self._eval_dataset,
                            batch_size=self.cfg.EvalReader['batch_size'])
468 469 470
                    # If metric is VOC, need to be set collate_batch=False.
                    if self.cfg.metric == 'VOC':
                        self.cfg['EvalReader']['collate_batch'] = False
K
Kaipeng Deng 已提交
471 472 473 474
                    self._eval_loader = create('EvalReader')(
                        self._eval_dataset,
                        self.cfg.worker_num,
                        batch_sampler=self._eval_batch_sampler)
Z
zhiboniu 已提交
475 476 477 478 479 480
                # if validation in training is enabled, metrics should be re-init
                # Init_mark makes sure this code will only execute once
                if validate and Init_mark == False:
                    Init_mark = True
                    self._init_metrics(validate=validate)
                    self._reset_metrics()
S
shangliang Xu 已提交
481

K
Kaipeng Deng 已提交
482
                with paddle.no_grad():
483
                    self.status['save_best_model'] = True
K
Kaipeng Deng 已提交
484 485
                    self._eval_with_loader(self._eval_loader)

S
shangliang Xu 已提交
486 487
            if is_snapshot and self.use_ema:
                # reset original weight
488
                self.model.set_dict(weight)
S
shangliang Xu 已提交
489
                self.status.pop('weight')
490

491 492
        self._compose_callback.on_train_end(self.status)

K
Kaipeng Deng 已提交
493
    def _eval_with_loader(self, loader):
K
Kaipeng Deng 已提交
494 495 496
        sample_num = 0
        tic = time.time()
        self._compose_callback.on_epoch_begin(self.status)
K
Kaipeng Deng 已提交
497 498
        self.status['mode'] = 'eval'
        self.model.eval()
G
Guanghua Yu 已提交
499
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
500 501 502
            flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
            self._flops(flops_loader)
K
Kaipeng Deng 已提交
503
        for step_id, data in enumerate(loader):
K
Kaipeng Deng 已提交
504 505 506 507 508 509 510 511 512
            self.status['step_id'] = step_id
            self._compose_callback.on_step_begin(self.status)
            # forward
            outs = self.model(data)

            # update metrics
            for metric in self._metrics:
                metric.update(data, outs)

M
Mark Ma 已提交
513 514 515 516 517
            # multi-scale inputs: all inputs have same im_id
            if isinstance(data, typing.Sequence):
                sample_num += data[0]['im_id'].numpy().shape[0]
            else:
                sample_num += data['im_id'].numpy().shape[0]
K
Kaipeng Deng 已提交
518 519 520 521 522 523 524 525 526
            self._compose_callback.on_step_end(self.status)

        self.status['sample_num'] = sample_num
        self.status['cost_time'] = time.time() - tic

        # accumulate metric to log out
        for metric in self._metrics:
            metric.accumulate()
            metric.log()
527
        self._compose_callback.on_epoch_end(self.status)
K
Kaipeng Deng 已提交
528 529 530
        # reset metric states for metric may performed multiple times
        self._reset_metrics()

K
Kaipeng Deng 已提交
531
    def evaluate(self):
532 533
        with paddle.no_grad():
            self._eval_with_loader(self.loader)
K
Kaipeng Deng 已提交
534

C
cnn 已提交
535 536 537 538 539
    def predict(self,
                images,
                draw_threshold=0.5,
                output_dir='output',
                save_txt=False):
K
Kaipeng Deng 已提交
540 541 542 543 544 545
        self.dataset.set_images(images)
        loader = create('TestReader')(self.dataset, 0)

        imid2path = self.dataset.get_imid2path()

        anno_file = self.dataset.get_anno()
C
cnn 已提交
546 547
        clsid2catid, catid2name = get_categories(
            self.cfg.metric, anno_file=anno_file)
K
Kaipeng Deng 已提交
548

K
Kaipeng Deng 已提交
549 550 551
        # Run Infer 
        self.status['mode'] = 'test'
        self.model.eval()
G
Guanghua Yu 已提交
552
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
553 554
            flops_loader = create('TestReader')(self.dataset, 0)
            self._flops(flops_loader)
555
        results = []
K
Kaipeng Deng 已提交
556 557 558 559
        for step_id, data in enumerate(loader):
            self.status['step_id'] = step_id
            # forward
            outs = self.model(data)
560

K
Kaipeng Deng 已提交
561
            for key in ['im_shape', 'scale_factor', 'im_id']:
M
Mark Ma 已提交
562 563 564 565
                if isinstance(data, typing.Sequence):
                    outs[key] = data[0][key]
                else:
                    outs[key] = data[key]
G
Guanghua Yu 已提交
566
            for key, value in outs.items():
567 568
                if hasattr(value, 'numpy'):
                    outs[key] = value.numpy()
569 570 571
            results.append(outs)
        # sniper
        if type(self.dataset) == SniperCOCODataSet:
572 573
            results = self.dataset.anno_cropper.aggregate_chips_detections(
                results)
K
Kaipeng Deng 已提交
574

575
        for outs in results:
K
Kaipeng Deng 已提交
576 577
            batch_res = get_infer_results(outs, clsid2catid)
            bbox_num = outs['bbox_num']
Z
zhiboniu 已提交
578

K
Kaipeng Deng 已提交
579 580 581 582
            start = 0
            for i, im_id in enumerate(outs['im_id']):
                image_path = imid2path[int(im_id)]
                image = Image.open(image_path).convert('RGB')
583
                image = ImageOps.exif_transpose(image)
584
                self.status['original_image'] = np.array(image.copy())
K
Kaipeng Deng 已提交
585

586
                end = start + bbox_num[i]
K
Kaipeng Deng 已提交
587 588 589 590
                bbox_res = batch_res['bbox'][start:end] \
                        if 'bbox' in batch_res else None
                mask_res = batch_res['mask'][start:end] \
                        if 'mask' in batch_res else None
G
Guanghua Yu 已提交
591 592
                segm_res = batch_res['segm'][start:end] \
                        if 'segm' in batch_res else None
593 594 595 596
                keypoint_res = batch_res['keypoint'][start:end] \
                        if 'keypoint' in batch_res else None
                image = visualize_results(
                    image, bbox_res, mask_res, segm_res, keypoint_res,
C
cnn 已提交
597
                    int(im_id), catid2name, draw_threshold)
598
                self.status['result_image'] = np.array(image.copy())
599 600
                if self._compose_callback:
                    self._compose_callback.on_step_end(self.status)
K
Kaipeng Deng 已提交
601 602 603 604 605
                # save image with detection
                save_name = self._get_save_image_name(output_dir, image_path)
                logger.info("Detection bbox results save in {}".format(
                    save_name))
                image.save(save_name, quality=95)
C
cnn 已提交
606 607
                if save_txt:
                    save_path = os.path.splitext(save_name)[0] + '.txt'
608 609 610 611 612 613 614
                    results = {}
                    results["im_id"] = im_id
                    if bbox_res:
                        results["bbox_res"] = bbox_res
                    if keypoint_res:
                        results["keypoint_res"] = keypoint_res
                    save_result(save_path, results, catid2name, draw_threshold)
K
Kaipeng Deng 已提交
615 616 617 618 619 620 621 622 623 624 625 626
                start = end

    def _get_save_image_name(self, output_dir, image_path):
        """
        Get save image name from source image path.
        """
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        image_name = os.path.split(image_path)[-1]
        name, ext = os.path.splitext(image_name)
        return os.path.join(output_dir, "{}".format(name)) + ext

G
Guanghua Yu 已提交
627
    def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
K
Kaipeng Deng 已提交
628
        image_shape = None
629 630
        im_shape = [None, 2]
        scale_factor = [None, 2]
631 632 633 634 635 636
        if self.cfg.architecture in MOT_ARCH:
            test_reader_name = 'TestMOTReader'
        else:
            test_reader_name = 'TestReader'
        if 'inputs_def' in self.cfg[test_reader_name]:
            inputs_def = self.cfg[test_reader_name]['inputs_def']
K
Kaipeng Deng 已提交
637
            image_shape = inputs_def.get('image_shape', None)
G
Guanghua Yu 已提交
638
        # set image_shape=[None, 3, -1, -1] as default
K
Kaipeng Deng 已提交
639
        if image_shape is None:
G
Guanghua Yu 已提交
640
            image_shape = [None, 3, -1, -1]
641

G
Guanghua Yu 已提交
642 643
        if len(image_shape) == 3:
            image_shape = [None] + image_shape
644 645 646
        else:
            im_shape = [image_shape[0], 2]
            scale_factor = [image_shape[0], 2]
K
Kaipeng Deng 已提交
647

648
        if hasattr(self.model, 'deploy'):
649
            self.model.deploy = True
S
shangliang Xu 已提交
650 651 652 653 654

        for layer in self.model.sublayers():
            if hasattr(layer, 'convert_to_deploy'):
                layer.convert_to_deploy()

655 656 657 658 659 660
        export_post_process = self.cfg['export'].get(
            'post_process', False) if hasattr(self.cfg, 'export') else True
        export_nms = self.cfg['export'].get('nms', False) if hasattr(
            self.cfg, 'export') else True
        export_benchmark = self.cfg['export'].get(
            'benchmark', False) if hasattr(self.cfg, 'export') else False
661 662 663
        if hasattr(self.model, 'fuse_norm'):
            self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
                                                              False)
664 665 666 667 668 669
        if hasattr(self.model, 'export_post_process'):
            self.model.export_post_process = export_post_process if not export_benchmark else False
        if hasattr(self.model, 'export_nms'):
            self.model.export_nms = export_nms if not export_benchmark else False
        if export_post_process and not export_benchmark:
            image_shape = [None] + image_shape[1:]
K
Kaipeng Deng 已提交
670

K
Kaipeng Deng 已提交
671 672 673 674 675 676 677
        # Save infer cfg
        _dump_infer_config(self.cfg,
                           os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
                           self.model)

        input_spec = [{
            "image": InputSpec(
G
Guanghua Yu 已提交
678
                shape=image_shape, name='image'),
K
Kaipeng Deng 已提交
679
            "im_shape": InputSpec(
680
                shape=im_shape, name='im_shape'),
K
Kaipeng Deng 已提交
681
            "scale_factor": InputSpec(
682
                shape=scale_factor, name='scale_factor')
K
Kaipeng Deng 已提交
683
        }]
G
George Ni 已提交
684 685 686 687 688
        if self.cfg.architecture == 'DeepSORT':
            input_spec[0].update({
                "crops": InputSpec(
                    shape=[None, 3, 192, 64], name='crops')
            })
G
Guanghua Yu 已提交
689 690 691 692 693 694 695 696 697 698 699 700
        if prune_input:
            static_model = paddle.jit.to_static(
                self.model, input_spec=input_spec)
            # NOTE: dy2st do not pruned program, but jit.save will prune program
            # input spec, prune input spec here and save with pruned input spec
            pruned_input_spec = _prune_input_spec(
                input_spec, static_model.forward.main_program,
                static_model.forward.outputs)
        else:
            static_model = None
            pruned_input_spec = input_spec

G
Guanghua Yu 已提交
701
        # TODO: Hard code, delete it when support prune input_spec.
702
        if self.cfg.architecture == 'PicoDet' and not export_post_process:
G
Guanghua Yu 已提交
703 704 705 706 707
            pruned_input_spec = [{
                "image": InputSpec(
                    shape=image_shape, name='image')
            }]

G
Guanghua Yu 已提交
708 709 710 711 712 713 714 715
        return static_model, pruned_input_spec

    def export(self, output_dir='output_inference'):
        self.model.eval()
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
K
Kaipeng Deng 已提交
716

G
Guanghua Yu 已提交
717 718
        static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir)
G
Guanghua Yu 已提交
719 720 721

        # dy2st and save model
        if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
722 723 724 725 726
            paddle.jit.save(
                static_model,
                os.path.join(save_dir, 'model'),
                input_spec=pruned_input_spec)
        else:
727
            self.cfg.slim.save_quantized_model(
728 729
                self.model,
                os.path.join(save_dir, 'model'),
G
Guanghua Yu 已提交
730 731
                input_spec=pruned_input_spec)
        logger.info("Export model and saved in {}".format(save_dir))
732

G
Guanghua Yu 已提交
733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752
    def post_quant(self, output_dir='output_inference'):
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for idx, data in enumerate(self.loader):
            self.model(data)
            if idx == int(self.cfg.get('quant_batch_num', 10)):
                break

        # TODO: support prune input_spec
        _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir, prune_input=False)

        self.cfg.slim.save_quantized_model(
            self.model,
            os.path.join(save_dir, 'model'),
            input_spec=pruned_input_spec)
        logger.info("Export Post-Quant model and saved in {}".format(save_dir))
G
Guanghua Yu 已提交
753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777

    def _flops(self, loader):
        self.model.eval()
        try:
            import paddleslim
        except Exception as e:
            logger.warning(
                'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
            )
            return

        from paddleslim.analysis import dygraph_flops as flops
        input_data = None
        for data in loader:
            input_data = data
            break

        input_spec = [{
            "image": input_data['image'][0].unsqueeze(0),
            "im_shape": input_data['im_shape'][0].unsqueeze(0),
            "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
        }]
        flops = flops(self.model, input_spec) / (1000**3)
        logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
            flops, input_data['image'][0].unsqueeze(0).shape))
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800

    def parse_mot_images(self, cfg):
        import glob
        # for quant
        dataset_dir = cfg['EvalMOTDataset'].dataset_dir
        data_root = cfg['EvalMOTDataset'].data_root
        data_root = '{}/{}'.format(dataset_dir, data_root)
        seqs = os.listdir(data_root)
        seqs.sort()
        all_images = []
        for seq in seqs:
            infer_dir = os.path.join(data_root, seq)
            assert infer_dir is None or os.path.isdir(infer_dir), \
                "{} is not a directory".format(infer_dir)
            images = set()
            exts = ['jpg', 'jpeg', 'png', 'bmp']
            exts += [ext.upper() for ext in exts]
            for ext in exts:
                images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
            images = list(images)
            images.sort()
            assert len(images) > 0, "no image found in {}".format(infer_dir)
            all_images.extend(images)
801 802 803
            logger.info("Found {} inference images in total.".format(
                len(images)))
        return all_images