trainer.py 32.4 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
K
Kaipeng Deng 已提交
13 14 15 16 17 18 19
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
G
George Ni 已提交
20
import sys
21
import copy
K
Kaipeng Deng 已提交
22
import time
M
Manuel Garcia 已提交
23

K
Kaipeng Deng 已提交
24
import numpy as np
M
Mark Ma 已提交
25
import typing
F
Feng Ni 已提交
26 27
from PIL import Image, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
K
Kaipeng Deng 已提交
28 29

import paddle
W
wangguanzhong 已提交
30 31
import paddle.distributed as dist
from paddle.distributed import fleet
32
from paddle import amp
K
Kaipeng Deng 已提交
33
from paddle.static import InputSpec
34
from ppdet.optimizer import ModelEMA
K
Kaipeng Deng 已提交
35 36 37

from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
C
cnn 已提交
38
from ppdet.utils.visualizer import visualize_results, save_result
Z
zhiboniu 已提交
39
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
40 41
from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
from ppdet.data.source.sniper_coco import SniperCOCODataSet
K
Kaipeng Deng 已提交
42
from ppdet.data.source.category import get_categories
K
Kaipeng Deng 已提交
43
import ppdet.utils.stats as stats
44
from ppdet.utils import profiler
K
Kaipeng Deng 已提交
45

46
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
G
Guanghua Yu 已提交
47
from .export_utils import _dump_infer_config, _prune_input_spec
K
Kaipeng Deng 已提交
48 49

from ppdet.utils.logger import setup_logger
50
logger = setup_logger('ppdet.engine')
K
Kaipeng Deng 已提交
51 52 53

__all__ = ['Trainer']

54 55
MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']

K
Kaipeng Deng 已提交
56 57 58 59 60 61 62

class Trainer(object):
    def __init__(self, cfg, mode='train'):
        self.cfg = cfg
        assert mode.lower() in ['train', 'eval', 'test'], \
                "mode should be 'train', 'eval' or 'test'"
        self.mode = mode.lower()
63
        self.optimizer = None
64
        self.is_loaded_weights = False
K
Kaipeng Deng 已提交
65

G
George Ni 已提交
66
        # build data loader
67 68 69 70 71 72 73 74 75
        if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
            self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
        else:
            self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]

        if cfg.architecture == 'DeepSORT' and self.mode == 'train':
            logger.error('DeepSORT has no need of training on mot dataset.')
            sys.exit(1)

76 77 78 79
        if cfg.architecture == 'FairMOT' and self.mode == 'eval':
            images = self.parse_mot_images(cfg)
            self.dataset.set_images(images)

G
George Ni 已提交
80 81 82 83 84 85
        if self.mode == 'train':
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num)

        if cfg.architecture == 'JDE' and self.mode == 'train':
            cfg['JDEEmbeddingHead'][
86 87
                'num_identities'] = self.dataset.num_identities_dict[0]
            # JDE only support single class MOT now.
G
George Ni 已提交
88

F
FlyingQianMM 已提交
89
        if cfg.architecture == 'FairMOT' and self.mode == 'train':
M
minghaoBD 已提交
90 91
            cfg['FairMOTEmbeddingHead'][
                'num_identities_dict'] = self.dataset.num_identities_dict
92
            # FairMOT support single class and multi-class MOT now.
F
FlyingQianMM 已提交
93

K
Kaipeng Deng 已提交
94
        # build model
95 96 97 98 99
        if 'model' not in self.cfg:
            self.model = create(cfg.architecture)
        else:
            self.model = self.cfg.model
            self.is_loaded_weights = True
100

101
        #normalize params for deploy
C
Chang Xu 已提交
102 103 104 105 106
        if 'slim' in cfg and cfg['slim_type'] == 'OFA':
            self.model.model.load_meanstd(cfg['TestReader'][
                'sample_transforms'])
        else:
            self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
107

108 109
        self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
        if self.use_ema:
G
Guanghua Yu 已提交
110 111
            ema_decay = self.cfg.get('ema_decay', 0.9998)
            cycle_epoch = self.cfg.get('cycle_epoch', -1)
112
            self.ema = ModelEMA(
G
Guanghua Yu 已提交
113 114 115 116
                self.model,
                decay=ema_decay,
                use_thres_step=True,
                cycle_epoch=cycle_epoch)
117

K
Kaipeng Deng 已提交
118 119 120
        # EvalDataset build with BatchSampler to evaluate in single device
        # TODO: multi-device evaluate
        if self.mode == 'eval':
121 122 123 124 125 126 127 128 129 130 131
            if cfg.architecture == 'FairMOT':
                self.loader = create('EvalMOTReader')(self.dataset, 0)
            else:
                self._eval_batch_sampler = paddle.io.BatchSampler(
                    self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
                reader_name = '{}Reader'.format(self.mode.capitalize())
                # If metric is VOC, need to be set collate_batch=False.
                if cfg.metric == 'VOC':
                    cfg[reader_name]['collate_batch'] = False
                self.loader = create(reader_name)(self.dataset, cfg.worker_num,
                                                  self._eval_batch_sampler)
K
Kaipeng Deng 已提交
132
        # TestDataset build after user set images, skip loader creation here
K
Kaipeng Deng 已提交
133 134 135 136 137

        # build optimizer in train mode
        if self.mode == 'train':
            steps_per_epoch = len(self.loader)
            self.lr = create('LearningRate')(steps_per_epoch)
W
Wenyu 已提交
138
            self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
K
Kaipeng Deng 已提交
139

M
minghaoBD 已提交
140 141 142 143
            # Unstructured pruner is only enabled in the train mode.
            if self.cfg.get('unstructured_prune'):
                self.pruner = create('UnstructuredPruner')(self.model,
                                                           steps_per_epoch)
M
minghaoBD 已提交
144

W
wangguanzhong 已提交
145 146
        self._nranks = dist.get_world_size()
        self._local_rank = dist.get_rank()
K
Kaipeng Deng 已提交
147

K
Kaipeng Deng 已提交
148 149 150
        self.status = {}

        self.start_epoch = 0
G
George Ni 已提交
151
        self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
K
Kaipeng Deng 已提交
152 153 154 155 156 157 158 159 160 161 162

        # initial default callbacks
        self._init_callbacks()

        # initial default metrics
        self._init_metrics()
        self._reset_metrics()

    def _init_callbacks(self):
        if self.mode == 'train':
            self._callbacks = [LogPrinter(self), Checkpointer(self)]
163
            if self.cfg.get('use_vdl', False):
164
                self._callbacks.append(VisualDLWriter(self))
165 166
            if self.cfg.get('save_proposals', False):
                self._callbacks.append(SniperProposalsGenerator(self))
K
Kaipeng Deng 已提交
167 168 169
            self._compose_callback = ComposeCallback(self._callbacks)
        elif self.mode == 'eval':
            self._callbacks = [LogPrinter(self)]
170 171
            if self.cfg.metric == 'WiderFace':
                self._callbacks.append(WiferFaceEval(self))
K
Kaipeng Deng 已提交
172
            self._compose_callback = ComposeCallback(self._callbacks)
173
        elif self.mode == 'test' and self.cfg.get('use_vdl', False):
174 175
            self._callbacks = [VisualDLWriter(self)]
            self._compose_callback = ComposeCallback(self._callbacks)
K
Kaipeng Deng 已提交
176 177 178 179
        else:
            self._callbacks = []
            self._compose_callback = None

K
Kaipeng Deng 已提交
180 181
    def _init_metrics(self, validate=False):
        if self.mode == 'test' or (self.mode == 'train' and not validate):
G
Guanghua Yu 已提交
182 183
            self._metrics = []
            return
184
        classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
185
        if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO":
W
wangxinxin08 已提交
186
            # TODO: bias should be unified
187
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
S
shangliang Xu 已提交
188 189
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
190
            save_prediction_only = self.cfg.get('save_prediction_only', False)
191 192 193

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
K
Kaipeng Deng 已提交
194 195
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None
196 197 198 199

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
200
            dataset = self.dataset
201 202 203 204
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()
205
                dataset = eval_dataset
206

207
            IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
208 209 210 211 212 213 214 215 216 217 218
            if self.cfg.metric == "COCO":
                self._metrics = [
                    COCOMetric(
                        anno_file=anno_file,
                        clsid2catid=clsid2catid,
                        classwise=classwise,
                        output_eval=output_eval,
                        bias=bias,
                        IouType=IouType,
                        save_prediction_only=save_prediction_only)
                ]
219
            elif self.cfg.metric == "SNIPERCOCO":  # sniper
220 221 222 223 224 225 226 227 228
                self._metrics = [
                    SNIPERCOCOMetric(
                        anno_file=anno_file,
                        dataset=dataset,
                        clsid2catid=clsid2catid,
                        classwise=classwise,
                        output_eval=output_eval,
                        bias=bias,
                        IouType=IouType,
229
                        save_prediction_only=save_prediction_only)
230
                ]
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
        elif self.cfg.metric == 'RBOX':
            # TODO: bias should be unified
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
            save_prediction_only = self.cfg.get('save_prediction_only', False)

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()

            self._metrics = [
                RBoxMetric(
                    anno_file=anno_file,
                    clsid2catid=clsid2catid,
                    classwise=classwise,
                    output_eval=output_eval,
                    bias=bias,
                    save_prediction_only=save_prediction_only)
            ]
K
Kaipeng Deng 已提交
260 261 262
        elif self.cfg.metric == 'VOC':
            self._metrics = [
                VOCMetric(
263
                    label_list=self.dataset.get_label_list(),
K
Kaipeng Deng 已提交
264
                    class_num=self.cfg.num_classes,
265 266
                    map_type=self.cfg.map_type,
                    classwise=classwise)
K
Kaipeng Deng 已提交
267
            ]
268 269 270 271 272 273 274 275 276
        elif self.cfg.metric == 'WiderFace':
            multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
            self._metrics = [
                WiderFaceMetric(
                    image_dir=os.path.join(self.dataset.dataset_dir,
                                           self.dataset.image_dir),
                    anno_file=self.dataset.get_anno(),
                    multi_scale=multi_scale)
            ]
277 278 279 280
        elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
281
            save_prediction_only = self.cfg.get('save_prediction_only', False)
282
            self._metrics = [
283 284 285 286 287 288
                KeyPointTopDownCOCOEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
289
            ]
Z
zhiboniu 已提交
290 291 292 293
        elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
294
            save_prediction_only = self.cfg.get('save_prediction_only', False)
Z
zhiboniu 已提交
295
            self._metrics = [
296 297 298 299 300 301
                KeyPointTopDownMPIIEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
Z
zhiboniu 已提交
302
            ]
G
George Ni 已提交
303 304
        elif self.cfg.metric == 'MOTDet':
            self._metrics = [JDEDetMetric(), ]
K
Kaipeng Deng 已提交
305
        else:
306
            logger.warning("Metric not support for metric type {}".format(
K
Kaipeng Deng 已提交
307
                self.cfg.metric))
K
Kaipeng Deng 已提交
308 309 310 311 312 313 314
            self._metrics = []

    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def register_callbacks(self, callbacks):
315
        callbacks = [c for c in list(callbacks) if c is not None]
K
Kaipeng Deng 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328
        for c in callbacks:
            assert isinstance(c, Callback), \
                    "metrics shoule be instances of subclass of Metric"
        self._callbacks.extend(callbacks)
        self._compose_callback = ComposeCallback(self._callbacks)

    def register_metrics(self, metrics):
        metrics = [m for m in list(metrics) if m is not None]
        for m in metrics:
            assert isinstance(m, Metric), \
                    "metrics shoule be instances of subclass of Metric"
        self._metrics.extend(metrics)

K
Kaipeng Deng 已提交
329
    def load_weights(self, weights):
330 331
        if self.is_loaded_weights:
            return
K
Kaipeng Deng 已提交
332
        self.start_epoch = 0
333
        load_pretrain_weight(self.model, weights)
K
Kaipeng Deng 已提交
334 335
        logger.debug("Load weights {} to start training".format(weights))

336 337 338 339 340 341 342
    def load_weights_sde(self, det_weights, reid_weights):
        if self.model.detector:
            load_weight(self.model.detector, det_weights)
            load_weight(self.model.reid, reid_weights)
        else:
            load_weight(self.model.reid, reid_weights)

K
Kaipeng Deng 已提交
343
    def resume_weights(self, weights):
344 345 346 347 348
        # support Distill resume weights
        if hasattr(self.model, 'student_model'):
            self.start_epoch = load_weight(self.model.student_model, weights,
                                           self.optimizer)
        else:
S
shangliang Xu 已提交
349 350
            self.start_epoch = load_weight(self.model, weights, self.optimizer,
                                           self.ema if self.use_ema else None)
K
Kaipeng Deng 已提交
351
        logger.debug("Resume weights of epoch {}".format(self.start_epoch))
K
Kaipeng Deng 已提交
352

K
Kaipeng Deng 已提交
353
    def train(self, validate=False):
K
Kaipeng Deng 已提交
354
        assert self.mode == 'train', "Model not in 'train' mode"
Z
zhiboniu 已提交
355
        Init_mark = False
K
Kaipeng Deng 已提交
356

357
        sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
W
wangxinxin08 已提交
358 359
                   self.cfg.use_gpu and self._nranks > 1)
        if sync_bn:
360 361
            self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)
W
wangxinxin08 已提交
362

363
        model = self.model
364
        if self.cfg.get('fleet', False):
365
            model = fleet.distributed_model(model)
W
wangguanzhong 已提交
366
            self.optimizer = fleet.distributed_optimizer(self.optimizer)
367
        elif self._nranks > 1:
G
George Ni 已提交
368 369 370 371
            find_unused_parameters = self.cfg[
                'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
            model = paddle.DataParallel(
                self.model, find_unused_parameters=find_unused_parameters)
372

W
Wenyu 已提交
373 374
        # enabel auto mixed precision mode
        if self.cfg.get('amp', False):
375 376
            scaler = amp.GradScaler(
                enable=self.cfg.use_gpu, init_loss_scaling=1024)
K
Kaipeng Deng 已提交
377

K
Kaipeng Deng 已提交
378 379 380 381 382 383 384 385 386 387 388 389
        self.status.update({
            'epoch_id': self.start_epoch,
            'step_id': 0,
            'steps_per_epoch': len(self.loader)
        })

        self.status['batch_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['data_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)

G
Guanghua Yu 已提交
390
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
391 392 393
            flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, self.cfg.worker_num)
            self._flops(flops_loader)
394
        profiler_options = self.cfg.get('profiler_options', None)
G
Guanghua Yu 已提交
395

396 397
        self._compose_callback.on_train_begin(self.status)

K
Kaipeng Deng 已提交
398
        for epoch_id in range(self.start_epoch, self.cfg.epoch):
K
Kaipeng Deng 已提交
399
            self.status['mode'] = 'train'
K
Kaipeng Deng 已提交
400 401 402
            self.status['epoch_id'] = epoch_id
            self._compose_callback.on_epoch_begin(self.status)
            self.loader.dataset.set_epoch(epoch_id)
K
Kaipeng Deng 已提交
403
            model.train()
K
Kaipeng Deng 已提交
404 405 406 407
            iter_tic = time.time()
            for step_id, data in enumerate(self.loader):
                self.status['data_time'].update(time.time() - iter_tic)
                self.status['step_id'] = step_id
408
                profiler.add_profiler_step(profiler_options)
K
Kaipeng Deng 已提交
409
                self._compose_callback.on_step_begin(self.status)
S
shangliang Xu 已提交
410
                data['epoch_id'] = epoch_id
K
Kaipeng Deng 已提交
411

W
Wenyu 已提交
412
                if self.cfg.get('amp', False):
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
                    with amp.auto_cast(enable=self.cfg.use_gpu):
                        # model forward
                        outputs = model(data)
                        loss = outputs['loss']

                    # model backward
                    scaled_loss = scaler.scale(loss)
                    scaled_loss.backward()
                    # in dygraph mode, optimizer.minimize is equal to optimizer.step
                    scaler.minimize(self.optimizer, scaled_loss)
                else:
                    # model forward
                    outputs = model(data)
                    loss = outputs['loss']
                    # model backward
                    loss.backward()
                    self.optimizer.step()
K
Kaipeng Deng 已提交
430 431
                curr_lr = self.optimizer.get_lr()
                self.lr.step()
M
minghaoBD 已提交
432 433
                if self.cfg.get('unstructured_prune'):
                    self.pruner.step()
K
Kaipeng Deng 已提交
434 435 436
                self.optimizer.clear_grad()
                self.status['learning_rate'] = curr_lr

K
Kaipeng Deng 已提交
437
                if self._nranks < 2 or self._local_rank == 0:
K
Kaipeng Deng 已提交
438 439 440 441
                    self.status['training_staus'].update(outputs)

                self.status['batch_time'].update(time.time() - iter_tic)
                self._compose_callback.on_step_end(self.status)
442
                if self.use_ema:
S
shangliang Xu 已提交
443
                    self.ema.update()
F
Feng Ni 已提交
444
                iter_tic = time.time()
K
Kaipeng Deng 已提交
445

M
minghaoBD 已提交
446 447
            if self.cfg.get('unstructured_prune'):
                self.pruner.update_params()
448

S
shangliang Xu 已提交
449 450 451 452 453 454 455 456
            is_snapshot = (self._nranks < 2 or self._local_rank == 0) \
                       and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 or epoch_id == self.end_epoch - 1)
            if is_snapshot and self.use_ema:
                # apply ema weight on model
                weight = copy.deepcopy(self.model.state_dict())
                self.model.set_dict(self.ema.apply())
                self.status['weight'] = weight

K
Kaipeng Deng 已提交
457 458
            self._compose_callback.on_epoch_end(self.status)

S
shangliang Xu 已提交
459
            if validate and is_snapshot:
K
Kaipeng Deng 已提交
460 461 462 463 464 465 466
                if not hasattr(self, '_eval_loader'):
                    # build evaluation dataset and loader
                    self._eval_dataset = self.cfg.EvalDataset
                    self._eval_batch_sampler = \
                        paddle.io.BatchSampler(
                            self._eval_dataset,
                            batch_size=self.cfg.EvalReader['batch_size'])
467 468 469
                    # If metric is VOC, need to be set collate_batch=False.
                    if self.cfg.metric == 'VOC':
                        self.cfg['EvalReader']['collate_batch'] = False
K
Kaipeng Deng 已提交
470 471 472 473
                    self._eval_loader = create('EvalReader')(
                        self._eval_dataset,
                        self.cfg.worker_num,
                        batch_sampler=self._eval_batch_sampler)
Z
zhiboniu 已提交
474 475 476 477 478 479
                # if validation in training is enabled, metrics should be re-init
                # Init_mark makes sure this code will only execute once
                if validate and Init_mark == False:
                    Init_mark = True
                    self._init_metrics(validate=validate)
                    self._reset_metrics()
S
shangliang Xu 已提交
480

K
Kaipeng Deng 已提交
481
                with paddle.no_grad():
482
                    self.status['save_best_model'] = True
K
Kaipeng Deng 已提交
483 484
                    self._eval_with_loader(self._eval_loader)

S
shangliang Xu 已提交
485 486
            if is_snapshot and self.use_ema:
                # reset original weight
487
                self.model.set_dict(weight)
S
shangliang Xu 已提交
488
                self.status.pop('weight')
489

490 491
        self._compose_callback.on_train_end(self.status)

K
Kaipeng Deng 已提交
492
    def _eval_with_loader(self, loader):
K
Kaipeng Deng 已提交
493 494 495
        sample_num = 0
        tic = time.time()
        self._compose_callback.on_epoch_begin(self.status)
K
Kaipeng Deng 已提交
496 497
        self.status['mode'] = 'eval'
        self.model.eval()
G
Guanghua Yu 已提交
498
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
499 500 501
            flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
            self._flops(flops_loader)
K
Kaipeng Deng 已提交
502
        for step_id, data in enumerate(loader):
K
Kaipeng Deng 已提交
503 504 505 506 507 508 509 510 511
            self.status['step_id'] = step_id
            self._compose_callback.on_step_begin(self.status)
            # forward
            outs = self.model(data)

            # update metrics
            for metric in self._metrics:
                metric.update(data, outs)

M
Mark Ma 已提交
512 513 514 515 516
            # multi-scale inputs: all inputs have same im_id
            if isinstance(data, typing.Sequence):
                sample_num += data[0]['im_id'].numpy().shape[0]
            else:
                sample_num += data['im_id'].numpy().shape[0]
K
Kaipeng Deng 已提交
517 518 519 520 521 522 523 524 525
            self._compose_callback.on_step_end(self.status)

        self.status['sample_num'] = sample_num
        self.status['cost_time'] = time.time() - tic

        # accumulate metric to log out
        for metric in self._metrics:
            metric.accumulate()
            metric.log()
526
        self._compose_callback.on_epoch_end(self.status)
K
Kaipeng Deng 已提交
527 528 529
        # reset metric states for metric may performed multiple times
        self._reset_metrics()

K
Kaipeng Deng 已提交
530
    def evaluate(self):
531 532
        with paddle.no_grad():
            self._eval_with_loader(self.loader)
K
Kaipeng Deng 已提交
533

C
cnn 已提交
534 535 536 537 538
    def predict(self,
                images,
                draw_threshold=0.5,
                output_dir='output',
                save_txt=False):
K
Kaipeng Deng 已提交
539 540 541 542 543 544
        self.dataset.set_images(images)
        loader = create('TestReader')(self.dataset, 0)

        imid2path = self.dataset.get_imid2path()

        anno_file = self.dataset.get_anno()
C
cnn 已提交
545 546
        clsid2catid, catid2name = get_categories(
            self.cfg.metric, anno_file=anno_file)
K
Kaipeng Deng 已提交
547

K
Kaipeng Deng 已提交
548 549 550
        # Run Infer 
        self.status['mode'] = 'test'
        self.model.eval()
G
Guanghua Yu 已提交
551
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
552 553
            flops_loader = create('TestReader')(self.dataset, 0)
            self._flops(flops_loader)
554
        results = []
K
Kaipeng Deng 已提交
555 556 557 558
        for step_id, data in enumerate(loader):
            self.status['step_id'] = step_id
            # forward
            outs = self.model(data)
559

K
Kaipeng Deng 已提交
560
            for key in ['im_shape', 'scale_factor', 'im_id']:
M
Mark Ma 已提交
561 562 563 564
                if isinstance(data, typing.Sequence):
                    outs[key] = data[0][key]
                else:
                    outs[key] = data[key]
G
Guanghua Yu 已提交
565
            for key, value in outs.items():
566 567
                if hasattr(value, 'numpy'):
                    outs[key] = value.numpy()
568 569 570
            results.append(outs)
        # sniper
        if type(self.dataset) == SniperCOCODataSet:
571 572
            results = self.dataset.anno_cropper.aggregate_chips_detections(
                results)
K
Kaipeng Deng 已提交
573

574
        for outs in results:
K
Kaipeng Deng 已提交
575 576
            batch_res = get_infer_results(outs, clsid2catid)
            bbox_num = outs['bbox_num']
Z
zhiboniu 已提交
577

K
Kaipeng Deng 已提交
578 579 580 581
            start = 0
            for i, im_id in enumerate(outs['im_id']):
                image_path = imid2path[int(im_id)]
                image = Image.open(image_path).convert('RGB')
582
                image = ImageOps.exif_transpose(image)
583
                self.status['original_image'] = np.array(image.copy())
K
Kaipeng Deng 已提交
584

585
                end = start + bbox_num[i]
K
Kaipeng Deng 已提交
586 587 588 589
                bbox_res = batch_res['bbox'][start:end] \
                        if 'bbox' in batch_res else None
                mask_res = batch_res['mask'][start:end] \
                        if 'mask' in batch_res else None
G
Guanghua Yu 已提交
590 591
                segm_res = batch_res['segm'][start:end] \
                        if 'segm' in batch_res else None
592 593 594 595
                keypoint_res = batch_res['keypoint'][start:end] \
                        if 'keypoint' in batch_res else None
                image = visualize_results(
                    image, bbox_res, mask_res, segm_res, keypoint_res,
C
cnn 已提交
596
                    int(im_id), catid2name, draw_threshold)
597
                self.status['result_image'] = np.array(image.copy())
598 599
                if self._compose_callback:
                    self._compose_callback.on_step_end(self.status)
K
Kaipeng Deng 已提交
600 601 602 603 604
                # save image with detection
                save_name = self._get_save_image_name(output_dir, image_path)
                logger.info("Detection bbox results save in {}".format(
                    save_name))
                image.save(save_name, quality=95)
C
cnn 已提交
605 606
                if save_txt:
                    save_path = os.path.splitext(save_name)[0] + '.txt'
607 608 609 610 611 612 613
                    results = {}
                    results["im_id"] = im_id
                    if bbox_res:
                        results["bbox_res"] = bbox_res
                    if keypoint_res:
                        results["keypoint_res"] = keypoint_res
                    save_result(save_path, results, catid2name, draw_threshold)
K
Kaipeng Deng 已提交
614 615 616 617 618 619 620 621 622 623 624 625
                start = end

    def _get_save_image_name(self, output_dir, image_path):
        """
        Get save image name from source image path.
        """
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        image_name = os.path.split(image_path)[-1]
        name, ext = os.path.splitext(image_name)
        return os.path.join(output_dir, "{}".format(name)) + ext

G
Guanghua Yu 已提交
626
    def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
K
Kaipeng Deng 已提交
627
        image_shape = None
628 629
        im_shape = [None, 2]
        scale_factor = [None, 2]
630 631 632 633 634 635
        if self.cfg.architecture in MOT_ARCH:
            test_reader_name = 'TestMOTReader'
        else:
            test_reader_name = 'TestReader'
        if 'inputs_def' in self.cfg[test_reader_name]:
            inputs_def = self.cfg[test_reader_name]['inputs_def']
K
Kaipeng Deng 已提交
636
            image_shape = inputs_def.get('image_shape', None)
G
Guanghua Yu 已提交
637
        # set image_shape=[None, 3, -1, -1] as default
K
Kaipeng Deng 已提交
638
        if image_shape is None:
G
Guanghua Yu 已提交
639
            image_shape = [None, 3, -1, -1]
640

G
Guanghua Yu 已提交
641 642
        if len(image_shape) == 3:
            image_shape = [None] + image_shape
643 644 645
        else:
            im_shape = [image_shape[0], 2]
            scale_factor = [image_shape[0], 2]
K
Kaipeng Deng 已提交
646

647
        if hasattr(self.model, 'deploy'):
648
            self.model.deploy = True
S
shangliang Xu 已提交
649 650 651 652 653

        for layer in self.model.sublayers():
            if hasattr(layer, 'convert_to_deploy'):
                layer.convert_to_deploy()

654 655 656 657
        export_post_process = self.cfg.get('export_post_process', False)
        if hasattr(self.model, 'export_post_process'):
            self.model.export_post_process = export_post_process
            image_shape = [None] + image_shape[1:]
658 659 660
        if hasattr(self.model, 'fuse_norm'):
            self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
                                                              False)
K
Kaipeng Deng 已提交
661

K
Kaipeng Deng 已提交
662 663 664 665 666 667 668
        # Save infer cfg
        _dump_infer_config(self.cfg,
                           os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
                           self.model)

        input_spec = [{
            "image": InputSpec(
G
Guanghua Yu 已提交
669
                shape=image_shape, name='image'),
K
Kaipeng Deng 已提交
670
            "im_shape": InputSpec(
671
                shape=im_shape, name='im_shape'),
K
Kaipeng Deng 已提交
672
            "scale_factor": InputSpec(
673
                shape=scale_factor, name='scale_factor')
K
Kaipeng Deng 已提交
674
        }]
G
George Ni 已提交
675 676 677 678 679
        if self.cfg.architecture == 'DeepSORT':
            input_spec[0].update({
                "crops": InputSpec(
                    shape=[None, 3, 192, 64], name='crops')
            })
G
Guanghua Yu 已提交
680 681 682 683 684 685 686 687 688 689 690 691
        if prune_input:
            static_model = paddle.jit.to_static(
                self.model, input_spec=input_spec)
            # NOTE: dy2st do not pruned program, but jit.save will prune program
            # input spec, prune input spec here and save with pruned input spec
            pruned_input_spec = _prune_input_spec(
                input_spec, static_model.forward.main_program,
                static_model.forward.outputs)
        else:
            static_model = None
            pruned_input_spec = input_spec

G
Guanghua Yu 已提交
692
        # TODO: Hard code, delete it when support prune input_spec.
693
        if self.cfg.architecture == 'PicoDet' and not export_post_process:
G
Guanghua Yu 已提交
694 695 696 697 698
            pruned_input_spec = [{
                "image": InputSpec(
                    shape=image_shape, name='image')
            }]

G
Guanghua Yu 已提交
699 700 701 702 703 704 705 706
        return static_model, pruned_input_spec

    def export(self, output_dir='output_inference'):
        self.model.eval()
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
K
Kaipeng Deng 已提交
707

G
Guanghua Yu 已提交
708 709
        static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir)
G
Guanghua Yu 已提交
710 711 712

        # dy2st and save model
        if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
713 714 715 716 717
            paddle.jit.save(
                static_model,
                os.path.join(save_dir, 'model'),
                input_spec=pruned_input_spec)
        else:
718
            self.cfg.slim.save_quantized_model(
719 720
                self.model,
                os.path.join(save_dir, 'model'),
G
Guanghua Yu 已提交
721 722
                input_spec=pruned_input_spec)
        logger.info("Export model and saved in {}".format(save_dir))
723

G
Guanghua Yu 已提交
724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743
    def post_quant(self, output_dir='output_inference'):
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for idx, data in enumerate(self.loader):
            self.model(data)
            if idx == int(self.cfg.get('quant_batch_num', 10)):
                break

        # TODO: support prune input_spec
        _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir, prune_input=False)

        self.cfg.slim.save_quantized_model(
            self.model,
            os.path.join(save_dir, 'model'),
            input_spec=pruned_input_spec)
        logger.info("Export Post-Quant model and saved in {}".format(save_dir))
G
Guanghua Yu 已提交
744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768

    def _flops(self, loader):
        self.model.eval()
        try:
            import paddleslim
        except Exception as e:
            logger.warning(
                'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
            )
            return

        from paddleslim.analysis import dygraph_flops as flops
        input_data = None
        for data in loader:
            input_data = data
            break

        input_spec = [{
            "image": input_data['image'][0].unsqueeze(0),
            "im_shape": input_data['im_shape'][0].unsqueeze(0),
            "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
        }]
        flops = flops(self.model, input_spec) / (1000**3)
        logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
            flops, input_data['image'][0].unsqueeze(0).shape))
769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793

    def parse_mot_images(self, cfg):
        import glob
        # for quant
        dataset_dir = cfg['EvalMOTDataset'].dataset_dir
        data_root = cfg['EvalMOTDataset'].data_root
        data_root = '{}/{}'.format(dataset_dir, data_root)
        seqs = os.listdir(data_root)
        seqs.sort()
        all_images = []
        for seq in seqs:
            infer_dir = os.path.join(data_root, seq)
            assert infer_dir is None or os.path.isdir(infer_dir), \
                "{} is not a directory".format(infer_dir)
            images = set()
            exts = ['jpg', 'jpeg', 'png', 'bmp']
            exts += [ext.upper() for ext in exts]
            for ext in exts:
                images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
            images = list(images)
            images.sort()
            assert len(images) > 0, "no image found in {}".format(infer_dir)
            all_images.extend(images)
            logger.info("Found {} inference images in total.".format(len(images)))
        return all_images