trainer.py 29.4 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
K
Kaipeng Deng 已提交
13 14 15 16 17 18 19
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
G
George Ni 已提交
20
import sys
21
import copy
K
Kaipeng Deng 已提交
22
import time
M
Manuel Garcia 已提交
23

K
Kaipeng Deng 已提交
24
import numpy as np
F
Feng Ni 已提交
25 26
from PIL import Image, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
K
Kaipeng Deng 已提交
27 28

import paddle
W
wangguanzhong 已提交
29 30
import paddle.distributed as dist
from paddle.distributed import fleet
31
from paddle import amp
K
Kaipeng Deng 已提交
32
from paddle.static import InputSpec
33
from ppdet.optimizer import ModelEMA
K
Kaipeng Deng 已提交
34 35 36

from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
C
cnn 已提交
37
from ppdet.utils.visualizer import visualize_results, save_result
Z
zhiboniu 已提交
38
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
39 40
from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
from ppdet.data.source.sniper_coco import SniperCOCODataSet
K
Kaipeng Deng 已提交
41
from ppdet.data.source.category import get_categories
K
Kaipeng Deng 已提交
42
import ppdet.utils.stats as stats
43
from ppdet.utils import profiler
K
Kaipeng Deng 已提交
44

45
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
G
Guanghua Yu 已提交
46
from .export_utils import _dump_infer_config, _prune_input_spec
K
Kaipeng Deng 已提交
47 48

from ppdet.utils.logger import setup_logger
49
logger = setup_logger('ppdet.engine')
K
Kaipeng Deng 已提交
50 51 52

__all__ = ['Trainer']

53 54
MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']

K
Kaipeng Deng 已提交
55 56 57 58 59 60 61

class Trainer(object):
    def __init__(self, cfg, mode='train'):
        self.cfg = cfg
        assert mode.lower() in ['train', 'eval', 'test'], \
                "mode should be 'train', 'eval' or 'test'"
        self.mode = mode.lower()
62
        self.optimizer = None
63
        self.is_loaded_weights = False
K
Kaipeng Deng 已提交
64

G
George Ni 已提交
65
        # build data loader
W
wangguanzhong 已提交
66
        capital_mode = self.mode.capitalize()
67
        if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
W
wangguanzhong 已提交
68 69
            self.dataset = self.cfg['{}MOTDataset'.format(
                capital_mode)] = create('{}MOTDataset'.format(capital_mode))()
70
        else:
W
wangguanzhong 已提交
71 72
            self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create(
                '{}Dataset'.format(capital_mode))()
73 74 75 76 77

        if cfg.architecture == 'DeepSORT' and self.mode == 'train':
            logger.error('DeepSORT has no need of training on mot dataset.')
            sys.exit(1)

G
George Ni 已提交
78
        if self.mode == 'train':
W
wangguanzhong 已提交
79
            self.loader = create('{}Reader'.format(capital_mode))(
G
George Ni 已提交
80 81 82 83
                self.dataset, cfg.worker_num)

        if cfg.architecture == 'JDE' and self.mode == 'train':
            cfg['JDEEmbeddingHead'][
84 85
                'num_identities'] = self.dataset.num_identities_dict[0]
            # JDE only support single class MOT now.
G
George Ni 已提交
86

F
FlyingQianMM 已提交
87
        if cfg.architecture == 'FairMOT' and self.mode == 'train':
M
minghaoBD 已提交
88 89
            cfg['FairMOTEmbeddingHead'][
                'num_identities_dict'] = self.dataset.num_identities_dict
90
            # FairMOT support single class and multi-class MOT now.
F
FlyingQianMM 已提交
91

K
Kaipeng Deng 已提交
92
        # build model
93 94 95 96 97
        if 'model' not in self.cfg:
            self.model = create(cfg.architecture)
        else:
            self.model = self.cfg.model
            self.is_loaded_weights = True
98

99 100 101
        #normalize params for deploy
        self.model.load_meanstd(cfg['TestReader']['sample_transforms'])

102 103
        self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
        if self.use_ema:
G
Guanghua Yu 已提交
104 105
            ema_decay = self.cfg.get('ema_decay', 0.9998)
            cycle_epoch = self.cfg.get('cycle_epoch', -1)
106
            self.ema = ModelEMA(
G
Guanghua Yu 已提交
107 108 109 110
                self.model,
                decay=ema_decay,
                use_thres_step=True,
                cycle_epoch=cycle_epoch)
111

K
Kaipeng Deng 已提交
112 113 114 115 116 117 118 119
        # EvalDataset build with BatchSampler to evaluate in single device
        # TODO: multi-device evaluate
        if self.mode == 'eval':
            self._eval_batch_sampler = paddle.io.BatchSampler(
                self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num, self._eval_batch_sampler)
        # TestDataset build after user set images, skip loader creation here
K
Kaipeng Deng 已提交
120 121 122 123 124

        # build optimizer in train mode
        if self.mode == 'train':
            steps_per_epoch = len(self.loader)
            self.lr = create('LearningRate')(steps_per_epoch)
W
Wenyu 已提交
125
            self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
K
Kaipeng Deng 已提交
126

M
minghaoBD 已提交
127 128 129 130
        if self.cfg.get('unstructured_prune'):
            self.pruner = create('UnstructuredPruner')(self.model,
                                                       steps_per_epoch)

W
wangguanzhong 已提交
131 132
        self._nranks = dist.get_world_size()
        self._local_rank = dist.get_rank()
K
Kaipeng Deng 已提交
133

K
Kaipeng Deng 已提交
134 135 136
        self.status = {}

        self.start_epoch = 0
G
George Ni 已提交
137
        self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
K
Kaipeng Deng 已提交
138 139 140 141 142 143 144 145 146 147 148

        # initial default callbacks
        self._init_callbacks()

        # initial default metrics
        self._init_metrics()
        self._reset_metrics()

    def _init_callbacks(self):
        if self.mode == 'train':
            self._callbacks = [LogPrinter(self), Checkpointer(self)]
149
            if self.cfg.get('use_vdl', False):
150
                self._callbacks.append(VisualDLWriter(self))
151 152
            if self.cfg.get('save_proposals', False):
                self._callbacks.append(SniperProposalsGenerator(self))
K
Kaipeng Deng 已提交
153 154 155
            self._compose_callback = ComposeCallback(self._callbacks)
        elif self.mode == 'eval':
            self._callbacks = [LogPrinter(self)]
156 157
            if self.cfg.metric == 'WiderFace':
                self._callbacks.append(WiferFaceEval(self))
K
Kaipeng Deng 已提交
158
            self._compose_callback = ComposeCallback(self._callbacks)
159
        elif self.mode == 'test' and self.cfg.get('use_vdl', False):
160 161
            self._callbacks = [VisualDLWriter(self)]
            self._compose_callback = ComposeCallback(self._callbacks)
K
Kaipeng Deng 已提交
162 163 164 165
        else:
            self._callbacks = []
            self._compose_callback = None

K
Kaipeng Deng 已提交
166 167
    def _init_metrics(self, validate=False):
        if self.mode == 'test' or (self.mode == 'train' and not validate):
G
Guanghua Yu 已提交
168 169
            self._metrics = []
            return
170
        classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
171
        if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO":
W
wangxinxin08 已提交
172
            # TODO: bias should be unified
173
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
S
shangliang Xu 已提交
174 175
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
176
            save_prediction_only = self.cfg.get('save_prediction_only', False)
177 178 179

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
K
Kaipeng Deng 已提交
180 181
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None
182 183 184 185

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
186
            dataset = self.dataset
187 188 189 190
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()
191
                dataset = eval_dataset
192

193
            IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
194 195 196 197 198 199 200 201 202 203 204
            if self.cfg.metric == "COCO":
                self._metrics = [
                    COCOMetric(
                        anno_file=anno_file,
                        clsid2catid=clsid2catid,
                        classwise=classwise,
                        output_eval=output_eval,
                        bias=bias,
                        IouType=IouType,
                        save_prediction_only=save_prediction_only)
                ]
205
            elif self.cfg.metric == "SNIPERCOCO":  # sniper
206 207 208 209 210 211 212 213 214
                self._metrics = [
                    SNIPERCOCOMetric(
                        anno_file=anno_file,
                        dataset=dataset,
                        clsid2catid=clsid2catid,
                        classwise=classwise,
                        output_eval=output_eval,
                        bias=bias,
                        IouType=IouType,
215
                        save_prediction_only=save_prediction_only)
216
                ]
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
        elif self.cfg.metric == 'RBOX':
            # TODO: bias should be unified
            bias = self.cfg['bias'] if 'bias' in self.cfg else 0
            output_eval = self.cfg['output_eval'] \
                if 'output_eval' in self.cfg else None
            save_prediction_only = self.cfg.get('save_prediction_only', False)

            # pass clsid2catid info to metric instance to avoid multiple loading
            # annotation file
            clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
                                if self.mode == 'eval' else None

            # when do validation in train, annotation file should be get from
            # EvalReader instead of self.dataset(which is TrainReader)
            anno_file = self.dataset.get_anno()
            if self.mode == 'train' and validate:
                eval_dataset = self.cfg['EvalDataset']
                eval_dataset.check_or_download_dataset()
                anno_file = eval_dataset.get_anno()

            self._metrics = [
                RBoxMetric(
                    anno_file=anno_file,
                    clsid2catid=clsid2catid,
                    classwise=classwise,
                    output_eval=output_eval,
                    bias=bias,
                    save_prediction_only=save_prediction_only)
            ]
K
Kaipeng Deng 已提交
246 247 248
        elif self.cfg.metric == 'VOC':
            self._metrics = [
                VOCMetric(
249
                    label_list=self.dataset.get_label_list(),
K
Kaipeng Deng 已提交
250
                    class_num=self.cfg.num_classes,
251 252
                    map_type=self.cfg.map_type,
                    classwise=classwise)
K
Kaipeng Deng 已提交
253
            ]
254 255 256 257 258 259 260 261 262
        elif self.cfg.metric == 'WiderFace':
            multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
            self._metrics = [
                WiderFaceMetric(
                    image_dir=os.path.join(self.dataset.dataset_dir,
                                           self.dataset.image_dir),
                    anno_file=self.dataset.get_anno(),
                    multi_scale=multi_scale)
            ]
263 264 265 266
        elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
267
            save_prediction_only = self.cfg.get('save_prediction_only', False)
268
            self._metrics = [
269 270 271 272 273 274
                KeyPointTopDownCOCOEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
275
            ]
Z
zhiboniu 已提交
276 277 278 279
        elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
            eval_dataset = self.cfg['EvalDataset']
            eval_dataset.check_or_download_dataset()
            anno_file = eval_dataset.get_anno()
280
            save_prediction_only = self.cfg.get('save_prediction_only', False)
Z
zhiboniu 已提交
281
            self._metrics = [
282 283 284 285 286 287
                KeyPointTopDownMPIIEval(
                    anno_file,
                    len(eval_dataset),
                    self.cfg.num_joints,
                    self.cfg.save_dir,
                    save_prediction_only=save_prediction_only)
Z
zhiboniu 已提交
288
            ]
G
George Ni 已提交
289 290
        elif self.cfg.metric == 'MOTDet':
            self._metrics = [JDEDetMetric(), ]
K
Kaipeng Deng 已提交
291
        else:
292
            logger.warning("Metric not support for metric type {}".format(
K
Kaipeng Deng 已提交
293
                self.cfg.metric))
K
Kaipeng Deng 已提交
294 295 296 297 298 299 300
            self._metrics = []

    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def register_callbacks(self, callbacks):
301
        callbacks = [c for c in list(callbacks) if c is not None]
K
Kaipeng Deng 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314
        for c in callbacks:
            assert isinstance(c, Callback), \
                    "metrics shoule be instances of subclass of Metric"
        self._callbacks.extend(callbacks)
        self._compose_callback = ComposeCallback(self._callbacks)

    def register_metrics(self, metrics):
        metrics = [m for m in list(metrics) if m is not None]
        for m in metrics:
            assert isinstance(m, Metric), \
                    "metrics shoule be instances of subclass of Metric"
        self._metrics.extend(metrics)

K
Kaipeng Deng 已提交
315
    def load_weights(self, weights):
316 317
        if self.is_loaded_weights:
            return
K
Kaipeng Deng 已提交
318
        self.start_epoch = 0
319
        load_pretrain_weight(self.model, weights)
K
Kaipeng Deng 已提交
320 321
        logger.debug("Load weights {} to start training".format(weights))

322 323 324 325 326 327 328
    def load_weights_sde(self, det_weights, reid_weights):
        if self.model.detector:
            load_weight(self.model.detector, det_weights)
            load_weight(self.model.reid, reid_weights)
        else:
            load_weight(self.model.reid, reid_weights)

K
Kaipeng Deng 已提交
329
    def resume_weights(self, weights):
330 331 332 333 334 335
        # support Distill resume weights
        if hasattr(self.model, 'student_model'):
            self.start_epoch = load_weight(self.model.student_model, weights,
                                           self.optimizer)
        else:
            self.start_epoch = load_weight(self.model, weights, self.optimizer)
K
Kaipeng Deng 已提交
336
        logger.debug("Resume weights of epoch {}".format(self.start_epoch))
K
Kaipeng Deng 已提交
337

K
Kaipeng Deng 已提交
338
    def train(self, validate=False):
K
Kaipeng Deng 已提交
339
        assert self.mode == 'train', "Model not in 'train' mode"
Z
zhiboniu 已提交
340
        Init_mark = False
W
wangguanzhong 已提交
341
        if validate:
W
wangguanzhong 已提交
342 343
            self.cfg['EvalDataset'] = self.cfg.EvalDataset = create(
                "EvalDataset")()
K
Kaipeng Deng 已提交
344

345
        model = self.model
346
        if self.cfg.get('fleet', False):
347
            model = fleet.distributed_model(model)
W
wangguanzhong 已提交
348
            self.optimizer = fleet.distributed_optimizer(self.optimizer)
349
        elif self._nranks > 1:
G
George Ni 已提交
350 351 352 353
            find_unused_parameters = self.cfg[
                'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
            model = paddle.DataParallel(
                self.model, find_unused_parameters=find_unused_parameters)
354 355

        # initial fp16
356
        if self.cfg.get('fp16', False):
357 358
            scaler = amp.GradScaler(
                enable=self.cfg.use_gpu, init_loss_scaling=1024)
K
Kaipeng Deng 已提交
359

K
Kaipeng Deng 已提交
360 361 362 363 364 365 366 367 368 369 370 371
        self.status.update({
            'epoch_id': self.start_epoch,
            'step_id': 0,
            'steps_per_epoch': len(self.loader)
        })

        self.status['batch_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['data_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)

G
Guanghua Yu 已提交
372
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
373 374 375
            flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, self.cfg.worker_num)
            self._flops(flops_loader)
376
        profiler_options = self.cfg.get('profiler_options', None)
G
Guanghua Yu 已提交
377

378 379
        self._compose_callback.on_train_begin(self.status)

K
Kaipeng Deng 已提交
380
        for epoch_id in range(self.start_epoch, self.cfg.epoch):
K
Kaipeng Deng 已提交
381
            self.status['mode'] = 'train'
K
Kaipeng Deng 已提交
382 383 384
            self.status['epoch_id'] = epoch_id
            self._compose_callback.on_epoch_begin(self.status)
            self.loader.dataset.set_epoch(epoch_id)
K
Kaipeng Deng 已提交
385
            model.train()
K
Kaipeng Deng 已提交
386 387 388 389
            iter_tic = time.time()
            for step_id, data in enumerate(self.loader):
                self.status['data_time'].update(time.time() - iter_tic)
                self.status['step_id'] = step_id
390
                profiler.add_profiler_step(profiler_options)
K
Kaipeng Deng 已提交
391
                self._compose_callback.on_step_begin(self.status)
S
shangliang Xu 已提交
392
                data['epoch_id'] = epoch_id
K
Kaipeng Deng 已提交
393

394
                if self.cfg.get('fp16', False):
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
                    with amp.auto_cast(enable=self.cfg.use_gpu):
                        # model forward
                        outputs = model(data)
                        loss = outputs['loss']

                    # model backward
                    scaled_loss = scaler.scale(loss)
                    scaled_loss.backward()
                    # in dygraph mode, optimizer.minimize is equal to optimizer.step
                    scaler.minimize(self.optimizer, scaled_loss)
                else:
                    # model forward
                    outputs = model(data)
                    loss = outputs['loss']
                    # model backward
                    loss.backward()
                    self.optimizer.step()
K
Kaipeng Deng 已提交
412 413
                curr_lr = self.optimizer.get_lr()
                self.lr.step()
M
minghaoBD 已提交
414 415
                if self.cfg.get('unstructured_prune'):
                    self.pruner.step()
K
Kaipeng Deng 已提交
416 417 418
                self.optimizer.clear_grad()
                self.status['learning_rate'] = curr_lr

K
Kaipeng Deng 已提交
419
                if self._nranks < 2 or self._local_rank == 0:
K
Kaipeng Deng 已提交
420 421 422 423
                    self.status['training_staus'].update(outputs)

                self.status['batch_time'].update(time.time() - iter_tic)
                self._compose_callback.on_step_end(self.status)
424 425
                if self.use_ema:
                    self.ema.update(self.model)
F
Feng Ni 已提交
426
                iter_tic = time.time()
K
Kaipeng Deng 已提交
427

428 429
            # apply ema weight on model
            if self.use_ema:
430
                weight = copy.deepcopy(self.model.state_dict())
431
                self.model.set_dict(self.ema.apply())
M
minghaoBD 已提交
432 433
            if self.cfg.get('unstructured_prune'):
                self.pruner.update_params()
434

K
Kaipeng Deng 已提交
435 436
            self._compose_callback.on_epoch_end(self.status)

K
Kaipeng Deng 已提交
437
            if validate and (self._nranks < 2 or self._local_rank == 0) \
G
Guanghua Yu 已提交
438
                    and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
K
Kaipeng Deng 已提交
439 440 441 442 443 444 445 446 447 448 449 450
                             or epoch_id == self.end_epoch - 1):
                if not hasattr(self, '_eval_loader'):
                    # build evaluation dataset and loader
                    self._eval_dataset = self.cfg.EvalDataset
                    self._eval_batch_sampler = \
                        paddle.io.BatchSampler(
                            self._eval_dataset,
                            batch_size=self.cfg.EvalReader['batch_size'])
                    self._eval_loader = create('EvalReader')(
                        self._eval_dataset,
                        self.cfg.worker_num,
                        batch_sampler=self._eval_batch_sampler)
Z
zhiboniu 已提交
451 452 453 454 455 456
                # if validation in training is enabled, metrics should be re-init
                # Init_mark makes sure this code will only execute once
                if validate and Init_mark == False:
                    Init_mark = True
                    self._init_metrics(validate=validate)
                    self._reset_metrics()
K
Kaipeng Deng 已提交
457
                with paddle.no_grad():
458
                    self.status['save_best_model'] = True
K
Kaipeng Deng 已提交
459 460
                    self._eval_with_loader(self._eval_loader)

461 462 463 464
            # restore origin weight on model
            if self.use_ema:
                self.model.set_dict(weight)

465 466
        self._compose_callback.on_train_end(self.status)

K
Kaipeng Deng 已提交
467
    def _eval_with_loader(self, loader):
K
Kaipeng Deng 已提交
468 469 470
        sample_num = 0
        tic = time.time()
        self._compose_callback.on_epoch_begin(self.status)
K
Kaipeng Deng 已提交
471 472
        self.status['mode'] = 'eval'
        self.model.eval()
G
Guanghua Yu 已提交
473
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
474 475 476
            flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
            self._flops(flops_loader)
K
Kaipeng Deng 已提交
477
        for step_id, data in enumerate(loader):
K
Kaipeng Deng 已提交
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
            self.status['step_id'] = step_id
            self._compose_callback.on_step_begin(self.status)
            # forward
            outs = self.model(data)

            # update metrics
            for metric in self._metrics:
                metric.update(data, outs)

            sample_num += data['im_id'].numpy().shape[0]
            self._compose_callback.on_step_end(self.status)

        self.status['sample_num'] = sample_num
        self.status['cost_time'] = time.time() - tic

        # accumulate metric to log out
        for metric in self._metrics:
            metric.accumulate()
            metric.log()
497
        self._compose_callback.on_epoch_end(self.status)
K
Kaipeng Deng 已提交
498 499 500
        # reset metric states for metric may performed multiple times
        self._reset_metrics()

K
Kaipeng Deng 已提交
501
    def evaluate(self):
502 503
        with paddle.no_grad():
            self._eval_with_loader(self.loader)
K
Kaipeng Deng 已提交
504

C
cnn 已提交
505 506 507 508 509
    def predict(self,
                images,
                draw_threshold=0.5,
                output_dir='output',
                save_txt=False):
K
Kaipeng Deng 已提交
510 511 512 513 514 515
        self.dataset.set_images(images)
        loader = create('TestReader')(self.dataset, 0)

        imid2path = self.dataset.get_imid2path()

        anno_file = self.dataset.get_anno()
C
cnn 已提交
516 517
        clsid2catid, catid2name = get_categories(
            self.cfg.metric, anno_file=anno_file)
K
Kaipeng Deng 已提交
518

K
Kaipeng Deng 已提交
519 520 521
        # Run Infer 
        self.status['mode'] = 'test'
        self.model.eval()
G
Guanghua Yu 已提交
522
        if self.cfg.get('print_flops', False):
G
Guanghua Yu 已提交
523 524
            flops_loader = create('TestReader')(self.dataset, 0)
            self._flops(flops_loader)
525
        results = []
K
Kaipeng Deng 已提交
526 527 528 529
        for step_id, data in enumerate(loader):
            self.status['step_id'] = step_id
            # forward
            outs = self.model(data)
530

K
Kaipeng Deng 已提交
531 532
            for key in ['im_shape', 'scale_factor', 'im_id']:
                outs[key] = data[key]
G
Guanghua Yu 已提交
533
            for key, value in outs.items():
534 535
                if hasattr(value, 'numpy'):
                    outs[key] = value.numpy()
536 537 538
            results.append(outs)
        # sniper
        if type(self.dataset) == SniperCOCODataSet:
539 540
            results = self.dataset.anno_cropper.aggregate_chips_detections(
                results)
K
Kaipeng Deng 已提交
541

542
        for outs in results:
K
Kaipeng Deng 已提交
543 544
            batch_res = get_infer_results(outs, clsid2catid)
            bbox_num = outs['bbox_num']
Z
zhiboniu 已提交
545

K
Kaipeng Deng 已提交
546 547 548 549
            start = 0
            for i, im_id in enumerate(outs['im_id']):
                image_path = imid2path[int(im_id)]
                image = Image.open(image_path).convert('RGB')
550
                image = ImageOps.exif_transpose(image)
551
                self.status['original_image'] = np.array(image.copy())
K
Kaipeng Deng 已提交
552

553
                end = start + bbox_num[i]
K
Kaipeng Deng 已提交
554 555 556 557
                bbox_res = batch_res['bbox'][start:end] \
                        if 'bbox' in batch_res else None
                mask_res = batch_res['mask'][start:end] \
                        if 'mask' in batch_res else None
G
Guanghua Yu 已提交
558 559
                segm_res = batch_res['segm'][start:end] \
                        if 'segm' in batch_res else None
560 561 562 563
                keypoint_res = batch_res['keypoint'][start:end] \
                        if 'keypoint' in batch_res else None
                image = visualize_results(
                    image, bbox_res, mask_res, segm_res, keypoint_res,
C
cnn 已提交
564
                    int(im_id), catid2name, draw_threshold)
565
                self.status['result_image'] = np.array(image.copy())
566 567
                if self._compose_callback:
                    self._compose_callback.on_step_end(self.status)
K
Kaipeng Deng 已提交
568 569 570 571 572
                # save image with detection
                save_name = self._get_save_image_name(output_dir, image_path)
                logger.info("Detection bbox results save in {}".format(
                    save_name))
                image.save(save_name, quality=95)
C
cnn 已提交
573 574
                if save_txt:
                    save_path = os.path.splitext(save_name)[0] + '.txt'
575 576 577 578 579 580 581
                    results = {}
                    results["im_id"] = im_id
                    if bbox_res:
                        results["bbox_res"] = bbox_res
                    if keypoint_res:
                        results["keypoint_res"] = keypoint_res
                    save_result(save_path, results, catid2name, draw_threshold)
K
Kaipeng Deng 已提交
582 583 584 585 586 587 588 589 590 591 592 593
                start = end

    def _get_save_image_name(self, output_dir, image_path):
        """
        Get save image name from source image path.
        """
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        image_name = os.path.split(image_path)[-1]
        name, ext = os.path.splitext(image_name)
        return os.path.join(output_dir, "{}".format(name)) + ext

G
Guanghua Yu 已提交
594
    def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
K
Kaipeng Deng 已提交
595
        image_shape = None
596 597
        im_shape = [None, 2]
        scale_factor = [None, 2]
598 599 600 601 602 603
        if self.cfg.architecture in MOT_ARCH:
            test_reader_name = 'TestMOTReader'
        else:
            test_reader_name = 'TestReader'
        if 'inputs_def' in self.cfg[test_reader_name]:
            inputs_def = self.cfg[test_reader_name]['inputs_def']
K
Kaipeng Deng 已提交
604
            image_shape = inputs_def.get('image_shape', None)
G
Guanghua Yu 已提交
605
        # set image_shape=[None, 3, -1, -1] as default
K
Kaipeng Deng 已提交
606
        if image_shape is None:
G
Guanghua Yu 已提交
607
            image_shape = [None, 3, -1, -1]
608

G
Guanghua Yu 已提交
609 610
        if len(image_shape) == 3:
            image_shape = [None] + image_shape
611 612 613
        else:
            im_shape = [image_shape[0], 2]
            scale_factor = [image_shape[0], 2]
K
Kaipeng Deng 已提交
614

615 616 617 618 619
        if hasattr(self.model, 'deploy'):
            self.model.deploy = True
        if hasattr(self.model, 'fuse_norm'):
            self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
                                                              False)
K
Kaipeng Deng 已提交
620

K
Kaipeng Deng 已提交
621 622 623 624 625 626 627
        # Save infer cfg
        _dump_infer_config(self.cfg,
                           os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
                           self.model)

        input_spec = [{
            "image": InputSpec(
G
Guanghua Yu 已提交
628
                shape=image_shape, name='image'),
K
Kaipeng Deng 已提交
629
            "im_shape": InputSpec(
630
                shape=im_shape, name='im_shape'),
K
Kaipeng Deng 已提交
631
            "scale_factor": InputSpec(
632
                shape=scale_factor, name='scale_factor')
K
Kaipeng Deng 已提交
633
        }]
G
George Ni 已提交
634 635 636 637 638
        if self.cfg.architecture == 'DeepSORT':
            input_spec[0].update({
                "crops": InputSpec(
                    shape=[None, 3, 192, 64], name='crops')
            })
G
Guanghua Yu 已提交
639 640 641 642 643 644 645 646 647 648 649 650
        if prune_input:
            static_model = paddle.jit.to_static(
                self.model, input_spec=input_spec)
            # NOTE: dy2st do not pruned program, but jit.save will prune program
            # input spec, prune input spec here and save with pruned input spec
            pruned_input_spec = _prune_input_spec(
                input_spec, static_model.forward.main_program,
                static_model.forward.outputs)
        else:
            static_model = None
            pruned_input_spec = input_spec

G
Guanghua Yu 已提交
651 652 653 654 655 656 657
        # TODO: Hard code, delete it when support prune input_spec.
        if self.cfg.architecture == 'PicoDet':
            pruned_input_spec = [{
                "image": InputSpec(
                    shape=image_shape, name='image')
            }]

G
Guanghua Yu 已提交
658 659 660 661 662 663 664 665
        return static_model, pruned_input_spec

    def export(self, output_dir='output_inference'):
        self.model.eval()
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
K
Kaipeng Deng 已提交
666

G
Guanghua Yu 已提交
667 668
        static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir)
G
Guanghua Yu 已提交
669 670 671

        # dy2st and save model
        if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
672 673 674 675 676
            paddle.jit.save(
                static_model,
                os.path.join(save_dir, 'model'),
                input_spec=pruned_input_spec)
        else:
677
            self.cfg.slim.save_quantized_model(
678 679
                self.model,
                os.path.join(save_dir, 'model'),
G
Guanghua Yu 已提交
680 681
                input_spec=pruned_input_spec)
        logger.info("Export model and saved in {}".format(save_dir))
682

G
Guanghua Yu 已提交
683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702
    def post_quant(self, output_dir='output_inference'):
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for idx, data in enumerate(self.loader):
            self.model(data)
            if idx == int(self.cfg.get('quant_batch_num', 10)):
                break

        # TODO: support prune input_spec
        _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
            save_dir, prune_input=False)

        self.cfg.slim.save_quantized_model(
            self.model,
            os.path.join(save_dir, 'model'),
            input_spec=pruned_input_spec)
        logger.info("Export Post-Quant model and saved in {}".format(save_dir))
G
Guanghua Yu 已提交
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727

    def _flops(self, loader):
        self.model.eval()
        try:
            import paddleslim
        except Exception as e:
            logger.warning(
                'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
            )
            return

        from paddleslim.analysis import dygraph_flops as flops
        input_data = None
        for data in loader:
            input_data = data
            break

        input_spec = [{
            "image": input_data['image'][0].unsqueeze(0),
            "im_shape": input_data['im_shape'][0].unsqueeze(0),
            "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
        }]
        flops = flops(self.model, input_spec) / (1000**3)
        logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
            flops, input_data['image'][0].unsqueeze(0).shape))