trainer.py 28.8 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
K
Kaipeng Deng 已提交
13 14 15 16 17 18 19
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
G
George Ni 已提交
20
import sys
21
import copy
K
Kaipeng Deng 已提交
22
import time
M
Manuel Garcia 已提交
23

K
Kaipeng Deng 已提交
24
import numpy as np
25
from PIL import Image, ImageOps
K
Kaipeng Deng 已提交
26 27

import paddle
W
wangguanzhong 已提交
28 29
import paddle.distributed as dist
from paddle.distributed import fleet
30
from paddle import amp
K
Kaipeng Deng 已提交
31
from paddle.static import InputSpec
32
from ppdet.optimizer import ModelEMA
K
Kaipeng Deng 已提交
33 34 35

from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
C
cnn 已提交
36
from ppdet.utils.visualizer import visualize_results, save_result
Z
zhiboniu 已提交
37
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
38 39
from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
from ppdet.data.source.sniper_coco import SniperCOCODataSet
K
Kaipeng Deng 已提交
40
from ppdet.data.source.category import get_categories
K
Kaipeng Deng 已提交
41
import ppdet.utils.stats as stats
42
from ppdet.utils import profiler
K
Kaipeng Deng 已提交
43

44
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
G
Guanghua Yu 已提交
45
from .export_utils import _dump_infer_config, _prune_input_spec
K
Kaipeng Deng 已提交
46 47

from ppdet.utils.logger import setup_logger
48
logger = setup_logger('ppdet.engine')
K
Kaipeng Deng 已提交
49 50 51

__all__ = ['Trainer']

52 53
MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']

K
Kaipeng Deng 已提交
54 55 56 57 58 59 60

class Trainer(object):
    def __init__(self, cfg, mode='train'):
        self.cfg = cfg
        assert mode.lower() in ['train', 'eval', 'test'], \
                "mode should be 'train', 'eval' or 'test'"
        self.mode = mode.lower()
61
        self.optimizer = None
62
        self.is_loaded_weights = False
K
Kaipeng Deng 已提交
63

G
George Ni 已提交
64
        # build data loader
65 66 67 68 69 70 71 72 73
        if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
            self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
        else:
            self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]

        if cfg.architecture == 'DeepSORT' and self.mode == 'train':
            logger.error('DeepSORT has no need of training on mot dataset.')
            sys.exit(1)

G
George Ni 已提交
74 75 76 77 78 79
        if self.mode == 'train':
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num)

        if cfg.architecture == 'JDE' and self.mode == 'train':
            cfg['JDEEmbeddingHead'][
80 81
                'num_identities'] = self.dataset.num_identities_dict[0]
            # JDE only support single class MOT now.
G
George Ni 已提交
82

F
FlyingQianMM 已提交
83
        if cfg.architecture == 'FairMOT' and self.mode == 'train':
M
minghaoBD 已提交
84 85
            cfg['FairMOTEmbeddingHead'][
                'num_identities_dict'] = self.dataset.num_identities_dict
86
            # FairMOT support single class and multi-class MOT now.
F
FlyingQianMM 已提交
87

K
Kaipeng Deng 已提交
88
        # build model
89 90 91 92 93
        if 'model' not in self.cfg:
            self.model = create(cfg.architecture)
        else:
            self.model = self.cfg.model
            self.is_loaded_weights = True
94

95 96 97
        #normalize params for deploy
        self.model.load_meanstd(cfg['TestReader']['sample_transforms'])

98 99
        self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
        if self.use_ema:
G
Guanghua Yu 已提交
100 101
            ema_decay = self.cfg.get('ema_decay', 0.9998)
            cycle_epoch = self.cfg.get('cycle_epoch', -1)
102
            self.ema = ModelEMA(
G
Guanghua Yu 已提交
103 104 105 106
                self.model,
                decay=ema_decay,
                use_thres_step=True,
                cycle_epoch=cycle_epoch)
107

K
Kaipeng Deng 已提交
108 109 110 111 112 113 114 115
        # EvalDataset build with BatchSampler to evaluate in single device
        # TODO: multi-device evaluate
        if self.mode == 'eval':
            self._eval_batch_sampler = paddle.io.BatchSampler(
                self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num, self._eval_batch_sampler)
        # TestDataset build after user set images, skip loader creation here
K
Kaipeng Deng 已提交
116 117 118 119 120

        # build optimizer in train mode
        if self.mode == 'train':
            steps_per_epoch = len(self.loader)
            self.lr = create('LearningRate')(steps_per_epoch)
W
Wenyu 已提交
121
            self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
K
Kaipeng Deng 已提交
122

M
minghaoBD 已提交
123 124 125 126
        if self.cfg.get('unstructured_prune'):
            self.pruner = create('UnstructuredPruner')(self.model,
                                                       steps_per_epoch)

W
wangguanzhong 已提交
127 128
        self._nranks = dist.get_world_size()
        self._local_rank = dist.get_rank()
K
Kaipeng Deng 已提交
129

K
Kaipeng Deng 已提交
130 131 132
        self.status = {}

        self.start_epoch = 0
G
George Ni 已提交
133
        self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
K
Kaipeng Deng 已提交
134 135 136 137 138 139 140 141 142 143 144

        # initial default callbacks
        self._init_callbacks()

        # initial default metrics
        self._init_metrics()
        self._reset_metrics()

    def _init_callbacks(self):
        if self.mode == 'train':
            self._callbacks = [LogPrinter(self), Checkpointer(self)]
145
            if self.cfg.get('use_vdl', False):
146
                self._callbacks.append(VisualDLWriter(self))
147 148
            if self.cfg.get('save_proposals', False):
                self._callbacks.append(SniperProposalsGenerator(self))
K
Kaipeng Deng 已提交
149 150 151
            self._compose_callback = ComposeCallback(self._callbacks)
        elif self.mode == 'eval':
            self._callbacks = [LogPrinter(self)]
152 153
            if self.cfg.metric == 'WiderFace':
                self._callbacks.append(WiferFaceEval(self))
K
Kaipeng Deng 已提交
154
            self._compose_callback = ComposeCallback(self._callbacks)
155
        elif self.mode == 'test' and self.cfg.get('use_vdl', False):
156 157
            self._callbacks = [VisualDLWriter(self)]
            self._compose_callback = ComposeCallback(self._callbacks)
K
Kaipeng Deng 已提交
158 159 160 161
        else:
            self._callbacks = []
            self._compose_callback = None

K
Kaipeng Deng 已提交
162 163
    def _init_metrics(self, validate=False):
        if self.mode == 'test' or (self.mode == 'train' and not validate):
G
Guanghua Yu 已提交
164 165
            self._metrics = []
            return
166
        classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
167
        if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO":
W
wangxinxin08 已提交
168
            # TODO: bias should be unified
169
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
S
shangliang Xu 已提交
170 171
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
172
            save_prediction_only = self.cfg.get('save_prediction_only', False)
173 174 175

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
K
Kaipeng Deng 已提交
176 177
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None
178 179 180 181

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
182
            dataset = self.dataset
183 184 185 186
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()
187
                dataset = eval_dataset
188

189
            IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
190 191 192 193 194 195 196 197 198 199 200
            if self.cfg.metric == "COCO":
                self._metrics = [
                    COCOMetric(
                        anno_file=anno_file,
                        clsid2catid=clsid2catid,
                        classwise=classwise,
                        output_eval=output_eval,
                        bias=bias,
                        IouType=IouType,
                        save_prediction_only=save_prediction_only)
                ]
201
            elif self.cfg.metric == "SNIPERCOCO":  # sniper
202 203 204 205 206 207 208 209 210
                self._metrics = [
                    SNIPERCOCOMetric(
                        anno_file=anno_file,
                        dataset=dataset,
                        clsid2catid=clsid2catid,
                        classwise=classwise,
                        output_eval=output_eval,
                        bias=bias,
                        IouType=IouType,
211
                        save_prediction_only=save_prediction_only)
212
                ]
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
        elif self.cfg.metric == 'RBOX':
            # TODO: bias should be unified
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
            save_prediction_only = self.cfg.get('save_prediction_only', False)

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()

            self._metrics = [
                RBoxMetric(
                    anno_file=anno_file,
                    clsid2catid=clsid2catid,
                    classwise=classwise,
                    output_eval=output_eval,
                    bias=bias,
                    save_prediction_only=save_prediction_only)
            ]
K
Kaipeng Deng 已提交
242 243 244
        elif self.cfg.metric == 'VOC':
            self._metrics = [
                VOCMetric(
245
                    label_list=self.dataset.get_label_list(),
K
Kaipeng Deng 已提交
246
                    class_num=self.cfg.num_classes,
247 248
                    map_type=self.cfg.map_type,
                    classwise=classwise)
K
Kaipeng Deng 已提交
249
            ]
250 251 252 253 254 255 256 257 258
        elif self.cfg.metric == 'WiderFace':
            multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
            self._metrics = [
                WiderFaceMetric(
                    image_dir=os.path.join(self.dataset.dataset_dir,
                                           self.dataset.image_dir),
                    anno_file=self.dataset.get_anno(),
                    multi_scale=multi_scale)
            ]
259 260 261 262
        elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
263
            save_prediction_only = self.cfg.get('save_prediction_only', False)
264
            self._metrics = [
265 266 267 268 269 270
                KeyPointTopDownCOCOEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
271
            ]
Z
zhiboniu 已提交
272 273 274 275
        elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
276
            save_prediction_only = self.cfg.get('save_prediction_only', False)
Z
zhiboniu 已提交
277
            self._metrics = [
278 279 280 281 282 283
                KeyPointTopDownMPIIEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
Z
zhiboniu 已提交
284
            ]
G
George Ni 已提交
285 286
        elif self.cfg.metric == 'MOTDet':
            self._metrics = [JDEDetMetric(), ]
K
Kaipeng Deng 已提交
287
        else:
288
            logger.warning("Metric not support for metric type {}".format(
K
Kaipeng Deng 已提交
289
                self.cfg.metric))
K
Kaipeng Deng 已提交
290 291 292 293 294 295 296
            self._metrics = []

    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def register_callbacks(self, callbacks):
297
        callbacks = [c for c in list(callbacks) if c is not None]
K
Kaipeng Deng 已提交
298 299 300 301 302 303 304 305 306 307 308 309 310
        for c in callbacks:
            assert isinstance(c, Callback), \
                    "metrics shoule be instances of subclass of Metric"
        self._callbacks.extend(callbacks)
        self._compose_callback = ComposeCallback(self._callbacks)

    def register_metrics(self, metrics):
        metrics = [m for m in list(metrics) if m is not None]
        for m in metrics:
            assert isinstance(m, Metric), \
                    "metrics shoule be instances of subclass of Metric"
        self._metrics.extend(metrics)

K
Kaipeng Deng 已提交
311
    def load_weights(self, weights):
312 313
        if self.is_loaded_weights:
            return
K
Kaipeng Deng 已提交
314
        self.start_epoch = 0
315
        load_pretrain_weight(self.model, weights)
K
Kaipeng Deng 已提交
316 317
        logger.debug("Load weights {} to start training".format(weights))

318 319 320 321 322 323 324
    def load_weights_sde(self, det_weights, reid_weights):
        if self.model.detector:
            load_weight(self.model.detector, det_weights)
            load_weight(self.model.reid, reid_weights)
        else:
            load_weight(self.model.reid, reid_weights)

K
Kaipeng Deng 已提交
325
    def resume_weights(self, weights):
326 327 328 329 330 331
        # support Distill resume weights
        if hasattr(self.model, 'student_model'):
            self.start_epoch = load_weight(self.model.student_model, weights,
                                           self.optimizer)
        else:
            self.start_epoch = load_weight(self.model, weights, self.optimizer)
K
Kaipeng Deng 已提交
332
        logger.debug("Resume weights of epoch {}".format(self.start_epoch))
K
Kaipeng Deng 已提交
333

K
Kaipeng Deng 已提交
334
    def train(self, validate=False):
K
Kaipeng Deng 已提交
335
        assert self.mode == 'train', "Model not in 'train' mode"
Z
zhiboniu 已提交
336
        Init_mark = False
K
Kaipeng Deng 已提交
337

338
        model = self.model
339
        if self.cfg.get('fleet', False):
340
            model = fleet.distributed_model(model)
W
wangguanzhong 已提交
341
            self.optimizer = fleet.distributed_optimizer(self.optimizer)
342
        elif self._nranks > 1:
G
George Ni 已提交
343 344 345 346
            find_unused_parameters = self.cfg[
                'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
            model = paddle.DataParallel(
                self.model, find_unused_parameters=find_unused_parameters)
347 348

        # initial fp16
349
        if self.cfg.get('fp16', False):
350 351
            scaler = amp.GradScaler(
                enable=self.cfg.use_gpu, init_loss_scaling=1024)
K
Kaipeng Deng 已提交
352

K
Kaipeng Deng 已提交
353 354 355 356 357 358 359 360 361 362 363 364
        self.status.update({
            'epoch_id': self.start_epoch,
            'step_id': 0,
            'steps_per_epoch': len(self.loader)
        })

        self.status['batch_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['data_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)

G
Guanghua Yu 已提交
365 366
        if self.cfg.get('print_flops', False):
            self._flops(self.loader)
367
        profiler_options = self.cfg.get('profiler_options', None)
G
Guanghua Yu 已提交
368

369 370
        self._compose_callback.on_train_begin(self.status)

K
Kaipeng Deng 已提交
371
        for epoch_id in range(self.start_epoch, self.cfg.epoch):
K
Kaipeng Deng 已提交
372
            self.status['mode'] = 'train'
K
Kaipeng Deng 已提交
373 374 375
            self.status['epoch_id'] = epoch_id
            self._compose_callback.on_epoch_begin(self.status)
            self.loader.dataset.set_epoch(epoch_id)
K
Kaipeng Deng 已提交
376
            model.train()
K
Kaipeng Deng 已提交
377 378 379 380
            iter_tic = time.time()
            for step_id, data in enumerate(self.loader):
                self.status['data_time'].update(time.time() - iter_tic)
                self.status['step_id'] = step_id
381
                profiler.add_profiler_step(profiler_options)
K
Kaipeng Deng 已提交
382
                self._compose_callback.on_step_begin(self.status)
S
shangliang Xu 已提交
383
                data['epoch_id'] = epoch_id
K
Kaipeng Deng 已提交
384

385
                if self.cfg.get('fp16', False):
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
                    with amp.auto_cast(enable=self.cfg.use_gpu):
                        # model forward
                        outputs = model(data)
                        loss = outputs['loss']

                    # model backward
                    scaled_loss = scaler.scale(loss)
                    scaled_loss.backward()
                    # in dygraph mode, optimizer.minimize is equal to optimizer.step
                    scaler.minimize(self.optimizer, scaled_loss)
                else:
                    # model forward
                    outputs = model(data)
                    loss = outputs['loss']
                    # model backward
                    loss.backward()
                    self.optimizer.step()
K
Kaipeng Deng 已提交
403 404
                curr_lr = self.optimizer.get_lr()
                self.lr.step()
M
minghaoBD 已提交
405 406
                if self.cfg.get('unstructured_prune'):
                    self.pruner.step()
K
Kaipeng Deng 已提交
407 408 409
                self.optimizer.clear_grad()
                self.status['learning_rate'] = curr_lr

K
Kaipeng Deng 已提交
410
                if self._nranks < 2 or self._local_rank == 0:
K
Kaipeng Deng 已提交
411 412 413 414
                    self.status['training_staus'].update(outputs)

                self.status['batch_time'].update(time.time() - iter_tic)
                self._compose_callback.on_step_end(self.status)
415 416
                if self.use_ema:
                    self.ema.update(self.model)
F
Feng Ni 已提交
417
                iter_tic = time.time()
K
Kaipeng Deng 已提交
418

419 420
            # apply ema weight on model
            if self.use_ema:
421
                weight = copy.deepcopy(self.model.state_dict())
422
                self.model.set_dict(self.ema.apply())
M
minghaoBD 已提交
423 424
            if self.cfg.get('unstructured_prune'):
                self.pruner.update_params()
425

K
Kaipeng Deng 已提交
426 427
            self._compose_callback.on_epoch_end(self.status)

K
Kaipeng Deng 已提交
428
            if validate and (self._nranks < 2 or self._local_rank == 0) \
G
Guanghua Yu 已提交
429
                    and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
K
Kaipeng Deng 已提交
430 431 432 433 434 435 436 437 438 439 440 441
                             or epoch_id == self.end_epoch - 1):
                if not hasattr(self, '_eval_loader'):
                    # build evaluation dataset and loader
                    self._eval_dataset = self.cfg.EvalDataset
                    self._eval_batch_sampler = \
                        paddle.io.BatchSampler(
                            self._eval_dataset,
                            batch_size=self.cfg.EvalReader['batch_size'])
                    self._eval_loader = create('EvalReader')(
                        self._eval_dataset,
                        self.cfg.worker_num,
                        batch_sampler=self._eval_batch_sampler)
Z
zhiboniu 已提交
442 443 444 445 446 447
                # if validation in training is enabled, metrics should be re-init
                # Init_mark makes sure this code will only execute once
                if validate and Init_mark == False:
                    Init_mark = True
                    self._init_metrics(validate=validate)
                    self._reset_metrics()
K
Kaipeng Deng 已提交
448
                with paddle.no_grad():
449
                    self.status['save_best_model'] = True
K
Kaipeng Deng 已提交
450 451
                    self._eval_with_loader(self._eval_loader)

452 453 454 455
            # restore origin weight on model
            if self.use_ema:
                self.model.set_dict(weight)

456 457
        self._compose_callback.on_train_end(self.status)

K
Kaipeng Deng 已提交
458
    def _eval_with_loader(self, loader):
K
Kaipeng Deng 已提交
459 460 461
        sample_num = 0
        tic = time.time()
        self._compose_callback.on_epoch_begin(self.status)
K
Kaipeng Deng 已提交
462 463
        self.status['mode'] = 'eval'
        self.model.eval()
G
Guanghua Yu 已提交
464 465
        if self.cfg.get('print_flops', False):
            self._flops(loader)
K
Kaipeng Deng 已提交
466
        for step_id, data in enumerate(loader):
K
Kaipeng Deng 已提交
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
            self.status['step_id'] = step_id
            self._compose_callback.on_step_begin(self.status)
            # forward
            outs = self.model(data)

            # update metrics
            for metric in self._metrics:
                metric.update(data, outs)

            sample_num += data['im_id'].numpy().shape[0]
            self._compose_callback.on_step_end(self.status)

        self.status['sample_num'] = sample_num
        self.status['cost_time'] = time.time() - tic

        # accumulate metric to log out
        for metric in self._metrics:
            metric.accumulate()
            metric.log()
486
        self._compose_callback.on_epoch_end(self.status)
K
Kaipeng Deng 已提交
487 488 489
        # reset metric states for metric may performed multiple times
        self._reset_metrics()

K
Kaipeng Deng 已提交
490
    def evaluate(self):
491 492
        with paddle.no_grad():
            self._eval_with_loader(self.loader)
K
Kaipeng Deng 已提交
493

C
cnn 已提交
494 495 496 497 498
    def predict(self,
                images,
                draw_threshold=0.5,
                output_dir='output',
                save_txt=False):
K
Kaipeng Deng 已提交
499 500 501 502 503 504
        self.dataset.set_images(images)
        loader = create('TestReader')(self.dataset, 0)

        imid2path = self.dataset.get_imid2path()

        anno_file = self.dataset.get_anno()
C
cnn 已提交
505 506
        clsid2catid, catid2name = get_categories(
            self.cfg.metric, anno_file=anno_file)
K
Kaipeng Deng 已提交
507

K
Kaipeng Deng 已提交
508 509 510
        # Run Infer 
        self.status['mode'] = 'test'
        self.model.eval()
G
Guanghua Yu 已提交
511 512
        if self.cfg.get('print_flops', False):
            self._flops(loader)
513
        results = []
K
Kaipeng Deng 已提交
514 515 516 517
        for step_id, data in enumerate(loader):
            self.status['step_id'] = step_id
            # forward
            outs = self.model(data)
518

K
Kaipeng Deng 已提交
519 520
            for key in ['im_shape', 'scale_factor', 'im_id']:
                outs[key] = data[key]
G
Guanghua Yu 已提交
521
            for key, value in outs.items():
522 523
                if hasattr(value, 'numpy'):
                    outs[key] = value.numpy()
524 525 526
            results.append(outs)
        # sniper
        if type(self.dataset) == SniperCOCODataSet:
527 528
            results = self.dataset.anno_cropper.aggregate_chips_detections(
                results)
K
Kaipeng Deng 已提交
529

530
        for outs in results:
K
Kaipeng Deng 已提交
531 532
            batch_res = get_infer_results(outs, clsid2catid)
            bbox_num = outs['bbox_num']
Z
zhiboniu 已提交
533

K
Kaipeng Deng 已提交
534 535 536 537
            start = 0
            for i, im_id in enumerate(outs['im_id']):
                image_path = imid2path[int(im_id)]
                image = Image.open(image_path).convert('RGB')
538
                image = ImageOps.exif_transpose(image)
539
                self.status['original_image'] = np.array(image.copy())
K
Kaipeng Deng 已提交
540

541
                end = start + bbox_num[i]
K
Kaipeng Deng 已提交
542 543 544 545
                bbox_res = batch_res['bbox'][start:end] \
                        if 'bbox' in batch_res else None
                mask_res = batch_res['mask'][start:end] \
                        if 'mask' in batch_res else None
G
Guanghua Yu 已提交
546 547
                segm_res = batch_res['segm'][start:end] \
                        if 'segm' in batch_res else None
548 549 550 551
                keypoint_res = batch_res['keypoint'][start:end] \
                        if 'keypoint' in batch_res else None
                image = visualize_results(
                    image, bbox_res, mask_res, segm_res, keypoint_res,
C
cnn 已提交
552
                    int(im_id), catid2name, draw_threshold)
553
                self.status['result_image'] = np.array(image.copy())
554 555
                if self._compose_callback:
                    self._compose_callback.on_step_end(self.status)
K
Kaipeng Deng 已提交
556 557 558 559 560
                # save image with detection
                save_name = self._get_save_image_name(output_dir, image_path)
                logger.info("Detection bbox results save in {}".format(
                    save_name))
                image.save(save_name, quality=95)
C
cnn 已提交
561 562
                if save_txt:
                    save_path = os.path.splitext(save_name)[0] + '.txt'
563 564 565 566 567 568 569
                    results = {}
                    results["im_id"] = im_id
                    if bbox_res:
                        results["bbox_res"] = bbox_res
                    if keypoint_res:
                        results["keypoint_res"] = keypoint_res
                    save_result(save_path, results, catid2name, draw_threshold)
K
Kaipeng Deng 已提交
570 571 572 573 574 575 576 577 578 579 580 581
                start = end

    def _get_save_image_name(self, output_dir, image_path):
        """
        Get save image name from source image path.
        """
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        image_name = os.path.split(image_path)[-1]
        name, ext = os.path.splitext(image_name)
        return os.path.join(output_dir, "{}".format(name)) + ext

G
Guanghua Yu 已提交
582
    def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
K
Kaipeng Deng 已提交
583
        image_shape = None
584 585
        im_shape = [None, 2]
        scale_factor = [None, 2]
586 587 588 589 590 591
        if self.cfg.architecture in MOT_ARCH:
            test_reader_name = 'TestMOTReader'
        else:
            test_reader_name = 'TestReader'
        if 'inputs_def' in self.cfg[test_reader_name]:
            inputs_def = self.cfg[test_reader_name]['inputs_def']
K
Kaipeng Deng 已提交
592
            image_shape = inputs_def.get('image_shape', None)
G
Guanghua Yu 已提交
593
        # set image_shape=[None, 3, -1, -1] as default
K
Kaipeng Deng 已提交
594
        if image_shape is None:
G
Guanghua Yu 已提交
595
            image_shape = [None, 3, -1, -1]
596

G
Guanghua Yu 已提交
597 598
        if len(image_shape) == 3:
            image_shape = [None] + image_shape
599 600 601
        else:
            im_shape = [image_shape[0], 2]
            scale_factor = [image_shape[0], 2]
K
Kaipeng Deng 已提交
602

603 604 605 606 607
        if hasattr(self.model, 'deploy'):
            self.model.deploy = True
        if hasattr(self.model, 'fuse_norm'):
            self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
                                                              False)
K
Kaipeng Deng 已提交
608

K
Kaipeng Deng 已提交
609 610 611 612 613 614 615
        # Save infer cfg
        _dump_infer_config(self.cfg,
                           os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
                           self.model)

        input_spec = [{
            "image": InputSpec(
G
Guanghua Yu 已提交
616
                shape=image_shape, name='image'),
K
Kaipeng Deng 已提交
617
            "im_shape": InputSpec(
618
                shape=im_shape, name='im_shape'),
K
Kaipeng Deng 已提交
619
            "scale_factor": InputSpec(
620
                shape=scale_factor, name='scale_factor')
K
Kaipeng Deng 已提交
621
        }]
G
George Ni 已提交
622 623 624 625 626
        if self.cfg.architecture == 'DeepSORT':
            input_spec[0].update({
                "crops": InputSpec(
                    shape=[None, 3, 192, 64], name='crops')
            })
G
Guanghua Yu 已提交
627 628 629 630 631 632 633 634 635 636 637 638
        if prune_input:
            static_model = paddle.jit.to_static(
                self.model, input_spec=input_spec)
            # NOTE: dy2st do not pruned program, but jit.save will prune program
            # input spec, prune input spec here and save with pruned input spec
            pruned_input_spec = _prune_input_spec(
                input_spec, static_model.forward.main_program,
                static_model.forward.outputs)
        else:
            static_model = None
            pruned_input_spec = input_spec

G
Guanghua Yu 已提交
639 640 641 642 643 644 645
        # TODO: Hard code, delete it when support prune input_spec.
        if self.cfg.architecture == 'PicoDet':
            pruned_input_spec = [{
                "image": InputSpec(
                    shape=image_shape, name='image')
            }]

G
Guanghua Yu 已提交
646 647 648 649 650 651 652 653
        return static_model, pruned_input_spec

    def export(self, output_dir='output_inference'):
        self.model.eval()
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
K
Kaipeng Deng 已提交
654

G
Guanghua Yu 已提交
655 656
        static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir)
G
Guanghua Yu 已提交
657 658 659

        # dy2st and save model
        if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
660 661 662 663 664
            paddle.jit.save(
                static_model,
                os.path.join(save_dir, 'model'),
                input_spec=pruned_input_spec)
        else:
665
            self.cfg.slim.save_quantized_model(
666 667
                self.model,
                os.path.join(save_dir, 'model'),
G
Guanghua Yu 已提交
668 669
                input_spec=pruned_input_spec)
        logger.info("Export model and saved in {}".format(save_dir))
670

G
Guanghua Yu 已提交
671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
    def post_quant(self, output_dir='output_inference'):
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for idx, data in enumerate(self.loader):
            self.model(data)
            if idx == int(self.cfg.get('quant_batch_num', 10)):
                break

        # TODO: support prune input_spec
        _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir, prune_input=False)

        self.cfg.slim.save_quantized_model(
            self.model,
            os.path.join(save_dir, 'model'),
            input_spec=pruned_input_spec)
        logger.info("Export Post-Quant model and saved in {}".format(save_dir))
G
Guanghua Yu 已提交
691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715

    def _flops(self, loader):
        self.model.eval()
        try:
            import paddleslim
        except Exception as e:
            logger.warning(
                'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
            )
            return

        from paddleslim.analysis import dygraph_flops as flops
        input_data = None
        for data in loader:
            input_data = data
            break

        input_spec = [{
            "image": input_data['image'][0].unsqueeze(0),
            "im_shape": input_data['im_shape'][0].unsqueeze(0),
            "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
        }]
        flops = flops(self.model, input_spec) / (1000**3)
        logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
            flops, input_data['image'][0].unsqueeze(0).shape))