infer.py 29.9 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
G
Guanghua Yu 已提交
17
import glob
Q
qingqing01 已提交
18 19 20 21
from functools import reduce

import cv2
import numpy as np
C
cnn 已提交
22
import math
Q
qingqing01 已提交
23 24 25 26
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor

27
from benchmark_utils import PaddleInferBenchmark
28
from picodet_postprocess import PicoDetPostProcess
G
George Ni 已提交
29
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize
G
Guanghua Yu 已提交
30
from visualize import visualize_box_mask
31
from utils import argsparser, Timer, get_current_memory_mb
G
Guanghua Yu 已提交
32

Q
qingqing01 已提交
33 34 35 36 37
# Global dictionary
SUPPORT_MODELS = {
    'YOLO',
    'RCNN',
    'SSD',
38
    'Face',
F
Feng Ni 已提交
39
    'FCOS',
G
Guanghua Yu 已提交
40
    'SOLOv2',
F
Feng Ni 已提交
41
    'TTFNet',
C
cnn 已提交
42
    'S2ANet',
G
George Ni 已提交
43 44 45
    'JDE',
    'FairMOT',
    'DeepSORT',
G
Guanghua Yu 已提交
46 47
    'GFL',
    'PicoDet',
Q
qingqing01 已提交
48 49 50 51 52 53
}


class Detector(object):
    """
    Args:
54
        pred_config (object): config of model, defined by `Config(model_dir)`
Q
qingqing01 已提交
55
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
G
Guanghua Yu 已提交
56
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
Q
qingqing01 已提交
57
        run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
58
        batch_size (int): size of pre batch in inference
59 60 61
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
62 63 64 65
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN
Q
qingqing01 已提交
66 67 68 69 70
    """

    def __init__(self,
                 pred_config,
                 model_dir,
G
Guanghua Yu 已提交
71
                 device='CPU',
Q
qingqing01 已提交
72
                 run_mode='fluid',
73
                 batch_size=1,
74 75 76
                 trt_min_shape=1,
                 trt_max_shape=1280,
                 trt_opt_shape=640,
77 78 79
                 trt_calib_mode=False,
                 cpu_threads=1,
                 enable_mkldnn=False):
Q
qingqing01 已提交
80
        self.pred_config = pred_config
81
        self.predictor, self.config = load_predictor(
Q
qingqing01 已提交
82 83
            model_dir,
            run_mode=run_mode,
84
            batch_size=batch_size,
Q
qingqing01 已提交
85
            min_subgraph_size=self.pred_config.min_subgraph_size,
G
Guanghua Yu 已提交
86
            device=device,
87
            use_dynamic_shape=self.pred_config.use_dynamic_shape,
88 89
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
G
Guanghua Yu 已提交
90
            trt_opt_shape=trt_opt_shape,
91 92 93
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
            enable_mkldnn=enable_mkldnn)
G
Guanghua Yu 已提交
94 95
        self.det_times = Timer()
        self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
Q
qingqing01 已提交
96

C
cnn 已提交
97
    def preprocess(self, image_list):
Q
qingqing01 已提交
98 99 100 101 102
        preprocess_ops = []
        for op_info in self.pred_config.preprocess_infos:
            new_op_info = op_info.copy()
            op_type = new_op_info.pop('type')
            preprocess_ops.append(eval(op_type)(**new_op_info))
C
cnn 已提交
103 104 105 106

        input_im_lst = []
        input_im_info_lst = []
        for im_path in image_list:
107
            im, im_info = preprocess(im_path, preprocess_ops)
C
cnn 已提交
108 109 110
            input_im_lst.append(im)
            input_im_info_lst.append(im_info)
        inputs = create_inputs(input_im_lst, input_im_info_lst)
Q
qingqing01 已提交
111 112
        return inputs

C
cnn 已提交
113 114 115 116 117 118
    def postprocess(self,
                    np_boxes,
                    np_masks,
                    inputs,
                    np_boxes_num,
                    threshold=0.5):
Q
qingqing01 已提交
119 120 121
        # postprocess output of predictor
        results = {}
        results['boxes'] = np_boxes
C
cnn 已提交
122
        results['boxes_num'] = np_boxes_num
Q
qingqing01 已提交
123 124 125 126
        if np_masks is not None:
            results['masks'] = np_masks
        return results

C
cnn 已提交
127
    def predict(self, image_list, threshold=0.5, warmup=0, repeats=1):
Q
qingqing01 已提交
128 129
        '''
        Args:
130
            image_list (list): list of image
Q
qingqing01 已提交
131 132 133 134 135
            threshold (float): threshold of predicted box' score
        Returns:
            results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
                            matix element:[class, score, x_min, y_min, x_max, y_max]
                            MaskRCNN's results include 'masks': np.ndarray:
G
Guanghua Yu 已提交
136
                            shape: [N, im_h, im_w]
Q
qingqing01 已提交
137
        '''
138
        self.det_times.preprocess_time_s.start()
C
cnn 已提交
139
        inputs = self.preprocess(image_list)
140
        self.det_times.preprocess_time_s.end()
Q
qingqing01 已提交
141 142 143 144 145 146 147 148 149 150
        np_boxes, np_masks = None, None
        input_names = self.predictor.get_input_names()
        for i in range(len(input_names)):
            input_tensor = self.predictor.get_input_handle(input_names[i])
            input_tensor.copy_from_cpu(inputs[input_names[i]])
        for i in range(warmup):
            self.predictor.run()
            output_names = self.predictor.get_output_names()
            boxes_tensor = self.predictor.get_output_handle(output_names[0])
            np_boxes = boxes_tensor.copy_to_cpu()
G
Guanghua Yu 已提交
151
            if self.pred_config.mask:
Q
qingqing01 已提交
152 153 154
                masks_tensor = self.predictor.get_output_handle(output_names[2])
                np_masks = masks_tensor.copy_to_cpu()

155
        self.det_times.inference_time_s.start()
Q
qingqing01 已提交
156 157 158 159 160
        for i in range(repeats):
            self.predictor.run()
            output_names = self.predictor.get_output_names()
            boxes_tensor = self.predictor.get_output_handle(output_names[0])
            np_boxes = boxes_tensor.copy_to_cpu()
C
cnn 已提交
161 162
            boxes_num = self.predictor.get_output_handle(output_names[1])
            np_boxes_num = boxes_num.copy_to_cpu()
G
Guanghua Yu 已提交
163
            if self.pred_config.mask:
Q
qingqing01 已提交
164 165
                masks_tensor = self.predictor.get_output_handle(output_names[2])
                np_masks = masks_tensor.copy_to_cpu()
166
        self.det_times.inference_time_s.end(repeats=repeats)
Q
qingqing01 已提交
167

168
        self.det_times.postprocess_time_s.start()
Q
qingqing01 已提交
169
        results = []
G
Guanghua Yu 已提交
170 171
        if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
            print('[WARNNING] No object detected.')
172
            results = {'boxes': np.array([[]]), 'boxes_num': [0]}
G
Guanghua Yu 已提交
173 174
        else:
            results = self.postprocess(
C
cnn 已提交
175
                np_boxes, np_masks, inputs, np_boxes_num, threshold=threshold)
176
        self.det_times.postprocess_time_s.end()
C
cnn 已提交
177
        self.det_times.img_num += len(image_list)
Q
qingqing01 已提交
178 179
        return results

W
wangguanzhong 已提交
180 181 182
    def get_timer(self):
        return self.det_times

Q
qingqing01 已提交
183

G
Guanghua Yu 已提交
184 185 186 187 188
class DetectorSOLOv2(Detector):
    """
    Args:
        config (object): config of model, defined by `Config(model_dir)`
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
G
Guanghua Yu 已提交
189
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
G
Guanghua Yu 已提交
190
        run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
191
        batch_size (int): size of pre batch in inference
192 193 194
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
195 196 197 198
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN 
G
Guanghua Yu 已提交
199 200 201 202 203
    """

    def __init__(self,
                 pred_config,
                 model_dir,
G
Guanghua Yu 已提交
204
                 device='CPU',
G
Guanghua Yu 已提交
205
                 run_mode='fluid',
206
                 batch_size=1,
207 208 209
                 trt_min_shape=1,
                 trt_max_shape=1280,
                 trt_opt_shape=640,
210 211 212
                 trt_calib_mode=False,
                 cpu_threads=1,
                 enable_mkldnn=False):
G
Guanghua Yu 已提交
213
        self.pred_config = pred_config
214
        self.predictor, self.config = load_predictor(
G
Guanghua Yu 已提交
215 216
            model_dir,
            run_mode=run_mode,
217
            batch_size=batch_size,
G
Guanghua Yu 已提交
218
            min_subgraph_size=self.pred_config.min_subgraph_size,
G
Guanghua Yu 已提交
219
            device=device,
220
            use_dynamic_shape=self.pred_config.use_dynamic_shape,
221 222
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
G
Guanghua Yu 已提交
223
            trt_opt_shape=trt_opt_shape,
224 225 226
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
            enable_mkldnn=enable_mkldnn)
G
Guanghua Yu 已提交
227
        self.det_times = Timer()
228
        self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
G
Guanghua Yu 已提交
229 230

    def predict(self, image, threshold=0.5, warmup=0, repeats=1):
G
Guanghua Yu 已提交
231 232 233 234 235
        '''
        Args:
            image (str/np.ndarray): path of image/ np.ndarray read by cv2
            threshold (float): threshold of predicted box' score
        Returns:
G
Guanghua Yu 已提交
236 237 238
            results (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
                            'cate_label': label of segm, shape:[N]
                            'cate_score': confidence score of segm, shape:[N]
G
Guanghua Yu 已提交
239
        '''
240
        self.det_times.preprocess_time_s.start()
G
Guanghua Yu 已提交
241
        inputs = self.preprocess(image)
242
        self.det_times.preprocess_time_s.end()
G
Guanghua Yu 已提交
243 244 245 246 247 248 249 250
        np_label, np_score, np_segms = None, None, None
        input_names = self.predictor.get_input_names()
        for i in range(len(input_names)):
            input_tensor = self.predictor.get_input_handle(input_names[i])
            input_tensor.copy_from_cpu(inputs[input_names[i]])
        for i in range(warmup):
            self.predictor.run()
            output_names = self.predictor.get_output_names()
W
wangguanzhong 已提交
251 252
            np_boxes_num = self.predictor.get_output_handle(output_names[
                0]).copy_to_cpu()
G
Guanghua Yu 已提交
253 254
            np_label = self.predictor.get_output_handle(output_names[
                1]).copy_to_cpu()
G
Guanghua Yu 已提交
255
            np_score = self.predictor.get_output_handle(output_names[
G
Guanghua Yu 已提交
256
                2]).copy_to_cpu()
G
Guanghua Yu 已提交
257 258
            np_segms = self.predictor.get_output_handle(output_names[
                3]).copy_to_cpu()
259
        self.det_times.inference_time_s.start()
G
Guanghua Yu 已提交
260 261 262
        for i in range(repeats):
            self.predictor.run()
            output_names = self.predictor.get_output_names()
W
wangguanzhong 已提交
263 264
            np_boxes_num = self.predictor.get_output_handle(output_names[
                0]).copy_to_cpu()
G
Guanghua Yu 已提交
265 266
            np_label = self.predictor.get_output_handle(output_names[
                1]).copy_to_cpu()
G
Guanghua Yu 已提交
267
            np_score = self.predictor.get_output_handle(output_names[
G
Guanghua Yu 已提交
268
                2]).copy_to_cpu()
G
Guanghua Yu 已提交
269 270
            np_segms = self.predictor.get_output_handle(output_names[
                3]).copy_to_cpu()
271
        self.det_times.inference_time_s.end(repeats=repeats)
G
Guanghua Yu 已提交
272
        self.det_times.img_num += 1
G
Guanghua Yu 已提交
273

W
wangguanzhong 已提交
274 275 276 277 278
        return dict(
            segm=np_segms,
            label=np_label,
            score=np_score,
            boxes_num=np_boxes_num)
G
Guanghua Yu 已提交
279 280


281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
class DetectorPicoDet(Detector):
    """
    Args:
        config (object): config of model, defined by `Config(model_dir)`
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
        run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
        batch_size (int): size of pre batch in inference
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN 
    """

    def __init__(self,
                 pred_config,
                 model_dir,
                 device='CPU',
                 run_mode='fluid',
                 batch_size=1,
                 trt_min_shape=1,
                 trt_max_shape=1280,
                 trt_opt_shape=640,
                 trt_calib_mode=False,
                 cpu_threads=1,
                 enable_mkldnn=False):
        self.pred_config = pred_config
        self.predictor, self.config = load_predictor(
            model_dir,
            run_mode=run_mode,
            batch_size=batch_size,
            min_subgraph_size=self.pred_config.min_subgraph_size,
            device=device,
            use_dynamic_shape=self.pred_config.use_dynamic_shape,
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
            trt_opt_shape=trt_opt_shape,
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
            enable_mkldnn=enable_mkldnn)
        self.det_times = Timer()
        self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0

    def predict(self, image, threshold=0.5, warmup=0, repeats=1):
        '''
        Args:
            image (str/np.ndarray): path of image/ np.ndarray read by cv2
            threshold (float): threshold of predicted box' score
        Returns:
            results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
                            matix element:[class, score, x_min, y_min, x_max, y_max]
        '''
        self.det_times.preprocess_time_s.start()
        inputs = self.preprocess(image)
        self.det_times.preprocess_time_s.end()
        input_names = self.predictor.get_input_names()
        for i in range(len(input_names)):
            input_tensor = self.predictor.get_input_handle(input_names[i])
            input_tensor.copy_from_cpu(inputs[input_names[i]])
        np_score_list, np_boxes_list = [], []
        for i in range(warmup):
            self.predictor.run()
            np_score_list.clear()
            np_boxes_list.clear()
            output_names = self.predictor.get_output_names()
            num_outs = int(len(output_names) / 2)
            for out_idx in range(num_outs):
                np_score_list.append(
                    self.predictor.get_output_handle(output_names[out_idx])
                    .copy_to_cpu())
                np_boxes_list.append(
                    self.predictor.get_output_handle(output_names[
                        out_idx + num_outs]).copy_to_cpu())

        self.det_times.inference_time_s.start()
        for i in range(repeats):
            self.predictor.run()
            np_score_list.clear()
            np_boxes_list.clear()
            output_names = self.predictor.get_output_names()
            num_outs = int(len(output_names) / 2)
            for out_idx in range(num_outs):
                np_score_list.append(
                    self.predictor.get_output_handle(output_names[out_idx])
                    .copy_to_cpu())
                np_boxes_list.append(
                    self.predictor.get_output_handle(output_names[
                        out_idx + num_outs]).copy_to_cpu())
        self.det_times.inference_time_s.end(repeats=repeats)
        self.det_times.img_num += 1
        self.det_times.postprocess_time_s.start()
        self.postprocess = PicoDetPostProcess(
            inputs['image'].shape[2:],
            inputs['im_shape'],
            inputs['scale_factor'],
            strides=self.pred_config.fpn_stride,
            nms_threshold=self.pred_config.nms['nms_threshold'])
        np_boxes, np_boxes_num = self.postprocess(np_score_list, np_boxes_list)
        self.det_times.postprocess_time_s.end()
        return dict(boxes=np_boxes, boxes_num=np_boxes_num)


C
cnn 已提交
386
def create_inputs(imgs, im_info):
Q
qingqing01 已提交
387 388
    """generate input for different model type
    Args:
W
wangguanzhong 已提交
389 390
        imgs (list(numpy)): list of images (np.ndarray)
        im_info (list(dict)): list of image info
Q
qingqing01 已提交
391 392 393 394 395
    Returns:
        inputs (dict): input of model
    """
    inputs = {}

C
cnn 已提交
396 397
    im_shape = []
    scale_factor = []
398 399 400 401 402 403 404 405
    if len(imgs) == 1:
        inputs['image'] = np.array((imgs[0], )).astype('float32')
        inputs['im_shape'] = np.array(
            (im_info[0]['im_shape'], )).astype('float32')
        inputs['scale_factor'] = np.array(
            (im_info[0]['scale_factor'], )).astype('float32')
        return inputs

C
cnn 已提交
406 407 408 409
    for e in im_info:
        im_shape.append(np.array((e['im_shape'], )).astype('float32'))
        scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))

C
cnn 已提交
410 411
    inputs['im_shape'] = np.concatenate(im_shape, axis=0)
    inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
C
cnn 已提交
412 413 414 415 416 417 418 419 420 421 422 423

    imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
    max_shape_h = max([e[0] for e in imgs_shape])
    max_shape_w = max([e[1] for e in imgs_shape])
    padding_imgs = []
    for img in imgs:
        im_c, im_h, im_w = img.shape[:]
        padding_im = np.zeros(
            (im_c, max_shape_h, max_shape_w), dtype=np.float32)
        padding_im[:, :im_h, :im_w] = img
        padding_imgs.append(padding_im)
    inputs['image'] = np.stack(padding_imgs, axis=0)
Q
qingqing01 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
    return inputs


class PredictConfig():
    """set config of preprocess, postprocess and visualize
    Args:
        model_dir (str): root path of model.yml
    """

    def __init__(self, model_dir):
        # parsing Yaml config for Preprocess
        deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
        with open(deploy_file) as f:
            yml_conf = yaml.safe_load(f)
        self.check_model(yml_conf)
        self.arch = yml_conf['arch']
        self.preprocess_infos = yml_conf['Preprocess']
        self.min_subgraph_size = yml_conf['min_subgraph_size']
        self.labels = yml_conf['label_list']
G
Guanghua Yu 已提交
443
        self.mask = False
444
        self.use_dynamic_shape = yml_conf['use_dynamic_shape']
G
Guanghua Yu 已提交
445 446
        if 'mask' in yml_conf:
            self.mask = yml_conf['mask']
447 448 449
        self.tracker = None
        if 'tracker' in yml_conf:
            self.tracker = yml_conf['tracker']
450 451 452 453
        if 'NMS' in yml_conf:
            self.nms = yml_conf['NMS']
        if 'fpn_stride' in yml_conf:
            self.fpn_stride = yml_conf['fpn_stride']
Q
qingqing01 已提交
454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
        self.print_config()

    def check_model(self, yml_conf):
        """
        Raises:
            ValueError: loaded model not in supported model type 
        """
        for support_model in SUPPORT_MODELS:
            if support_model in yml_conf['arch']:
                return True
        raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
            'arch'], SUPPORT_MODELS))

    def print_config(self):
        print('-----------  Model Configuration -----------')
        print('%s: %s' % ('Model Arch', self.arch))
        print('%s: ' % ('Transform Order'))
        for op_info in self.preprocess_infos:
            print('--%s: %s' % ('transform op', op_info['type']))
        print('--------------------------------------------')


def load_predictor(model_dir,
                   run_mode='fluid',
                   batch_size=1,
G
Guanghua Yu 已提交
479
                   device='CPU',
480 481 482 483
                   min_subgraph_size=3,
                   use_dynamic_shape=False,
                   trt_min_shape=1,
                   trt_max_shape=1280,
G
Guanghua Yu 已提交
484
                   trt_opt_shape=640,
485 486 487
                   trt_calib_mode=False,
                   cpu_threads=1,
                   enable_mkldnn=False):
Q
qingqing01 已提交
488 489 490
    """set AnalysisConfig, generate AnalysisPredictor
    Args:
        model_dir (str): root path of __model__ and __params__
G
Guanghua Yu 已提交
491
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
492
        run_mode (str): mode of running(fluid/trt_fp32/trt_fp16/trt_int8)
493 494 495 496
        use_dynamic_shape (bool): use dynamic shape or not
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
G
Guanghua Yu 已提交
497 498
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
Q
qingqing01 已提交
499 500 501
    Returns:
        predictor (PaddlePredictor): AnalysisPredictor
    Raises:
G
Guanghua Yu 已提交
502
        ValueError: predict by TensorRT need device == 'GPU'.
Q
qingqing01 已提交
503
    """
G
Guanghua Yu 已提交
504
    if device != 'GPU' and run_mode != 'fluid':
Q
qingqing01 已提交
505
        raise ValueError(
G
Guanghua Yu 已提交
506 507
            "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
            .format(run_mode, device))
Q
qingqing01 已提交
508 509 510
    config = Config(
        os.path.join(model_dir, 'model.pdmodel'),
        os.path.join(model_dir, 'model.pdiparams'))
G
Guanghua Yu 已提交
511
    if device == 'GPU':
Q
qingqing01 已提交
512 513 514
        # initial GPU memory(M), device ID
        config.enable_use_gpu(200, 0)
        # optimize graph and fuse op
515
        config.switch_ir_optim(True)
G
Guanghua Yu 已提交
516 517
    elif device == 'XPU':
        config.enable_xpu(10 * 1024 * 1024)
Q
qingqing01 已提交
518 519
    else:
        config.disable_gpu()
520 521
        config.set_cpu_math_library_num_threads(cpu_threads)
        if enable_mkldnn:
G
Guanghua Yu 已提交
522 523 524 525 526 527 528 529 530
            try:
                # cache 10 different shapes for mkldnn to avoid memory leak
                config.set_mkldnn_cache_capacity(10)
                config.enable_mkldnn()
            except Exception as e:
                print(
                    "The current environment does not support `mkldnn`, so disable mkldnn."
                )
                pass
Q
qingqing01 已提交
531

G
Guanghua Yu 已提交
532 533 534 535 536
    precision_map = {
        'trt_int8': Config.Precision.Int8,
        'trt_fp32': Config.Precision.Float32,
        'trt_fp16': Config.Precision.Half
    }
Q
qingqing01 已提交
537 538 539 540 541 542 543
    if run_mode in precision_map.keys():
        config.enable_tensorrt_engine(
            workspace_size=1 << 10,
            max_batch_size=batch_size,
            min_subgraph_size=min_subgraph_size,
            precision_mode=precision_map[run_mode],
            use_static=False,
G
Guanghua Yu 已提交
544
            use_calib_mode=trt_calib_mode)
545 546

        if use_dynamic_shape:
547 548 549 550 551 552 553 554 555
            min_input_shape = {
                'image': [batch_size, 3, trt_min_shape, trt_min_shape]
            }
            max_input_shape = {
                'image': [batch_size, 3, trt_max_shape, trt_max_shape]
            }
            opt_input_shape = {
                'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
            }
556 557 558
            config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
                                              opt_input_shape)
            print('trt set dynamic shape done!')
Q
qingqing01 已提交
559 560 561 562 563 564 565 566

    # disable print log when predict
    config.disable_glog_info()
    # enable shared memory
    config.enable_memory_optim()
    # disable feed, fetch OP, needed by zero_copy_run
    config.switch_use_feed_fetch_ops(False)
    predictor = create_predictor(config)
567
    return predictor, config
Q
qingqing01 已提交
568 569


G
Guanghua Yu 已提交
570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
def get_test_images(infer_dir, infer_img):
    """
    Get image path list in TEST mode
    """
    assert infer_img is not None or infer_dir is not None, \
        "--infer_img or --infer_dir should be set"
    assert infer_img is None or os.path.isfile(infer_img), \
            "{} is not a file".format(infer_img)
    assert infer_dir is None or os.path.isdir(infer_dir), \
            "{} is not a directory".format(infer_dir)

    # infer_img has a higher priority
    if infer_img and os.path.isfile(infer_img):
        return [infer_img]

    images = set()
    infer_dir = os.path.abspath(infer_dir)
    assert os.path.isdir(infer_dir), \
        "infer_dir {} is not a directory".format(infer_dir)
    exts = ['jpg', 'jpeg', 'png', 'bmp']
    exts += [ext.upper() for ext in exts]
    for ext in exts:
        images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
    images = list(images)

    assert len(images) > 0, "no image found in {}".format(infer_dir)
    print("Found {} inference images in total.".format(len(images)))

    return images


C
cnn 已提交
601
def visualize(image_list, results, labels, output_dir='output/', threshold=0.5):
Q
qingqing01 已提交
602
    # visualize the predict result
C
cnn 已提交
603 604 605 606 607 608 609 610 611 612 613 614 615
    start_idx = 0
    for idx, image_file in enumerate(image_list):
        im_bboxes_num = results['boxes_num'][idx]
        im_results = {}
        if 'boxes' in results:
            im_results['boxes'] = results['boxes'][start_idx:start_idx +
                                                   im_bboxes_num, :]
        if 'masks' in results:
            im_results['masks'] = results['masks'][start_idx:start_idx +
                                                   im_bboxes_num, :]
        if 'segm' in results:
            im_results['segm'] = results['segm'][start_idx:start_idx +
                                                 im_bboxes_num, :]
W
wangguanzhong 已提交
616 617 618 619 620 621 622
        if 'label' in results:
            im_results['label'] = results['label'][start_idx:start_idx +
                                                   im_bboxes_num]
        if 'score' in results:
            im_results['score'] = results['score'][start_idx:start_idx +
                                                   im_bboxes_num]

C
cnn 已提交
623 624 625 626 627 628 629 630 631
        start_idx += im_bboxes_num
        im = visualize_box_mask(
            image_file, im_results, labels, threshold=threshold)
        img_name = os.path.split(image_file)[-1]
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        out_path = os.path.join(output_dir, img_name)
        im.save(out_path, quality=95)
        print("save result to: " + out_path)
Q
qingqing01 已提交
632 633 634 635 636 637 638 639 640


def print_arguments(args):
    print('-----------  Running Arguments -----------')
    for arg, value in sorted(vars(args).items()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------')


C
cnn 已提交
641 642 643 644 645 646
def predict_image(detector, image_list, batch_size=1):
    batch_loop_cnt = math.ceil(float(len(image_list)) / batch_size)
    for i in range(batch_loop_cnt):
        start_index = i * batch_size
        end_index = min((i + 1) * batch_size, len(image_list))
        batch_image_list = image_list[start_index:end_index]
G
Guanghua Yu 已提交
647
        if FLAGS.run_benchmark:
C
cnn 已提交
648 649
            detector.predict(
                batch_image_list, FLAGS.threshold, warmup=10, repeats=10)
G
Guanghua Yu 已提交
650 651 652 653
            cm, gm, gu = get_current_memory_mb()
            detector.cpu_mem += cm
            detector.gpu_mem += gm
            detector.gpu_util += gu
C
cnn 已提交
654
            print('Test iter {}'.format(i))
G
Guanghua Yu 已提交
655
        else:
C
cnn 已提交
656
            results = detector.predict(batch_image_list, FLAGS.threshold)
G
Guanghua Yu 已提交
657
            visualize(
C
cnn 已提交
658
                batch_image_list,
G
Guanghua Yu 已提交
659 660 661 662
                results,
                detector.pred_config.labels,
                output_dir=FLAGS.output_dir,
                threshold=FLAGS.threshold)
Q
qingqing01 已提交
663 664 665 666 667 668 669 670 671 672


def predict_video(detector, camera_id):
    if camera_id != -1:
        capture = cv2.VideoCapture(camera_id)
        video_name = 'output.mp4'
    else:
        capture = cv2.VideoCapture(FLAGS.video_file)
        video_name = os.path.split(FLAGS.video_file)[-1]
    fps = 30
C
cnn 已提交
673 674
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    print('frame_count', frame_count)
Q
qingqing01 已提交
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
    width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    # yapf: disable
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    # yapf: enable
    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)
    out_path = os.path.join(FLAGS.output_dir, video_name)
    writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
    index = 1
    while (1):
        ret, frame = capture.read()
        if not ret:
            break
        print('detect frame:%d' % (index))
        index += 1
C
cnn 已提交
691
        results = detector.predict([frame], FLAGS.threshold)
Q
qingqing01 已提交
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
        im = visualize_box_mask(
            frame,
            results,
            detector.pred_config.labels,
            threshold=FLAGS.threshold)
        im = np.array(im)
        writer.write(im)
        if camera_id != -1:
            cv2.imshow('Mask Detection', im)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
    writer.release()


def main():
    pred_config = PredictConfig(FLAGS.model_dir)
708
    detector_func = 'Detector'
G
Guanghua Yu 已提交
709
    if pred_config.arch == 'SOLOv2':
710 711 712 713 714 715 716 717 718 719 720 721 722 723 724
        detector_func = 'DetectorSOLOv2'
    elif pred_config.arch == 'PicoDet':
        detector_func = 'DetectorPicoDet'

    detector = eval(detector_func)(pred_config,
                                   FLAGS.model_dir,
                                   device=FLAGS.device,
                                   run_mode=FLAGS.run_mode,
                                   batch_size=FLAGS.batch_size,
                                   trt_min_shape=FLAGS.trt_min_shape,
                                   trt_max_shape=FLAGS.trt_max_shape,
                                   trt_opt_shape=FLAGS.trt_opt_shape,
                                   trt_calib_mode=FLAGS.trt_calib_mode,
                                   cpu_threads=FLAGS.cpu_threads,
                                   enable_mkldnn=FLAGS.enable_mkldnn)
G
Guanghua Yu 已提交
725

Q
qingqing01 已提交
726
    # predict from video file or camera video stream
G
Guanghua Yu 已提交
727
    if FLAGS.video_file is not None or FLAGS.camera_id != -1:
Q
qingqing01 已提交
728
        predict_video(detector, FLAGS.camera_id)
G
Guanghua Yu 已提交
729 730
    else:
        # predict from image
C
cnn 已提交
731 732
        if FLAGS.image_dir is None and FLAGS.image_file is not None:
            assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
G
Guanghua Yu 已提交
733
        img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
C
cnn 已提交
734
        predict_image(detector, img_list, FLAGS.batch_size)
G
Guanghua Yu 已提交
735 736 737 738
        if not FLAGS.run_benchmark:
            detector.det_times.info(average=True)
        else:
            mems = {
739 740
                'cpu_rss_mb': detector.cpu_mem / len(img_list),
                'gpu_rss_mb': detector.gpu_mem / len(img_list),
G
Guanghua Yu 已提交
741 742
                'gpu_util': detector.gpu_util * 100 / len(img_list)
            }
743 744 745 746 747

            perf_info = detector.det_times.report(average=True)
            model_dir = FLAGS.model_dir
            mode = FLAGS.run_mode
            model_info = {
748 749
                'model_name': model_dir.strip('/').split('/')[-1],
                'precision': mode.split('_')[-1]
750 751
            }
            data_info = {
752
                'batch_size': FLAGS.batch_size,
753 754 755
                'shape': "dynamic_shape",
                'data_num': perf_info['img_num']
            }
756 757
            det_log = PaddleInferBenchmark(detector.config, model_info,
                                           data_info, perf_info, mems)
758
            det_log('Det')
Q
qingqing01 已提交
759 760 761 762


if __name__ == '__main__':
    paddle.enable_static()
G
Guanghua Yu 已提交
763
    parser = argsparser()
Q
qingqing01 已提交
764 765
    FLAGS = parser.parse_args()
    print_arguments(FLAGS)
G
Guanghua Yu 已提交
766 767 768 769
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"
    assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
Q
qingqing01 已提交
770 771

    main()