trainer.py 33.3 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
K
Kaipeng Deng 已提交
13 14 15 16 17 18 19
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
G
George Ni 已提交
20
import sys
21
import copy
K
Kaipeng Deng 已提交
22
import time
F
Feng Ni 已提交
23
from tqdm import tqdm
M
Manuel Garcia 已提交
24

K
Kaipeng Deng 已提交
25
import numpy as np
M
Mark Ma 已提交
26
import typing
F
Feng Ni 已提交
27 28
from PIL import Image, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
K
Kaipeng Deng 已提交
29 30

import paddle
W
wangguanzhong 已提交
31 32
import paddle.distributed as dist
from paddle.distributed import fleet
33
from paddle import amp
K
Kaipeng Deng 已提交
34
from paddle.static import InputSpec
35
from ppdet.optimizer import ModelEMA
K
Kaipeng Deng 已提交
36 37 38

from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
C
cnn 已提交
39
from ppdet.utils.visualizer import visualize_results, save_result
Z
zhiboniu 已提交
40
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
41 42
from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
from ppdet.data.source.sniper_coco import SniperCOCODataSet
K
Kaipeng Deng 已提交
43
from ppdet.data.source.category import get_categories
K
Kaipeng Deng 已提交
44
import ppdet.utils.stats as stats
45
from ppdet.utils import profiler
K
Kaipeng Deng 已提交
46

47
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
G
Guanghua Yu 已提交
48
from .export_utils import _dump_infer_config, _prune_input_spec
K
Kaipeng Deng 已提交
49 50

from ppdet.utils.logger import setup_logger
51
logger = setup_logger('ppdet.engine')
K
Kaipeng Deng 已提交
52 53 54

__all__ = ['Trainer']

55
MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack']
56

K
Kaipeng Deng 已提交
57 58 59 60 61 62 63

class Trainer(object):
    def __init__(self, cfg, mode='train'):
        self.cfg = cfg
        assert mode.lower() in ['train', 'eval', 'test'], \
                "mode should be 'train', 'eval' or 'test'"
        self.mode = mode.lower()
64
        self.optimizer = None
65
        self.is_loaded_weights = False
K
Kaipeng Deng 已提交
66

G
George Ni 已提交
67
        # build data loader
68 69 70 71 72 73 74 75 76
        if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
            self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
        else:
            self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]

        if cfg.architecture == 'DeepSORT' and self.mode == 'train':
            logger.error('DeepSORT has no need of training on mot dataset.')
            sys.exit(1)

77 78 79 80
        if cfg.architecture == 'FairMOT' and self.mode == 'eval':
            images = self.parse_mot_images(cfg)
            self.dataset.set_images(images)

G
George Ni 已提交
81 82 83 84 85 86
        if self.mode == 'train':
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num)

        if cfg.architecture == 'JDE' and self.mode == 'train':
            cfg['JDEEmbeddingHead'][
87 88
                'num_identities'] = self.dataset.num_identities_dict[0]
            # JDE only support single class MOT now.
G
George Ni 已提交
89

F
FlyingQianMM 已提交
90
        if cfg.architecture == 'FairMOT' and self.mode == 'train':
M
minghaoBD 已提交
91 92
            cfg['FairMOTEmbeddingHead'][
                'num_identities_dict'] = self.dataset.num_identities_dict
93
            # FairMOT support single class and multi-class MOT now.
F
FlyingQianMM 已提交
94

K
Kaipeng Deng 已提交
95
        # build model
96 97 98 99 100
        if 'model' not in self.cfg:
            self.model = create(cfg.architecture)
        else:
            self.model = self.cfg.model
            self.is_loaded_weights = True
101

102
        #normalize params for deploy
C
Chang Xu 已提交
103 104 105
        if 'slim' in cfg and cfg['slim_type'] == 'OFA':
            self.model.model.load_meanstd(cfg['TestReader'][
                'sample_transforms'])
C
Chang Xu 已提交
106 107 108 109 110 111 112
        elif 'slim' in cfg and cfg['slim_type'] == 'Distill':
            self.model.student_model.load_meanstd(cfg['TestReader'][
                'sample_transforms'])
        elif 'slim' in cfg and cfg[
                'slim_type'] == 'DistillPrune' and self.mode == 'train':
            self.model.student_model.load_meanstd(cfg['TestReader'][
                'sample_transforms'])
C
Chang Xu 已提交
113 114
        else:
            self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
115

116 117
        self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
        if self.use_ema:
G
Guanghua Yu 已提交
118 119
            ema_decay = self.cfg.get('ema_decay', 0.9998)
            cycle_epoch = self.cfg.get('cycle_epoch', -1)
120
            self.ema = ModelEMA(
G
Guanghua Yu 已提交
121 122 123 124
                self.model,
                decay=ema_decay,
                use_thres_step=True,
                cycle_epoch=cycle_epoch)
125

K
Kaipeng Deng 已提交
126 127 128
        # EvalDataset build with BatchSampler to evaluate in single device
        # TODO: multi-device evaluate
        if self.mode == 'eval':
129 130 131 132 133 134 135 136 137 138 139
            if cfg.architecture == 'FairMOT':
                self.loader = create('EvalMOTReader')(self.dataset, 0)
            else:
                self._eval_batch_sampler = paddle.io.BatchSampler(
                    self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
                reader_name = '{}Reader'.format(self.mode.capitalize())
                # If metric is VOC, need to be set collate_batch=False.
                if cfg.metric == 'VOC':
                    cfg[reader_name]['collate_batch'] = False
                self.loader = create(reader_name)(self.dataset, cfg.worker_num,
                                                  self._eval_batch_sampler)
K
Kaipeng Deng 已提交
140
        # TestDataset build after user set images, skip loader creation here
K
Kaipeng Deng 已提交
141 142 143 144 145

        # build optimizer in train mode
        if self.mode == 'train':
            steps_per_epoch = len(self.loader)
            self.lr = create('LearningRate')(steps_per_epoch)
W
Wenyu 已提交
146
            self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
K
Kaipeng Deng 已提交
147

M
minghaoBD 已提交
148 149 150 151
            # Unstructured pruner is only enabled in the train mode.
            if self.cfg.get('unstructured_prune'):
                self.pruner = create('UnstructuredPruner')(self.model,
                                                           steps_per_epoch)
M
minghaoBD 已提交
152

W
wangguanzhong 已提交
153 154
        self._nranks = dist.get_world_size()
        self._local_rank = dist.get_rank()
K
Kaipeng Deng 已提交
155

K
Kaipeng Deng 已提交
156 157 158
        self.status = {}

        self.start_epoch = 0
G
George Ni 已提交
159
        self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
K
Kaipeng Deng 已提交
160 161 162 163 164 165 166 167 168 169 170

        # initial default callbacks
        self._init_callbacks()

        # initial default metrics
        self._init_metrics()
        self._reset_metrics()

    def _init_callbacks(self):
        if self.mode == 'train':
            self._callbacks = [LogPrinter(self), Checkpointer(self)]
171
            if self.cfg.get('use_vdl', False):
172
                self._callbacks.append(VisualDLWriter(self))
173 174
            if self.cfg.get('save_proposals', False):
                self._callbacks.append(SniperProposalsGenerator(self))
K
Kaipeng Deng 已提交
175 176 177
            self._compose_callback = ComposeCallback(self._callbacks)
        elif self.mode == 'eval':
            self._callbacks = [LogPrinter(self)]
178 179
            if self.cfg.metric == 'WiderFace':
                self._callbacks.append(WiferFaceEval(self))
K
Kaipeng Deng 已提交
180
            self._compose_callback = ComposeCallback(self._callbacks)
181
        elif self.mode == 'test' and self.cfg.get('use_vdl', False):
182 183
            self._callbacks = [VisualDLWriter(self)]
            self._compose_callback = ComposeCallback(self._callbacks)
K
Kaipeng Deng 已提交
184 185 186 187
        else:
            self._callbacks = []
            self._compose_callback = None

K
Kaipeng Deng 已提交
188 189
    def _init_metrics(self, validate=False):
        if self.mode == 'test' or (self.mode == 'train' and not validate):
G
Guanghua Yu 已提交
190 191
            self._metrics = []
            return
192
        classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
193
        if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO":
W
wangxinxin08 已提交
194
            # TODO: bias should be unified
195
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
S
shangliang Xu 已提交
196 197
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
198
            save_prediction_only = self.cfg.get('save_prediction_only', False)
199 200 201

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
K
Kaipeng Deng 已提交
202 203
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None
204 205 206 207

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
208
            dataset = self.dataset
209 210 211 212
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()
213
                dataset = eval_dataset
214

215
            IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
216 217 218 219 220 221 222 223 224 225 226
            if self.cfg.metric == "COCO":
                self._metrics = [
                    COCOMetric(
                        anno_file=anno_file,
                        clsid2catid=clsid2catid,
                        classwise=classwise,
                        output_eval=output_eval,
                        bias=bias,
                        IouType=IouType,
                        save_prediction_only=save_prediction_only)
                ]
227
            elif self.cfg.metric == "SNIPERCOCO":  # sniper
228 229 230 231 232 233 234 235 236
                self._metrics = [
                    SNIPERCOCOMetric(
                        anno_file=anno_file,
                        dataset=dataset,
                        clsid2catid=clsid2catid,
                        classwise=classwise,
                        output_eval=output_eval,
                        bias=bias,
                        IouType=IouType,
237
                        save_prediction_only=save_prediction_only)
238
                ]
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
        elif self.cfg.metric == 'RBOX':
            # TODO: bias should be unified
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
            save_prediction_only = self.cfg.get('save_prediction_only', False)

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()

            self._metrics = [
                RBoxMetric(
                    anno_file=anno_file,
                    clsid2catid=clsid2catid,
                    classwise=classwise,
                    output_eval=output_eval,
                    bias=bias,
                    save_prediction_only=save_prediction_only)
            ]
K
Kaipeng Deng 已提交
268 269 270
        elif self.cfg.metric == 'VOC':
            self._metrics = [
                VOCMetric(
271
                    label_list=self.dataset.get_label_list(),
K
Kaipeng Deng 已提交
272
                    class_num=self.cfg.num_classes,
273 274
                    map_type=self.cfg.map_type,
                    classwise=classwise)
K
Kaipeng Deng 已提交
275
            ]
276 277 278 279 280 281 282 283 284
        elif self.cfg.metric == 'WiderFace':
            multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
            self._metrics = [
                WiderFaceMetric(
                    image_dir=os.path.join(self.dataset.dataset_dir,
                                           self.dataset.image_dir),
                    anno_file=self.dataset.get_anno(),
                    multi_scale=multi_scale)
            ]
285 286 287 288
        elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
289
            save_prediction_only = self.cfg.get('save_prediction_only', False)
290
            self._metrics = [
291 292 293 294 295 296
                KeyPointTopDownCOCOEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
297
            ]
Z
zhiboniu 已提交
298 299 300 301
        elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
302
            save_prediction_only = self.cfg.get('save_prediction_only', False)
Z
zhiboniu 已提交
303
            self._metrics = [
304 305 306 307 308 309
                KeyPointTopDownMPIIEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
Z
zhiboniu 已提交
310
            ]
G
George Ni 已提交
311 312
        elif self.cfg.metric == 'MOTDet':
            self._metrics = [JDEDetMetric(), ]
K
Kaipeng Deng 已提交
313
        else:
314
            logger.warning("Metric not support for metric type {}".format(
K
Kaipeng Deng 已提交
315
                self.cfg.metric))
K
Kaipeng Deng 已提交
316 317 318 319 320 321 322
            self._metrics = []

    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def register_callbacks(self, callbacks):
323
        callbacks = [c for c in list(callbacks) if c is not None]
K
Kaipeng Deng 已提交
324 325 326 327 328 329 330 331 332 333 334 335 336
        for c in callbacks:
            assert isinstance(c, Callback), \
                    "metrics shoule be instances of subclass of Metric"
        self._callbacks.extend(callbacks)
        self._compose_callback = ComposeCallback(self._callbacks)

    def register_metrics(self, metrics):
        metrics = [m for m in list(metrics) if m is not None]
        for m in metrics:
            assert isinstance(m, Metric), \
                    "metrics shoule be instances of subclass of Metric"
        self._metrics.extend(metrics)

K
Kaipeng Deng 已提交
337
    def load_weights(self, weights):
338 339
        if self.is_loaded_weights:
            return
K
Kaipeng Deng 已提交
340
        self.start_epoch = 0
341
        load_pretrain_weight(self.model, weights)
K
Kaipeng Deng 已提交
342 343
        logger.debug("Load weights {} to start training".format(weights))

344 345 346 347 348 349 350
    def load_weights_sde(self, det_weights, reid_weights):
        if self.model.detector:
            load_weight(self.model.detector, det_weights)
            load_weight(self.model.reid, reid_weights)
        else:
            load_weight(self.model.reid, reid_weights)

K
Kaipeng Deng 已提交
351
    def resume_weights(self, weights):
352 353 354 355 356
        # support Distill resume weights
        if hasattr(self.model, 'student_model'):
            self.start_epoch = load_weight(self.model.student_model, weights,
                                           self.optimizer)
        else:
S
shangliang Xu 已提交
357 358
            self.start_epoch = load_weight(self.model, weights, self.optimizer,
                                           self.ema if self.use_ema else None)
K
Kaipeng Deng 已提交
359
        logger.debug("Resume weights of epoch {}".format(self.start_epoch))
K
Kaipeng Deng 已提交
360

K
Kaipeng Deng 已提交
361
    def train(self, validate=False):
K
Kaipeng Deng 已提交
362
        assert self.mode == 'train', "Model not in 'train' mode"
Z
zhiboniu 已提交
363
        Init_mark = False
K
Kaipeng Deng 已提交
364

365
        sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
W
wangxinxin08 已提交
366 367
                   self.cfg.use_gpu and self._nranks > 1)
        if sync_bn:
368 369
            self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)
W
wangxinxin08 已提交
370

371
        model = self.model
372
        if self.cfg.get('fleet', False):
373
            model = fleet.distributed_model(model)
W
wangguanzhong 已提交
374
            self.optimizer = fleet.distributed_optimizer(self.optimizer)
375
        elif self._nranks > 1:
G
George Ni 已提交
376 377 378 379
            find_unused_parameters = self.cfg[
                'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
            model = paddle.DataParallel(
                self.model, find_unused_parameters=find_unused_parameters)
380

W
Wenyu 已提交
381 382
        # enabel auto mixed precision mode
        if self.cfg.get('amp', False):
383
            scaler = amp.GradScaler(
384 385
                enable=self.cfg.use_gpu or self.cfg.use_npu,
                init_loss_scaling=1024)
K
Kaipeng Deng 已提交
386

K
Kaipeng Deng 已提交
387 388 389 390 391 392 393 394 395 396 397 398
        self.status.update({
            'epoch_id': self.start_epoch,
            'step_id': 0,
            'steps_per_epoch': len(self.loader)
        })

        self.status['batch_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['data_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)

G
Guanghua Yu 已提交
399
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
400 401 402
            flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, self.cfg.worker_num)
            self._flops(flops_loader)
403
        profiler_options = self.cfg.get('profiler_options', None)
G
Guanghua Yu 已提交
404

405 406
        self._compose_callback.on_train_begin(self.status)

K
Kaipeng Deng 已提交
407
        for epoch_id in range(self.start_epoch, self.cfg.epoch):
K
Kaipeng Deng 已提交
408
            self.status['mode'] = 'train'
K
Kaipeng Deng 已提交
409 410 411
            self.status['epoch_id'] = epoch_id
            self._compose_callback.on_epoch_begin(self.status)
            self.loader.dataset.set_epoch(epoch_id)
K
Kaipeng Deng 已提交
412
            model.train()
K
Kaipeng Deng 已提交
413 414 415 416
            iter_tic = time.time()
            for step_id, data in enumerate(self.loader):
                self.status['data_time'].update(time.time() - iter_tic)
                self.status['step_id'] = step_id
417
                profiler.add_profiler_step(profiler_options)
K
Kaipeng Deng 已提交
418
                self._compose_callback.on_step_begin(self.status)
S
shangliang Xu 已提交
419
                data['epoch_id'] = epoch_id
K
Kaipeng Deng 已提交
420

W
Wenyu 已提交
421
                if self.cfg.get('amp', False):
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
                    with amp.auto_cast(enable=self.cfg.use_gpu):
                        # model forward
                        outputs = model(data)
                        loss = outputs['loss']

                    # model backward
                    scaled_loss = scaler.scale(loss)
                    scaled_loss.backward()
                    # in dygraph mode, optimizer.minimize is equal to optimizer.step
                    scaler.minimize(self.optimizer, scaled_loss)
                else:
                    # model forward
                    outputs = model(data)
                    loss = outputs['loss']
                    # model backward
                    loss.backward()
                    self.optimizer.step()
K
Kaipeng Deng 已提交
439 440
                curr_lr = self.optimizer.get_lr()
                self.lr.step()
M
minghaoBD 已提交
441 442
                if self.cfg.get('unstructured_prune'):
                    self.pruner.step()
K
Kaipeng Deng 已提交
443 444 445
                self.optimizer.clear_grad()
                self.status['learning_rate'] = curr_lr

K
Kaipeng Deng 已提交
446
                if self._nranks < 2 or self._local_rank == 0:
K
Kaipeng Deng 已提交
447 448 449 450
                    self.status['training_staus'].update(outputs)

                self.status['batch_time'].update(time.time() - iter_tic)
                self._compose_callback.on_step_end(self.status)
451
                if self.use_ema:
S
shangliang Xu 已提交
452
                    self.ema.update()
F
Feng Ni 已提交
453
                iter_tic = time.time()
K
Kaipeng Deng 已提交
454

M
minghaoBD 已提交
455 456
            if self.cfg.get('unstructured_prune'):
                self.pruner.update_params()
457

S
shangliang Xu 已提交
458 459 460 461 462 463 464 465
            is_snapshot = (self._nranks < 2 or self._local_rank == 0) \
                       and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 or epoch_id == self.end_epoch - 1)
            if is_snapshot and self.use_ema:
                # apply ema weight on model
                weight = copy.deepcopy(self.model.state_dict())
                self.model.set_dict(self.ema.apply())
                self.status['weight'] = weight

K
Kaipeng Deng 已提交
466 467
            self._compose_callback.on_epoch_end(self.status)

S
shangliang Xu 已提交
468
            if validate and is_snapshot:
K
Kaipeng Deng 已提交
469 470 471 472 473 474 475
                if not hasattr(self, '_eval_loader'):
                    # build evaluation dataset and loader
                    self._eval_dataset = self.cfg.EvalDataset
                    self._eval_batch_sampler = \
                        paddle.io.BatchSampler(
                            self._eval_dataset,
                            batch_size=self.cfg.EvalReader['batch_size'])
476 477 478
                    # If metric is VOC, need to be set collate_batch=False.
                    if self.cfg.metric == 'VOC':
                        self.cfg['EvalReader']['collate_batch'] = False
K
Kaipeng Deng 已提交
479 480 481 482
                    self._eval_loader = create('EvalReader')(
                        self._eval_dataset,
                        self.cfg.worker_num,
                        batch_sampler=self._eval_batch_sampler)
Z
zhiboniu 已提交
483 484 485 486 487 488
                # if validation in training is enabled, metrics should be re-init
                # Init_mark makes sure this code will only execute once
                if validate and Init_mark == False:
                    Init_mark = True
                    self._init_metrics(validate=validate)
                    self._reset_metrics()
S
shangliang Xu 已提交
489

K
Kaipeng Deng 已提交
490
                with paddle.no_grad():
491
                    self.status['save_best_model'] = True
K
Kaipeng Deng 已提交
492 493
                    self._eval_with_loader(self._eval_loader)

S
shangliang Xu 已提交
494 495
            if is_snapshot and self.use_ema:
                # reset original weight
496
                self.model.set_dict(weight)
S
shangliang Xu 已提交
497
                self.status.pop('weight')
498

499 500
        self._compose_callback.on_train_end(self.status)

K
Kaipeng Deng 已提交
501
    def _eval_with_loader(self, loader):
K
Kaipeng Deng 已提交
502 503 504
        sample_num = 0
        tic = time.time()
        self._compose_callback.on_epoch_begin(self.status)
K
Kaipeng Deng 已提交
505 506
        self.status['mode'] = 'eval'
        self.model.eval()
G
Guanghua Yu 已提交
507
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
508 509 510
            flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
            self._flops(flops_loader)
F
Feng Ni 已提交
511
        for step_id, data in enumerate(loader):
K
Kaipeng Deng 已提交
512 513 514 515 516 517 518 519 520
            self.status['step_id'] = step_id
            self._compose_callback.on_step_begin(self.status)
            # forward
            outs = self.model(data)

            # update metrics
            for metric in self._metrics:
                metric.update(data, outs)

M
Mark Ma 已提交
521 522 523 524 525
            # multi-scale inputs: all inputs have same im_id
            if isinstance(data, typing.Sequence):
                sample_num += data[0]['im_id'].numpy().shape[0]
            else:
                sample_num += data['im_id'].numpy().shape[0]
K
Kaipeng Deng 已提交
526 527 528 529 530 531 532 533 534
            self._compose_callback.on_step_end(self.status)

        self.status['sample_num'] = sample_num
        self.status['cost_time'] = time.time() - tic

        # accumulate metric to log out
        for metric in self._metrics:
            metric.accumulate()
            metric.log()
535
        self._compose_callback.on_epoch_end(self.status)
K
Kaipeng Deng 已提交
536 537 538
        # reset metric states for metric may performed multiple times
        self._reset_metrics()

K
Kaipeng Deng 已提交
539
    def evaluate(self):
540 541
        with paddle.no_grad():
            self._eval_with_loader(self.loader)
K
Kaipeng Deng 已提交
542

C
cnn 已提交
543 544 545 546 547
    def predict(self,
                images,
                draw_threshold=0.5,
                output_dir='output',
                save_txt=False):
K
Kaipeng Deng 已提交
548 549 550 551 552 553
        self.dataset.set_images(images)
        loader = create('TestReader')(self.dataset, 0)

        imid2path = self.dataset.get_imid2path()

        anno_file = self.dataset.get_anno()
C
cnn 已提交
554 555
        clsid2catid, catid2name = get_categories(
            self.cfg.metric, anno_file=anno_file)
K
Kaipeng Deng 已提交
556

K
Kaipeng Deng 已提交
557 558 559
        # Run Infer 
        self.status['mode'] = 'test'
        self.model.eval()
G
Guanghua Yu 已提交
560
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
561 562
            flops_loader = create('TestReader')(self.dataset, 0)
            self._flops(flops_loader)
563
        results = []
F
Feng Ni 已提交
564
        for step_id, data in enumerate(tqdm(loader)):
K
Kaipeng Deng 已提交
565 566 567
            self.status['step_id'] = step_id
            # forward
            outs = self.model(data)
568

K
Kaipeng Deng 已提交
569
            for key in ['im_shape', 'scale_factor', 'im_id']:
M
Mark Ma 已提交
570 571 572 573
                if isinstance(data, typing.Sequence):
                    outs[key] = data[0][key]
                else:
                    outs[key] = data[key]
G
Guanghua Yu 已提交
574
            for key, value in outs.items():
575 576
                if hasattr(value, 'numpy'):
                    outs[key] = value.numpy()
577 578 579
            results.append(outs)
        # sniper
        if type(self.dataset) == SniperCOCODataSet:
580 581
            results = self.dataset.anno_cropper.aggregate_chips_detections(
                results)
K
Kaipeng Deng 已提交
582

583
        for outs in results:
K
Kaipeng Deng 已提交
584 585
            batch_res = get_infer_results(outs, clsid2catid)
            bbox_num = outs['bbox_num']
Z
zhiboniu 已提交
586

K
Kaipeng Deng 已提交
587 588 589 590
            start = 0
            for i, im_id in enumerate(outs['im_id']):
                image_path = imid2path[int(im_id)]
                image = Image.open(image_path).convert('RGB')
591
                image = ImageOps.exif_transpose(image)
592
                self.status['original_image'] = np.array(image.copy())
K
Kaipeng Deng 已提交
593

594
                end = start + bbox_num[i]
K
Kaipeng Deng 已提交
595 596 597 598
                bbox_res = batch_res['bbox'][start:end] \
                        if 'bbox' in batch_res else None
                mask_res = batch_res['mask'][start:end] \
                        if 'mask' in batch_res else None
G
Guanghua Yu 已提交
599 600
                segm_res = batch_res['segm'][start:end] \
                        if 'segm' in batch_res else None
601 602 603 604
                keypoint_res = batch_res['keypoint'][start:end] \
                        if 'keypoint' in batch_res else None
                image = visualize_results(
                    image, bbox_res, mask_res, segm_res, keypoint_res,
C
cnn 已提交
605
                    int(im_id), catid2name, draw_threshold)
606
                self.status['result_image'] = np.array(image.copy())
607 608
                if self._compose_callback:
                    self._compose_callback.on_step_end(self.status)
K
Kaipeng Deng 已提交
609 610 611 612 613
                # save image with detection
                save_name = self._get_save_image_name(output_dir, image_path)
                logger.info("Detection bbox results save in {}".format(
                    save_name))
                image.save(save_name, quality=95)
C
cnn 已提交
614 615
                if save_txt:
                    save_path = os.path.splitext(save_name)[0] + '.txt'
616 617 618 619 620 621 622
                    results = {}
                    results["im_id"] = im_id
                    if bbox_res:
                        results["bbox_res"] = bbox_res
                    if keypoint_res:
                        results["keypoint_res"] = keypoint_res
                    save_result(save_path, results, catid2name, draw_threshold)
K
Kaipeng Deng 已提交
623 624 625 626 627 628 629 630 631 632 633 634
                start = end

    def _get_save_image_name(self, output_dir, image_path):
        """
        Get save image name from source image path.
        """
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        image_name = os.path.split(image_path)[-1]
        name, ext = os.path.splitext(image_name)
        return os.path.join(output_dir, "{}".format(name)) + ext

G
Guanghua Yu 已提交
635
    def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
K
Kaipeng Deng 已提交
636
        image_shape = None
637 638
        im_shape = [None, 2]
        scale_factor = [None, 2]
639 640 641 642 643 644
        if self.cfg.architecture in MOT_ARCH:
            test_reader_name = 'TestMOTReader'
        else:
            test_reader_name = 'TestReader'
        if 'inputs_def' in self.cfg[test_reader_name]:
            inputs_def = self.cfg[test_reader_name]['inputs_def']
K
Kaipeng Deng 已提交
645
            image_shape = inputs_def.get('image_shape', None)
G
Guanghua Yu 已提交
646
        # set image_shape=[None, 3, -1, -1] as default
K
Kaipeng Deng 已提交
647
        if image_shape is None:
G
Guanghua Yu 已提交
648
            image_shape = [None, 3, -1, -1]
649

G
Guanghua Yu 已提交
650 651
        if len(image_shape) == 3:
            image_shape = [None] + image_shape
652 653 654
        else:
            im_shape = [image_shape[0], 2]
            scale_factor = [image_shape[0], 2]
K
Kaipeng Deng 已提交
655

656
        if hasattr(self.model, 'deploy'):
657
            self.model.deploy = True
S
shangliang Xu 已提交
658 659 660 661 662

        for layer in self.model.sublayers():
            if hasattr(layer, 'convert_to_deploy'):
                layer.convert_to_deploy()

663 664 665 666 667 668
        export_post_process = self.cfg['export'].get(
            'post_process', False) if hasattr(self.cfg, 'export') else True
        export_nms = self.cfg['export'].get('nms', False) if hasattr(
            self.cfg, 'export') else True
        export_benchmark = self.cfg['export'].get(
            'benchmark', False) if hasattr(self.cfg, 'export') else False
669 670 671
        if hasattr(self.model, 'fuse_norm'):
            self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
                                                              False)
672 673 674 675 676 677
        if hasattr(self.model, 'export_post_process'):
            self.model.export_post_process = export_post_process if not export_benchmark else False
        if hasattr(self.model, 'export_nms'):
            self.model.export_nms = export_nms if not export_benchmark else False
        if export_post_process and not export_benchmark:
            image_shape = [None] + image_shape[1:]
K
Kaipeng Deng 已提交
678

K
Kaipeng Deng 已提交
679 680 681 682 683 684 685
        # Save infer cfg
        _dump_infer_config(self.cfg,
                           os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
                           self.model)

        input_spec = [{
            "image": InputSpec(
G
Guanghua Yu 已提交
686
                shape=image_shape, name='image'),
K
Kaipeng Deng 已提交
687
            "im_shape": InputSpec(
688
                shape=im_shape, name='im_shape'),
K
Kaipeng Deng 已提交
689
            "scale_factor": InputSpec(
690
                shape=scale_factor, name='scale_factor')
K
Kaipeng Deng 已提交
691
        }]
G
George Ni 已提交
692 693 694 695 696
        if self.cfg.architecture == 'DeepSORT':
            input_spec[0].update({
                "crops": InputSpec(
                    shape=[None, 3, 192, 64], name='crops')
            })
G
Guanghua Yu 已提交
697 698 699 700 701 702 703 704 705 706 707 708
        if prune_input:
            static_model = paddle.jit.to_static(
                self.model, input_spec=input_spec)
            # NOTE: dy2st do not pruned program, but jit.save will prune program
            # input spec, prune input spec here and save with pruned input spec
            pruned_input_spec = _prune_input_spec(
                input_spec, static_model.forward.main_program,
                static_model.forward.outputs)
        else:
            static_model = None
            pruned_input_spec = input_spec

G
Guanghua Yu 已提交
709
        # TODO: Hard code, delete it when support prune input_spec.
710
        if self.cfg.architecture == 'PicoDet' and not export_post_process:
G
Guanghua Yu 已提交
711 712 713 714 715
            pruned_input_spec = [{
                "image": InputSpec(
                    shape=image_shape, name='image')
            }]

G
Guanghua Yu 已提交
716 717 718 719 720 721 722 723
        return static_model, pruned_input_spec

    def export(self, output_dir='output_inference'):
        self.model.eval()
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
K
Kaipeng Deng 已提交
724

G
Guanghua Yu 已提交
725 726
        static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir)
G
Guanghua Yu 已提交
727 728 729

        # dy2st and save model
        if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
730 731 732 733 734
            paddle.jit.save(
                static_model,
                os.path.join(save_dir, 'model'),
                input_spec=pruned_input_spec)
        else:
735
            self.cfg.slim.save_quantized_model(
736 737
                self.model,
                os.path.join(save_dir, 'model'),
G
Guanghua Yu 已提交
738 739
                input_spec=pruned_input_spec)
        logger.info("Export model and saved in {}".format(save_dir))
740

G
Guanghua Yu 已提交
741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760
    def post_quant(self, output_dir='output_inference'):
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for idx, data in enumerate(self.loader):
            self.model(data)
            if idx == int(self.cfg.get('quant_batch_num', 10)):
                break

        # TODO: support prune input_spec
        _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir, prune_input=False)

        self.cfg.slim.save_quantized_model(
            self.model,
            os.path.join(save_dir, 'model'),
            input_spec=pruned_input_spec)
        logger.info("Export Post-Quant model and saved in {}".format(save_dir))
G
Guanghua Yu 已提交
761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785

    def _flops(self, loader):
        self.model.eval()
        try:
            import paddleslim
        except Exception as e:
            logger.warning(
                'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
            )
            return

        from paddleslim.analysis import dygraph_flops as flops
        input_data = None
        for data in loader:
            input_data = data
            break

        input_spec = [{
            "image": input_data['image'][0].unsqueeze(0),
            "im_shape": input_data['im_shape'][0].unsqueeze(0),
            "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
        }]
        flops = flops(self.model, input_spec) / (1000**3)
        logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
            flops, input_data['image'][0].unsqueeze(0).shape))
786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808

    def parse_mot_images(self, cfg):
        import glob
        # for quant
        dataset_dir = cfg['EvalMOTDataset'].dataset_dir
        data_root = cfg['EvalMOTDataset'].data_root
        data_root = '{}/{}'.format(dataset_dir, data_root)
        seqs = os.listdir(data_root)
        seqs.sort()
        all_images = []
        for seq in seqs:
            infer_dir = os.path.join(data_root, seq)
            assert infer_dir is None or os.path.isdir(infer_dir), \
                "{} is not a directory".format(infer_dir)
            images = set()
            exts = ['jpg', 'jpeg', 'png', 'bmp']
            exts += [ext.upper() for ext in exts]
            for ext in exts:
                images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
            images = list(images)
            images.sort()
            assert len(images) > 0, "no image found in {}".format(infer_dir)
            all_images.extend(images)
809 810 811
            logger.info("Found {} inference images in total.".format(
                len(images)))
        return all_images