trainer.py 26.9 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
K
Kaipeng Deng 已提交
13 14 15 16 17 18 19
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
G
George Ni 已提交
20
import sys
21
import copy
K
Kaipeng Deng 已提交
22
import time
M
Manuel Garcia 已提交
23

K
Kaipeng Deng 已提交
24 25 26 27
import numpy as np
from PIL import Image

import paddle
W
wangguanzhong 已提交
28 29
import paddle.distributed as dist
from paddle.distributed import fleet
30
from paddle import amp
K
Kaipeng Deng 已提交
31
from paddle.static import InputSpec
32
from ppdet.optimizer import ModelEMA
K
Kaipeng Deng 已提交
33 34 35

from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
C
cnn 已提交
36
from ppdet.utils.visualizer import visualize_results, save_result
Z
zhiboniu 已提交
37
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
G
George Ni 已提交
38
from ppdet.metrics import RBoxMetric, JDEDetMetric
K
Kaipeng Deng 已提交
39
from ppdet.data.source.category import get_categories
K
Kaipeng Deng 已提交
40
import ppdet.utils.stats as stats
41
from ppdet.utils import profiler
K
Kaipeng Deng 已提交
42

43
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter
G
Guanghua Yu 已提交
44
from .export_utils import _dump_infer_config, _prune_input_spec
K
Kaipeng Deng 已提交
45 46

from ppdet.utils.logger import setup_logger
47
logger = setup_logger('ppdet.engine')
K
Kaipeng Deng 已提交
48 49 50

__all__ = ['Trainer']

51 52
MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']

K
Kaipeng Deng 已提交
53 54 55 56 57 58 59

class Trainer(object):
    def __init__(self, cfg, mode='train'):
        self.cfg = cfg
        assert mode.lower() in ['train', 'eval', 'test'], \
                "mode should be 'train', 'eval' or 'test'"
        self.mode = mode.lower()
60
        self.optimizer = None
61
        self.is_loaded_weights = False
K
Kaipeng Deng 已提交
62

G
George Ni 已提交
63
        # build data loader
64 65 66 67 68 69 70 71 72
        if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
            self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
        else:
            self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]

        if cfg.architecture == 'DeepSORT' and self.mode == 'train':
            logger.error('DeepSORT has no need of training on mot dataset.')
            sys.exit(1)

G
George Ni 已提交
73 74 75 76 77 78 79 80
        if self.mode == 'train':
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num)

        if cfg.architecture == 'JDE' and self.mode == 'train':
            cfg['JDEEmbeddingHead'][
                'num_identifiers'] = self.dataset.total_identities

F
FlyingQianMM 已提交
81 82 83 84
        if cfg.architecture == 'FairMOT' and self.mode == 'train':
            cfg['FairMOTEmbeddingHead'][
                'num_identifiers'] = self.dataset.total_identities

K
Kaipeng Deng 已提交
85
        # build model
86 87 88 89 90
        if 'model' not in self.cfg:
            self.model = create(cfg.architecture)
        else:
            self.model = self.cfg.model
            self.is_loaded_weights = True
91

92 93 94
        #normalize params for deploy
        self.model.load_meanstd(cfg['TestReader']['sample_transforms'])

95 96
        self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
        if self.use_ema:
G
Guanghua Yu 已提交
97 98
            ema_decay = self.cfg.get('ema_decay', 0.9998)
            cycle_epoch = self.cfg.get('cycle_epoch', -1)
99
            self.ema = ModelEMA(
G
Guanghua Yu 已提交
100 101 102 103
                self.model,
                decay=ema_decay,
                use_thres_step=True,
                cycle_epoch=cycle_epoch)
104

K
Kaipeng Deng 已提交
105 106 107 108 109 110 111 112
        # EvalDataset build with BatchSampler to evaluate in single device
        # TODO: multi-device evaluate
        if self.mode == 'eval':
            self._eval_batch_sampler = paddle.io.BatchSampler(
                self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num, self._eval_batch_sampler)
        # TestDataset build after user set images, skip loader creation here
K
Kaipeng Deng 已提交
113 114 115 116 117

        # build optimizer in train mode
        if self.mode == 'train':
            steps_per_epoch = len(self.loader)
            self.lr = create('LearningRate')(steps_per_epoch)
W
Wenyu 已提交
118
            self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
K
Kaipeng Deng 已提交
119

W
wangguanzhong 已提交
120 121
        self._nranks = dist.get_world_size()
        self._local_rank = dist.get_rank()
K
Kaipeng Deng 已提交
122

K
Kaipeng Deng 已提交
123 124 125
        self.status = {}

        self.start_epoch = 0
G
George Ni 已提交
126
        self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
K
Kaipeng Deng 已提交
127 128 129 130 131 132 133 134 135 136 137

        # initial default callbacks
        self._init_callbacks()

        # initial default metrics
        self._init_metrics()
        self._reset_metrics()

    def _init_callbacks(self):
        if self.mode == 'train':
            self._callbacks = [LogPrinter(self), Checkpointer(self)]
138
            if self.cfg.get('use_vdl', False):
139
                self._callbacks.append(VisualDLWriter(self))
K
Kaipeng Deng 已提交
140 141 142
            self._compose_callback = ComposeCallback(self._callbacks)
        elif self.mode == 'eval':
            self._callbacks = [LogPrinter(self)]
143 144
            if self.cfg.metric == 'WiderFace':
                self._callbacks.append(WiferFaceEval(self))
K
Kaipeng Deng 已提交
145
            self._compose_callback = ComposeCallback(self._callbacks)
146
        elif self.mode == 'test' and self.cfg.get('use_vdl', False):
147 148
            self._callbacks = [VisualDLWriter(self)]
            self._compose_callback = ComposeCallback(self._callbacks)
K
Kaipeng Deng 已提交
149 150 151 152
        else:
            self._callbacks = []
            self._compose_callback = None

K
Kaipeng Deng 已提交
153 154
    def _init_metrics(self, validate=False):
        if self.mode == 'test' or (self.mode == 'train' and not validate):
G
Guanghua Yu 已提交
155 156
            self._metrics = []
            return
157
        classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
K
Kaipeng Deng 已提交
158
        if self.cfg.metric == 'COCO':
W
wangxinxin08 已提交
159
            # TODO: bias should be unified
160
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
S
shangliang Xu 已提交
161 162
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
163
            save_prediction_only = self.cfg.get('save_prediction_only', False)
164 165 166

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
K
Kaipeng Deng 已提交
167 168
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None
169 170 171 172 173 174 175 176 177

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()

178
            IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
W
wangxinxin08 已提交
179 180
            self._metrics = [
                COCOMetric(
181
                    anno_file=anno_file,
K
Kaipeng Deng 已提交
182
                    clsid2catid=clsid2catid,
183
                    classwise=classwise,
S
shangliang Xu 已提交
184
                    output_eval=output_eval,
185
                    bias=bias,
186
                    IouType=IouType,
187
                    save_prediction_only=save_prediction_only)
W
wangxinxin08 已提交
188
            ]
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
        elif self.cfg.metric == 'RBOX':
            # TODO: bias should be unified
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
            save_prediction_only = self.cfg.get('save_prediction_only', False)

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()

            self._metrics = [
                RBoxMetric(
                    anno_file=anno_file,
                    clsid2catid=clsid2catid,
                    classwise=classwise,
                    output_eval=output_eval,
                    bias=bias,
                    save_prediction_only=save_prediction_only)
            ]
K
Kaipeng Deng 已提交
218 219 220
        elif self.cfg.metric == 'VOC':
            self._metrics = [
                VOCMetric(
221
                    label_list=self.dataset.get_label_list(),
K
Kaipeng Deng 已提交
222
                    class_num=self.cfg.num_classes,
223 224
                    map_type=self.cfg.map_type,
                    classwise=classwise)
K
Kaipeng Deng 已提交
225
            ]
226 227 228 229 230 231 232 233 234
        elif self.cfg.metric == 'WiderFace':
            multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
            self._metrics = [
                WiderFaceMetric(
                    image_dir=os.path.join(self.dataset.dataset_dir,
                                           self.dataset.image_dir),
                    anno_file=self.dataset.get_anno(),
                    multi_scale=multi_scale)
            ]
235 236 237 238
        elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
239
            save_prediction_only = self.cfg.get('save_prediction_only', False)
240
            self._metrics = [
241 242 243 244 245 246
                KeyPointTopDownCOCOEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
247
            ]
Z
zhiboniu 已提交
248 249 250 251
        elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
252
            save_prediction_only = self.cfg.get('save_prediction_only', False)
Z
zhiboniu 已提交
253
            self._metrics = [
254 255 256 257 258 259
                KeyPointTopDownMPIIEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
Z
zhiboniu 已提交
260
            ]
G
George Ni 已提交
261 262
        elif self.cfg.metric == 'MOTDet':
            self._metrics = [JDEDetMetric(), ]
K
Kaipeng Deng 已提交
263
        else:
264
            logger.warning("Metric not support for metric type {}".format(
K
Kaipeng Deng 已提交
265
                self.cfg.metric))
K
Kaipeng Deng 已提交
266 267 268 269 270 271 272
            self._metrics = []

    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def register_callbacks(self, callbacks):
273
        callbacks = [c for c in list(callbacks) if c is not None]
K
Kaipeng Deng 已提交
274 275 276 277 278 279 280 281 282 283 284 285 286
        for c in callbacks:
            assert isinstance(c, Callback), \
                    "metrics shoule be instances of subclass of Metric"
        self._callbacks.extend(callbacks)
        self._compose_callback = ComposeCallback(self._callbacks)

    def register_metrics(self, metrics):
        metrics = [m for m in list(metrics) if m is not None]
        for m in metrics:
            assert isinstance(m, Metric), \
                    "metrics shoule be instances of subclass of Metric"
        self._metrics.extend(metrics)

K
Kaipeng Deng 已提交
287
    def load_weights(self, weights):
288 289
        if self.is_loaded_weights:
            return
K
Kaipeng Deng 已提交
290
        self.start_epoch = 0
291
        load_pretrain_weight(self.model, weights)
K
Kaipeng Deng 已提交
292 293
        logger.debug("Load weights {} to start training".format(weights))

294 295 296 297 298 299 300
    def load_weights_sde(self, det_weights, reid_weights):
        if self.model.detector:
            load_weight(self.model.detector, det_weights)
            load_weight(self.model.reid, reid_weights)
        else:
            load_weight(self.model.reid, reid_weights)

K
Kaipeng Deng 已提交
301
    def resume_weights(self, weights):
302 303 304 305 306 307
        # support Distill resume weights
        if hasattr(self.model, 'student_model'):
            self.start_epoch = load_weight(self.model.student_model, weights,
                                           self.optimizer)
        else:
            self.start_epoch = load_weight(self.model, weights, self.optimizer)
K
Kaipeng Deng 已提交
308
        logger.debug("Resume weights of epoch {}".format(self.start_epoch))
K
Kaipeng Deng 已提交
309

K
Kaipeng Deng 已提交
310
    def train(self, validate=False):
K
Kaipeng Deng 已提交
311
        assert self.mode == 'train', "Model not in 'train' mode"
Z
zhiboniu 已提交
312
        Init_mark = False
K
Kaipeng Deng 已提交
313

314
        model = self.model
315
        if self.cfg.get('fleet', False):
316
            model = fleet.distributed_model(model)
W
wangguanzhong 已提交
317
            self.optimizer = fleet.distributed_optimizer(self.optimizer)
318
        elif self._nranks > 1:
G
George Ni 已提交
319 320 321 322
            find_unused_parameters = self.cfg[
                'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
            model = paddle.DataParallel(
                self.model, find_unused_parameters=find_unused_parameters)
323 324

        # initial fp16
325
        if self.cfg.get('fp16', False):
326 327
            scaler = amp.GradScaler(
                enable=self.cfg.use_gpu, init_loss_scaling=1024)
K
Kaipeng Deng 已提交
328

K
Kaipeng Deng 已提交
329 330 331 332 333 334 335 336 337 338 339 340
        self.status.update({
            'epoch_id': self.start_epoch,
            'step_id': 0,
            'steps_per_epoch': len(self.loader)
        })

        self.status['batch_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['data_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)

G
Guanghua Yu 已提交
341 342
        if self.cfg.get('print_flops', False):
            self._flops(self.loader)
343
        profiler_options = self.cfg.get('profiler_options', None)
G
Guanghua Yu 已提交
344

K
Kaipeng Deng 已提交
345
        for epoch_id in range(self.start_epoch, self.cfg.epoch):
K
Kaipeng Deng 已提交
346
            self.status['mode'] = 'train'
K
Kaipeng Deng 已提交
347 348 349
            self.status['epoch_id'] = epoch_id
            self._compose_callback.on_epoch_begin(self.status)
            self.loader.dataset.set_epoch(epoch_id)
K
Kaipeng Deng 已提交
350
            model.train()
K
Kaipeng Deng 已提交
351 352 353 354
            iter_tic = time.time()
            for step_id, data in enumerate(self.loader):
                self.status['data_time'].update(time.time() - iter_tic)
                self.status['step_id'] = step_id
355
                profiler.add_profiler_step(profiler_options)
K
Kaipeng Deng 已提交
356
                self._compose_callback.on_step_begin(self.status)
S
shangliang Xu 已提交
357
                data['epoch_id'] = epoch_id
K
Kaipeng Deng 已提交
358

359
                if self.cfg.get('fp16', False):
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
                    with amp.auto_cast(enable=self.cfg.use_gpu):
                        # model forward
                        outputs = model(data)
                        loss = outputs['loss']

                    # model backward
                    scaled_loss = scaler.scale(loss)
                    scaled_loss.backward()
                    # in dygraph mode, optimizer.minimize is equal to optimizer.step
                    scaler.minimize(self.optimizer, scaled_loss)
                else:
                    # model forward
                    outputs = model(data)
                    loss = outputs['loss']
                    # model backward
                    loss.backward()
                    self.optimizer.step()
K
Kaipeng Deng 已提交
377 378 379 380 381 382

                curr_lr = self.optimizer.get_lr()
                self.lr.step()
                self.optimizer.clear_grad()
                self.status['learning_rate'] = curr_lr

K
Kaipeng Deng 已提交
383
                if self._nranks < 2 or self._local_rank == 0:
K
Kaipeng Deng 已提交
384 385 386 387
                    self.status['training_staus'].update(outputs)

                self.status['batch_time'].update(time.time() - iter_tic)
                self._compose_callback.on_step_end(self.status)
388 389
                if self.use_ema:
                    self.ema.update(self.model)
F
Feng Ni 已提交
390
                iter_tic = time.time()
K
Kaipeng Deng 已提交
391

392 393
            # apply ema weight on model
            if self.use_ema:
394
                weight = copy.deepcopy(self.model.state_dict())
395 396
                self.model.set_dict(self.ema.apply())

K
Kaipeng Deng 已提交
397 398
            self._compose_callback.on_epoch_end(self.status)

K
Kaipeng Deng 已提交
399
            if validate and (self._nranks < 2 or self._local_rank == 0) \
G
Guanghua Yu 已提交
400
                    and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
K
Kaipeng Deng 已提交
401 402 403 404 405 406 407 408 409 410 411 412
                             or epoch_id == self.end_epoch - 1):
                if not hasattr(self, '_eval_loader'):
                    # build evaluation dataset and loader
                    self._eval_dataset = self.cfg.EvalDataset
                    self._eval_batch_sampler = \
                        paddle.io.BatchSampler(
                            self._eval_dataset,
                            batch_size=self.cfg.EvalReader['batch_size'])
                    self._eval_loader = create('EvalReader')(
                        self._eval_dataset,
                        self.cfg.worker_num,
                        batch_sampler=self._eval_batch_sampler)
Z
zhiboniu 已提交
413 414 415 416 417 418
                # if validation in training is enabled, metrics should be re-init
                # Init_mark makes sure this code will only execute once
                if validate and Init_mark == False:
                    Init_mark = True
                    self._init_metrics(validate=validate)
                    self._reset_metrics()
K
Kaipeng Deng 已提交
419
                with paddle.no_grad():
420
                    self.status['save_best_model'] = True
K
Kaipeng Deng 已提交
421 422
                    self._eval_with_loader(self._eval_loader)

423 424 425 426
            # restore origin weight on model
            if self.use_ema:
                self.model.set_dict(weight)

K
Kaipeng Deng 已提交
427
    def _eval_with_loader(self, loader):
K
Kaipeng Deng 已提交
428 429 430
        sample_num = 0
        tic = time.time()
        self._compose_callback.on_epoch_begin(self.status)
K
Kaipeng Deng 已提交
431 432
        self.status['mode'] = 'eval'
        self.model.eval()
G
Guanghua Yu 已提交
433 434
        if self.cfg.get('print_flops', False):
            self._flops(loader)
K
Kaipeng Deng 已提交
435
        for step_id, data in enumerate(loader):
K
Kaipeng Deng 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454
            self.status['step_id'] = step_id
            self._compose_callback.on_step_begin(self.status)
            # forward
            outs = self.model(data)

            # update metrics
            for metric in self._metrics:
                metric.update(data, outs)

            sample_num += data['im_id'].numpy().shape[0]
            self._compose_callback.on_step_end(self.status)

        self.status['sample_num'] = sample_num
        self.status['cost_time'] = time.time() - tic

        # accumulate metric to log out
        for metric in self._metrics:
            metric.accumulate()
            metric.log()
455
        self._compose_callback.on_epoch_end(self.status)
K
Kaipeng Deng 已提交
456 457 458
        # reset metric states for metric may performed multiple times
        self._reset_metrics()

K
Kaipeng Deng 已提交
459
    def evaluate(self):
460 461
        with paddle.no_grad():
            self._eval_with_loader(self.loader)
K
Kaipeng Deng 已提交
462

C
cnn 已提交
463 464 465 466 467
    def predict(self,
                images,
                draw_threshold=0.5,
                output_dir='output',
                save_txt=False):
K
Kaipeng Deng 已提交
468 469 470 471 472 473
        self.dataset.set_images(images)
        loader = create('TestReader')(self.dataset, 0)

        imid2path = self.dataset.get_imid2path()

        anno_file = self.dataset.get_anno()
C
cnn 已提交
474 475
        clsid2catid, catid2name = get_categories(
            self.cfg.metric, anno_file=anno_file)
K
Kaipeng Deng 已提交
476

K
Kaipeng Deng 已提交
477 478 479
        # Run Infer 
        self.status['mode'] = 'test'
        self.model.eval()
G
Guanghua Yu 已提交
480 481
        if self.cfg.get('print_flops', False):
            self._flops(loader)
K
Kaipeng Deng 已提交
482 483 484 485
        for step_id, data in enumerate(loader):
            self.status['step_id'] = step_id
            # forward
            outs = self.model(data)
486

K
Kaipeng Deng 已提交
487 488
            for key in ['im_shape', 'scale_factor', 'im_id']:
                outs[key] = data[key]
G
Guanghua Yu 已提交
489
            for key, value in outs.items():
490 491
                if hasattr(value, 'numpy'):
                    outs[key] = value.numpy()
K
Kaipeng Deng 已提交
492 493 494

            batch_res = get_infer_results(outs, clsid2catid)
            bbox_num = outs['bbox_num']
Z
zhiboniu 已提交
495

K
Kaipeng Deng 已提交
496 497 498 499
            start = 0
            for i, im_id in enumerate(outs['im_id']):
                image_path = imid2path[int(im_id)]
                image = Image.open(image_path).convert('RGB')
500
                self.status['original_image'] = np.array(image.copy())
K
Kaipeng Deng 已提交
501

502
                end = start + bbox_num[i]
K
Kaipeng Deng 已提交
503 504 505 506
                bbox_res = batch_res['bbox'][start:end] \
                        if 'bbox' in batch_res else None
                mask_res = batch_res['mask'][start:end] \
                        if 'mask' in batch_res else None
G
Guanghua Yu 已提交
507 508
                segm_res = batch_res['segm'][start:end] \
                        if 'segm' in batch_res else None
509 510 511 512
                keypoint_res = batch_res['keypoint'][start:end] \
                        if 'keypoint' in batch_res else None
                image = visualize_results(
                    image, bbox_res, mask_res, segm_res, keypoint_res,
C
cnn 已提交
513
                    int(im_id), catid2name, draw_threshold)
514
                self.status['result_image'] = np.array(image.copy())
515 516
                if self._compose_callback:
                    self._compose_callback.on_step_end(self.status)
K
Kaipeng Deng 已提交
517 518 519 520 521
                # save image with detection
                save_name = self._get_save_image_name(output_dir, image_path)
                logger.info("Detection bbox results save in {}".format(
                    save_name))
                image.save(save_name, quality=95)
C
cnn 已提交
522 523
                if save_txt:
                    save_path = os.path.splitext(save_name)[0] + '.txt'
524 525 526 527 528 529 530
                    results = {}
                    results["im_id"] = im_id
                    if bbox_res:
                        results["bbox_res"] = bbox_res
                    if keypoint_res:
                        results["keypoint_res"] = keypoint_res
                    save_result(save_path, results, catid2name, draw_threshold)
K
Kaipeng Deng 已提交
531 532 533 534 535 536 537 538 539 540 541 542
                start = end

    def _get_save_image_name(self, output_dir, image_path):
        """
        Get save image name from source image path.
        """
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        image_name = os.path.split(image_path)[-1]
        name, ext = os.path.splitext(image_name)
        return os.path.join(output_dir, "{}".format(name)) + ext

G
Guanghua Yu 已提交
543
    def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
K
Kaipeng Deng 已提交
544
        image_shape = None
545 546
        im_shape = [None, 2]
        scale_factor = [None, 2]
547 548 549 550 551 552
        if self.cfg.architecture in MOT_ARCH:
            test_reader_name = 'TestMOTReader'
        else:
            test_reader_name = 'TestReader'
        if 'inputs_def' in self.cfg[test_reader_name]:
            inputs_def = self.cfg[test_reader_name]['inputs_def']
K
Kaipeng Deng 已提交
553
            image_shape = inputs_def.get('image_shape', None)
G
Guanghua Yu 已提交
554
        # set image_shape=[None, 3, -1, -1] as default
K
Kaipeng Deng 已提交
555
        if image_shape is None:
G
Guanghua Yu 已提交
556
            image_shape = [None, 3, -1, -1]
557

G
Guanghua Yu 已提交
558 559
        if len(image_shape) == 3:
            image_shape = [None] + image_shape
560 561 562
        else:
            im_shape = [image_shape[0], 2]
            scale_factor = [image_shape[0], 2]
K
Kaipeng Deng 已提交
563

564 565 566 567 568
        if hasattr(self.model, 'deploy'):
            self.model.deploy = True
        if hasattr(self.model, 'fuse_norm'):
            self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
                                                              False)
K
Kaipeng Deng 已提交
569

K
Kaipeng Deng 已提交
570 571 572 573 574 575 576
        # Save infer cfg
        _dump_infer_config(self.cfg,
                           os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
                           self.model)

        input_spec = [{
            "image": InputSpec(
G
Guanghua Yu 已提交
577
                shape=image_shape, name='image'),
K
Kaipeng Deng 已提交
578
            "im_shape": InputSpec(
579
                shape=im_shape, name='im_shape'),
K
Kaipeng Deng 已提交
580
            "scale_factor": InputSpec(
581
                shape=scale_factor, name='scale_factor')
K
Kaipeng Deng 已提交
582
        }]
G
George Ni 已提交
583 584 585 586 587
        if self.cfg.architecture == 'DeepSORT':
            input_spec[0].update({
                "crops": InputSpec(
                    shape=[None, 3, 192, 64], name='crops')
            })
G
Guanghua Yu 已提交
588 589 590 591 592 593 594 595 596 597 598 599
        if prune_input:
            static_model = paddle.jit.to_static(
                self.model, input_spec=input_spec)
            # NOTE: dy2st do not pruned program, but jit.save will prune program
            # input spec, prune input spec here and save with pruned input spec
            pruned_input_spec = _prune_input_spec(
                input_spec, static_model.forward.main_program,
                static_model.forward.outputs)
        else:
            static_model = None
            pruned_input_spec = input_spec

G
Guanghua Yu 已提交
600 601 602 603 604 605 606
        # TODO: Hard code, delete it when support prune input_spec.
        if self.cfg.architecture == 'PicoDet':
            pruned_input_spec = [{
                "image": InputSpec(
                    shape=image_shape, name='image')
            }]

G
Guanghua Yu 已提交
607 608 609 610 611 612 613 614
        return static_model, pruned_input_spec

    def export(self, output_dir='output_inference'):
        self.model.eval()
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
K
Kaipeng Deng 已提交
615

G
Guanghua Yu 已提交
616 617
        static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir)
G
Guanghua Yu 已提交
618 619 620

        # dy2st and save model
        if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
621 622 623 624 625
            paddle.jit.save(
                static_model,
                os.path.join(save_dir, 'model'),
                input_spec=pruned_input_spec)
        else:
626
            self.cfg.slim.save_quantized_model(
627 628
                self.model,
                os.path.join(save_dir, 'model'),
G
Guanghua Yu 已提交
629 630
                input_spec=pruned_input_spec)
        logger.info("Export model and saved in {}".format(save_dir))
631

G
Guanghua Yu 已提交
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
    def post_quant(self, output_dir='output_inference'):
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for idx, data in enumerate(self.loader):
            self.model(data)
            if idx == int(self.cfg.get('quant_batch_num', 10)):
                break

        # TODO: support prune input_spec
        _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir, prune_input=False)

        self.cfg.slim.save_quantized_model(
            self.model,
            os.path.join(save_dir, 'model'),
            input_spec=pruned_input_spec)
        logger.info("Export Post-Quant model and saved in {}".format(save_dir))
G
Guanghua Yu 已提交
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676

    def _flops(self, loader):
        self.model.eval()
        try:
            import paddleslim
        except Exception as e:
            logger.warning(
                'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
            )
            return

        from paddleslim.analysis import dygraph_flops as flops
        input_data = None
        for data in loader:
            input_data = data
            break

        input_spec = [{
            "image": input_data['image'][0].unsqueeze(0),
            "im_shape": input_data['im_shape'][0].unsqueeze(0),
            "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
        }]
        flops = flops(self.model, input_spec) / (1000**3)
        logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
            flops, input_data['image'][0].unsqueeze(0).shape))