trainer.py 25.3 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
K
Kaipeng Deng 已提交
13 14 15 16 17 18 19
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
20
import copy
K
Kaipeng Deng 已提交
21
import time
M
Manuel Garcia 已提交
22

K
Kaipeng Deng 已提交
23 24 25 26
import numpy as np
from PIL import Image

import paddle
W
wangguanzhong 已提交
27 28
import paddle.distributed as dist
from paddle.distributed import fleet
29
from paddle import amp
K
Kaipeng Deng 已提交
30
from paddle.static import InputSpec
31
from ppdet.optimizer import ModelEMA
K
Kaipeng Deng 已提交
32 33 34

from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
C
cnn 已提交
35
from ppdet.utils.visualizer import visualize_results, save_result
Z
zhiboniu 已提交
36
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
G
George Ni 已提交
37
from ppdet.metrics import RBoxMetric, JDEDetMetric
K
Kaipeng Deng 已提交
38
from ppdet.data.source.category import get_categories
K
Kaipeng Deng 已提交
39 40
import ppdet.utils.stats as stats

41
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter
K
Kaipeng Deng 已提交
42 43 44
from .export_utils import _dump_infer_config

from ppdet.utils.logger import setup_logger
45
logger = setup_logger('ppdet.engine')
K
Kaipeng Deng 已提交
46 47 48

__all__ = ['Trainer']

49 50
MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']

K
Kaipeng Deng 已提交
51 52 53 54 55 56 57

class Trainer(object):
    def __init__(self, cfg, mode='train'):
        self.cfg = cfg
        assert mode.lower() in ['train', 'eval', 'test'], \
                "mode should be 'train', 'eval' or 'test'"
        self.mode = mode.lower()
58
        self.optimizer = None
59
        self.is_loaded_weights = False
K
Kaipeng Deng 已提交
60

G
George Ni 已提交
61
        # build data loader
62 63 64 65 66 67 68 69 70
        if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
            self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
        else:
            self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]

        if cfg.architecture == 'DeepSORT' and self.mode == 'train':
            logger.error('DeepSORT has no need of training on mot dataset.')
            sys.exit(1)

G
George Ni 已提交
71 72 73 74 75 76 77 78
        if self.mode == 'train':
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num)

        if cfg.architecture == 'JDE' and self.mode == 'train':
            cfg['JDEEmbeddingHead'][
                'num_identifiers'] = self.dataset.total_identities

F
FlyingQianMM 已提交
79 80 81 82
        if cfg.architecture == 'FairMOT' and self.mode == 'train':
            cfg['FairMOTEmbeddingHead'][
                'num_identifiers'] = self.dataset.total_identities

K
Kaipeng Deng 已提交
83
        # build model
84 85 86 87 88
        if 'model' not in self.cfg:
            self.model = create(cfg.architecture)
        else:
            self.model = self.cfg.model
            self.is_loaded_weights = True
89

90 91 92 93 94
        self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
        if self.use_ema:
            self.ema = ModelEMA(
                cfg['ema_decay'], self.model, use_thres_step=True)

K
Kaipeng Deng 已提交
95 96 97 98 99 100 101 102
        # EvalDataset build with BatchSampler to evaluate in single device
        # TODO: multi-device evaluate
        if self.mode == 'eval':
            self._eval_batch_sampler = paddle.io.BatchSampler(
                self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num, self._eval_batch_sampler)
        # TestDataset build after user set images, skip loader creation here
K
Kaipeng Deng 已提交
103 104 105 106 107 108 109 110

        # build optimizer in train mode
        if self.mode == 'train':
            steps_per_epoch = len(self.loader)
            self.lr = create('LearningRate')(steps_per_epoch)
            self.optimizer = create('OptimizerBuilder')(self.lr,
                                                        self.model.parameters())

W
wangguanzhong 已提交
111 112
        self._nranks = dist.get_world_size()
        self._local_rank = dist.get_rank()
K
Kaipeng Deng 已提交
113

K
Kaipeng Deng 已提交
114 115 116
        self.status = {}

        self.start_epoch = 0
G
George Ni 已提交
117
        self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
K
Kaipeng Deng 已提交
118 119 120 121 122 123 124 125 126 127 128

        # initial default callbacks
        self._init_callbacks()

        # initial default metrics
        self._init_metrics()
        self._reset_metrics()

    def _init_callbacks(self):
        if self.mode == 'train':
            self._callbacks = [LogPrinter(self), Checkpointer(self)]
129
            if self.cfg.get('use_vdl', False):
130
                self._callbacks.append(VisualDLWriter(self))
K
Kaipeng Deng 已提交
131 132 133
            self._compose_callback = ComposeCallback(self._callbacks)
        elif self.mode == 'eval':
            self._callbacks = [LogPrinter(self)]
134 135
            if self.cfg.metric == 'WiderFace':
                self._callbacks.append(WiferFaceEval(self))
K
Kaipeng Deng 已提交
136
            self._compose_callback = ComposeCallback(self._callbacks)
137
        elif self.mode == 'test' and self.cfg.get('use_vdl', False):
138 139
            self._callbacks = [VisualDLWriter(self)]
            self._compose_callback = ComposeCallback(self._callbacks)
K
Kaipeng Deng 已提交
140 141 142 143
        else:
            self._callbacks = []
            self._compose_callback = None

K
Kaipeng Deng 已提交
144 145
    def _init_metrics(self, validate=False):
        if self.mode == 'test' or (self.mode == 'train' and not validate):
G
Guanghua Yu 已提交
146 147
            self._metrics = []
            return
148
        classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
K
Kaipeng Deng 已提交
149
        if self.cfg.metric == 'COCO':
W
wangxinxin08 已提交
150
            # TODO: bias should be unified
151
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
S
shangliang Xu 已提交
152 153
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
154
            save_prediction_only = self.cfg.get('save_prediction_only', False)
155 156 157

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
K
Kaipeng Deng 已提交
158 159
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None
160 161 162 163 164 165 166 167 168

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()

169
            IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
W
wangxinxin08 已提交
170 171
            self._metrics = [
                COCOMetric(
172
                    anno_file=anno_file,
K
Kaipeng Deng 已提交
173
                    clsid2catid=clsid2catid,
174
                    classwise=classwise,
S
shangliang Xu 已提交
175
                    output_eval=output_eval,
176
                    bias=bias,
177
                    IouType=IouType,
178
                    save_prediction_only=save_prediction_only)
W
wangxinxin08 已提交
179
            ]
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
        elif self.cfg.metric == 'RBOX':
            # TODO: bias should be unified
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
            save_prediction_only = self.cfg.get('save_prediction_only', False)

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()

            self._metrics = [
                RBoxMetric(
                    anno_file=anno_file,
                    clsid2catid=clsid2catid,
                    classwise=classwise,
                    output_eval=output_eval,
                    bias=bias,
                    save_prediction_only=save_prediction_only)
            ]
K
Kaipeng Deng 已提交
209 210 211
        elif self.cfg.metric == 'VOC':
            self._metrics = [
                VOCMetric(
212
                    label_list=self.dataset.get_label_list(),
K
Kaipeng Deng 已提交
213
                    class_num=self.cfg.num_classes,
214 215
                    map_type=self.cfg.map_type,
                    classwise=classwise)
K
Kaipeng Deng 已提交
216
            ]
217 218 219 220 221 222 223 224 225
        elif self.cfg.metric == 'WiderFace':
            multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
            self._metrics = [
                WiderFaceMetric(
                    image_dir=os.path.join(self.dataset.dataset_dir,
                                           self.dataset.image_dir),
                    anno_file=self.dataset.get_anno(),
                    multi_scale=multi_scale)
            ]
226 227 228 229
        elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
230
            save_prediction_only = self.cfg.get('save_prediction_only', False)
231
            self._metrics = [
232 233 234 235 236 237
                KeyPointTopDownCOCOEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
238
            ]
Z
zhiboniu 已提交
239 240 241 242
        elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
243
            save_prediction_only = self.cfg.get('save_prediction_only', False)
Z
zhiboniu 已提交
244
            self._metrics = [
245 246 247 248 249 250
                KeyPointTopDownMPIIEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
Z
zhiboniu 已提交
251
            ]
G
George Ni 已提交
252 253
        elif self.cfg.metric == 'MOTDet':
            self._metrics = [JDEDetMetric(), ]
K
Kaipeng Deng 已提交
254
        else:
255
            logger.warning("Metric not support for metric type {}".format(
K
Kaipeng Deng 已提交
256
                self.cfg.metric))
K
Kaipeng Deng 已提交
257 258 259 260 261 262 263
            self._metrics = []

    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def register_callbacks(self, callbacks):
264
        callbacks = [c for c in list(callbacks) if c is not None]
K
Kaipeng Deng 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277
        for c in callbacks:
            assert isinstance(c, Callback), \
                    "metrics shoule be instances of subclass of Metric"
        self._callbacks.extend(callbacks)
        self._compose_callback = ComposeCallback(self._callbacks)

    def register_metrics(self, metrics):
        metrics = [m for m in list(metrics) if m is not None]
        for m in metrics:
            assert isinstance(m, Metric), \
                    "metrics shoule be instances of subclass of Metric"
        self._metrics.extend(metrics)

K
Kaipeng Deng 已提交
278
    def load_weights(self, weights):
279 280
        if self.is_loaded_weights:
            return
K
Kaipeng Deng 已提交
281
        self.start_epoch = 0
282
        load_pretrain_weight(self.model, weights)
K
Kaipeng Deng 已提交
283 284
        logger.debug("Load weights {} to start training".format(weights))

285 286 287 288 289 290 291
    def load_weights_sde(self, det_weights, reid_weights):
        if self.model.detector:
            load_weight(self.model.detector, det_weights)
            load_weight(self.model.reid, reid_weights)
        else:
            load_weight(self.model.reid, reid_weights)

K
Kaipeng Deng 已提交
292
    def resume_weights(self, weights):
293 294 295 296 297 298
        # support Distill resume weights
        if hasattr(self.model, 'student_model'):
            self.start_epoch = load_weight(self.model.student_model, weights,
                                           self.optimizer)
        else:
            self.start_epoch = load_weight(self.model, weights, self.optimizer)
K
Kaipeng Deng 已提交
299
        logger.debug("Resume weights of epoch {}".format(self.start_epoch))
K
Kaipeng Deng 已提交
300

K
Kaipeng Deng 已提交
301
    def train(self, validate=False):
K
Kaipeng Deng 已提交
302
        assert self.mode == 'train', "Model not in 'train' mode"
Z
zhiboniu 已提交
303
        Init_mark = False
K
Kaipeng Deng 已提交
304

305
        model = self.model
306
        if self.cfg.get('fleet', False):
307
            model = fleet.distributed_model(model)
W
wangguanzhong 已提交
308
            self.optimizer = fleet.distributed_optimizer(self.optimizer)
309
        elif self._nranks > 1:
G
George Ni 已提交
310 311 312 313
            find_unused_parameters = self.cfg[
                'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
            model = paddle.DataParallel(
                self.model, find_unused_parameters=find_unused_parameters)
314 315

        # initial fp16
316
        if self.cfg.get('fp16', False):
317 318
            scaler = amp.GradScaler(
                enable=self.cfg.use_gpu, init_loss_scaling=1024)
K
Kaipeng Deng 已提交
319

K
Kaipeng Deng 已提交
320 321 322 323 324 325 326 327 328 329 330 331
        self.status.update({
            'epoch_id': self.start_epoch,
            'step_id': 0,
            'steps_per_epoch': len(self.loader)
        })

        self.status['batch_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['data_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)

G
Guanghua Yu 已提交
332 333 334
        if self.cfg.get('print_flops', False):
            self._flops(self.loader)

K
Kaipeng Deng 已提交
335
        for epoch_id in range(self.start_epoch, self.cfg.epoch):
K
Kaipeng Deng 已提交
336
            self.status['mode'] = 'train'
K
Kaipeng Deng 已提交
337 338 339
            self.status['epoch_id'] = epoch_id
            self._compose_callback.on_epoch_begin(self.status)
            self.loader.dataset.set_epoch(epoch_id)
K
Kaipeng Deng 已提交
340
            model.train()
K
Kaipeng Deng 已提交
341 342 343 344 345 346
            iter_tic = time.time()
            for step_id, data in enumerate(self.loader):
                self.status['data_time'].update(time.time() - iter_tic)
                self.status['step_id'] = step_id
                self._compose_callback.on_step_begin(self.status)

347
                if self.cfg.get('fp16', False):
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
                    with amp.auto_cast(enable=self.cfg.use_gpu):
                        # model forward
                        outputs = model(data)
                        loss = outputs['loss']

                    # model backward
                    scaled_loss = scaler.scale(loss)
                    scaled_loss.backward()
                    # in dygraph mode, optimizer.minimize is equal to optimizer.step
                    scaler.minimize(self.optimizer, scaled_loss)
                else:
                    # model forward
                    outputs = model(data)
                    loss = outputs['loss']
                    # model backward
                    loss.backward()
                    self.optimizer.step()
K
Kaipeng Deng 已提交
365 366 367 368 369 370

                curr_lr = self.optimizer.get_lr()
                self.lr.step()
                self.optimizer.clear_grad()
                self.status['learning_rate'] = curr_lr

K
Kaipeng Deng 已提交
371
                if self._nranks < 2 or self._local_rank == 0:
K
Kaipeng Deng 已提交
372 373 374 375
                    self.status['training_staus'].update(outputs)

                self.status['batch_time'].update(time.time() - iter_tic)
                self._compose_callback.on_step_end(self.status)
376 377
                if self.use_ema:
                    self.ema.update(self.model)
F
Feng Ni 已提交
378
                iter_tic = time.time()
K
Kaipeng Deng 已提交
379

380 381
            # apply ema weight on model
            if self.use_ema:
382
                weight = copy.deepcopy(self.model.state_dict())
383 384
                self.model.set_dict(self.ema.apply())

K
Kaipeng Deng 已提交
385 386
            self._compose_callback.on_epoch_end(self.status)

K
Kaipeng Deng 已提交
387
            if validate and (self._nranks < 2 or self._local_rank == 0) \
G
Guanghua Yu 已提交
388
                    and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
K
Kaipeng Deng 已提交
389 390 391 392 393 394 395 396 397 398 399 400
                             or epoch_id == self.end_epoch - 1):
                if not hasattr(self, '_eval_loader'):
                    # build evaluation dataset and loader
                    self._eval_dataset = self.cfg.EvalDataset
                    self._eval_batch_sampler = \
                        paddle.io.BatchSampler(
                            self._eval_dataset,
                            batch_size=self.cfg.EvalReader['batch_size'])
                    self._eval_loader = create('EvalReader')(
                        self._eval_dataset,
                        self.cfg.worker_num,
                        batch_sampler=self._eval_batch_sampler)
Z
zhiboniu 已提交
401 402 403 404 405 406
                # if validation in training is enabled, metrics should be re-init
                # Init_mark makes sure this code will only execute once
                if validate and Init_mark == False:
                    Init_mark = True
                    self._init_metrics(validate=validate)
                    self._reset_metrics()
K
Kaipeng Deng 已提交
407
                with paddle.no_grad():
408
                    self.status['save_best_model'] = True
K
Kaipeng Deng 已提交
409 410
                    self._eval_with_loader(self._eval_loader)

411 412 413 414
            # restore origin weight on model
            if self.use_ema:
                self.model.set_dict(weight)

K
Kaipeng Deng 已提交
415
    def _eval_with_loader(self, loader):
K
Kaipeng Deng 已提交
416 417 418
        sample_num = 0
        tic = time.time()
        self._compose_callback.on_epoch_begin(self.status)
K
Kaipeng Deng 已提交
419 420
        self.status['mode'] = 'eval'
        self.model.eval()
G
Guanghua Yu 已提交
421 422
        if self.cfg.get('print_flops', False):
            self._flops(loader)
K
Kaipeng Deng 已提交
423
        for step_id, data in enumerate(loader):
K
Kaipeng Deng 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
            self.status['step_id'] = step_id
            self._compose_callback.on_step_begin(self.status)
            # forward
            outs = self.model(data)

            # update metrics
            for metric in self._metrics:
                metric.update(data, outs)

            sample_num += data['im_id'].numpy().shape[0]
            self._compose_callback.on_step_end(self.status)

        self.status['sample_num'] = sample_num
        self.status['cost_time'] = time.time() - tic

        # accumulate metric to log out
        for metric in self._metrics:
            metric.accumulate()
            metric.log()
443
        self._compose_callback.on_epoch_end(self.status)
K
Kaipeng Deng 已提交
444 445 446
        # reset metric states for metric may performed multiple times
        self._reset_metrics()

K
Kaipeng Deng 已提交
447
    def evaluate(self):
448 449
        with paddle.no_grad():
            self._eval_with_loader(self.loader)
K
Kaipeng Deng 已提交
450

C
cnn 已提交
451 452 453 454 455
    def predict(self,
                images,
                draw_threshold=0.5,
                output_dir='output',
                save_txt=False):
K
Kaipeng Deng 已提交
456 457 458 459 460 461
        self.dataset.set_images(images)
        loader = create('TestReader')(self.dataset, 0)

        imid2path = self.dataset.get_imid2path()

        anno_file = self.dataset.get_anno()
C
cnn 已提交
462 463
        clsid2catid, catid2name = get_categories(
            self.cfg.metric, anno_file=anno_file)
K
Kaipeng Deng 已提交
464

K
Kaipeng Deng 已提交
465 466 467
        # Run Infer 
        self.status['mode'] = 'test'
        self.model.eval()
G
Guanghua Yu 已提交
468 469
        if self.cfg.get('print_flops', False):
            self._flops(loader)
K
Kaipeng Deng 已提交
470 471 472 473
        for step_id, data in enumerate(loader):
            self.status['step_id'] = step_id
            # forward
            outs = self.model(data)
474

K
Kaipeng Deng 已提交
475 476
            for key in ['im_shape', 'scale_factor', 'im_id']:
                outs[key] = data[key]
G
Guanghua Yu 已提交
477
            for key, value in outs.items():
478 479
                if hasattr(value, 'numpy'):
                    outs[key] = value.numpy()
K
Kaipeng Deng 已提交
480 481 482

            batch_res = get_infer_results(outs, clsid2catid)
            bbox_num = outs['bbox_num']
Z
zhiboniu 已提交
483

K
Kaipeng Deng 已提交
484 485 486 487
            start = 0
            for i, im_id in enumerate(outs['im_id']):
                image_path = imid2path[int(im_id)]
                image = Image.open(image_path).convert('RGB')
488
                self.status['original_image'] = np.array(image.copy())
K
Kaipeng Deng 已提交
489

490
                end = start + bbox_num[i]
K
Kaipeng Deng 已提交
491 492 493 494
                bbox_res = batch_res['bbox'][start:end] \
                        if 'bbox' in batch_res else None
                mask_res = batch_res['mask'][start:end] \
                        if 'mask' in batch_res else None
G
Guanghua Yu 已提交
495 496
                segm_res = batch_res['segm'][start:end] \
                        if 'segm' in batch_res else None
497 498 499 500
                keypoint_res = batch_res['keypoint'][start:end] \
                        if 'keypoint' in batch_res else None
                image = visualize_results(
                    image, bbox_res, mask_res, segm_res, keypoint_res,
C
cnn 已提交
501
                    int(im_id), catid2name, draw_threshold)
502
                self.status['result_image'] = np.array(image.copy())
503 504
                if self._compose_callback:
                    self._compose_callback.on_step_end(self.status)
K
Kaipeng Deng 已提交
505 506 507 508 509
                # save image with detection
                save_name = self._get_save_image_name(output_dir, image_path)
                logger.info("Detection bbox results save in {}".format(
                    save_name))
                image.save(save_name, quality=95)
C
cnn 已提交
510 511
                if save_txt:
                    save_path = os.path.splitext(save_name)[0] + '.txt'
512 513 514 515 516 517 518
                    results = {}
                    results["im_id"] = im_id
                    if bbox_res:
                        results["bbox_res"] = bbox_res
                    if keypoint_res:
                        results["keypoint_res"] = keypoint_res
                    save_result(save_path, results, catid2name, draw_threshold)
K
Kaipeng Deng 已提交
519 520 521 522 523 524 525 526 527 528 529 530 531
                start = end

    def _get_save_image_name(self, output_dir, image_path):
        """
        Get save image name from source image path.
        """
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        image_name = os.path.split(image_path)[-1]
        name, ext = os.path.splitext(image_name)
        return os.path.join(output_dir, "{}".format(name)) + ext

    def export(self, output_dir='output_inference'):
532
        self.model.eval()
K
Kaipeng Deng 已提交
533 534 535 536 537
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        image_shape = None
538 539 540 541 542 543
        if self.cfg.architecture in MOT_ARCH:
            test_reader_name = 'TestMOTReader'
        else:
            test_reader_name = 'TestReader'
        if 'inputs_def' in self.cfg[test_reader_name]:
            inputs_def = self.cfg[test_reader_name]['inputs_def']
K
Kaipeng Deng 已提交
544
            image_shape = inputs_def.get('image_shape', None)
545
        # set image_shape=[3, -1, -1] as default
K
Kaipeng Deng 已提交
546
        if image_shape is None:
547
            image_shape = [3, -1, -1]
K
Kaipeng Deng 已提交
548

K
Kaipeng Deng 已提交
549
        self.model.eval()
550
        if hasattr(self.model, 'deploy'): self.model.deploy = True
K
Kaipeng Deng 已提交
551

K
Kaipeng Deng 已提交
552 553 554 555 556 557 558 559 560 561 562 563 564
        # Save infer cfg
        _dump_infer_config(self.cfg,
                           os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
                           self.model)

        input_spec = [{
            "image": InputSpec(
                shape=[None] + image_shape, name='image'),
            "im_shape": InputSpec(
                shape=[None, 2], name='im_shape'),
            "scale_factor": InputSpec(
                shape=[None, 2], name='scale_factor')
        }]
G
George Ni 已提交
565 566 567 568 569
        if self.cfg.architecture == 'DeepSORT':
            input_spec[0].update({
                "crops": InputSpec(
                    shape=[None, 3, 192, 64], name='crops')
            })
K
Kaipeng Deng 已提交
570

Z
zhiboniu 已提交
571
        static_model = paddle.jit.to_static(self.model, input_spec=input_spec)
G
Guanghua Yu 已提交
572 573 574
        # NOTE: dy2st do not pruned program, but jit.save will prune program
        # input spec, prune input spec here and save with pruned input spec
        pruned_input_spec = self._prune_input_spec(
Z
zhiboniu 已提交
575 576
            input_spec, static_model.forward.main_program,
            static_model.forward.outputs)
G
Guanghua Yu 已提交
577 578 579

        # dy2st and save model
        if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
580 581 582 583 584
            paddle.jit.save(
                static_model,
                os.path.join(save_dir, 'model'),
                input_spec=pruned_input_spec)
        else:
585
            self.cfg.slim.save_quantized_model(
586 587
                self.model,
                os.path.join(save_dir, 'model'),
G
Guanghua Yu 已提交
588 589
                input_spec=pruned_input_spec)
        logger.info("Export model and saved in {}".format(save_dir))
590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606

    def _prune_input_spec(self, input_spec, program, targets):
        # try to prune static program to figure out pruned input spec
        # so we perform following operations in static mode
        paddle.enable_static()
        pruned_input_spec = [{}]
        program = program.clone()
        program = program._prune(targets=targets)
        global_block = program.global_block()
        for name, spec in input_spec[0].items():
            try:
                v = global_block.var(name)
                pruned_input_spec[0][name] = spec
            except Exception:
                pass
        paddle.disable_static()
        return pruned_input_spec
G
Guanghua Yu 已提交
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631

    def _flops(self, loader):
        self.model.eval()
        try:
            import paddleslim
        except Exception as e:
            logger.warning(
                'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
            )
            return

        from paddleslim.analysis import dygraph_flops as flops
        input_data = None
        for data in loader:
            input_data = data
            break

        input_spec = [{
            "image": input_data['image'][0].unsqueeze(0),
            "im_shape": input_data['im_shape'][0].unsqueeze(0),
            "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
        }]
        flops = flops(self.model, input_spec) / (1000**3)
        logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
            flops, input_data['image'][0].unsqueeze(0).shape))