infer.py 29.3 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
G
Guanghua Yu 已提交
17
import glob
Q
qingqing01 已提交
18 19 20 21
from functools import reduce

import cv2
import numpy as np
C
cnn 已提交
22
import math
Q
qingqing01 已提交
23 24 25 26
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor

W
wangguanzhong 已提交
27 28 29 30 31
import sys
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)

32
from benchmark_utils import PaddleInferBenchmark
33
from picodet_postprocess import PicoDetPostProcess
W
wangguanzhong 已提交
34
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine
W
wangguanzhong 已提交
35
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
G
Guanghua Yu 已提交
36
from visualize import visualize_box_mask
37
from utils import argsparser, Timer, get_current_memory_mb
G
Guanghua Yu 已提交
38

Q
qingqing01 已提交
39 40 41 42 43
# Global dictionary
SUPPORT_MODELS = {
    'YOLO',
    'RCNN',
    'SSD',
44
    'Face',
F
Feng Ni 已提交
45
    'FCOS',
G
Guanghua Yu 已提交
46
    'SOLOv2',
F
Feng Ni 已提交
47
    'TTFNet',
C
cnn 已提交
48
    'S2ANet',
G
George Ni 已提交
49 50 51
    'JDE',
    'FairMOT',
    'DeepSORT',
G
Guanghua Yu 已提交
52 53
    'GFL',
    'PicoDet',
W
wangguanzhong 已提交
54
    'CenterNet',
S
shangliang Xu 已提交
55
    'TOOD',
W
wangguanzhong 已提交
56
    'StrongBaseline',
Q
qingqing01 已提交
57 58 59
}


W
wangguanzhong 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
def bench_log(detector, img_list, model_info, batch_size=1, name=None):
    mems = {
        'cpu_rss_mb': detector.cpu_mem / len(img_list),
        'gpu_rss_mb': detector.gpu_mem / len(img_list),
        'gpu_util': detector.gpu_util * 100 / len(img_list)
    }
    perf_info = detector.det_times.report(average=True)
    data_info = {
        'batch_size': batch_size,
        'shape': "dynamic_shape",
        'data_num': perf_info['img_num']
    }
    log = PaddleInferBenchmark(detector.config, model_info, data_info,
                               perf_info, mems)
    log(name)


Q
qingqing01 已提交
77 78 79
class Detector(object):
    """
    Args:
80
        pred_config (object): config of model, defined by `Config(model_dir)`
Q
qingqing01 已提交
81
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
G
Guanghua Yu 已提交
82
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
83
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
84
        batch_size (int): size of pre batch in inference
85 86 87
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
88 89 90 91
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN
W
wangguanzhong 已提交
92 93
        output_dir (str): The path of output
        threshold (float): The threshold of score for visualization
Q
qingqing01 已提交
94 95
    """

W
wangguanzhong 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    def __init__(
            self,
            model_dir,
            device='CPU',
            run_mode='paddle',
            batch_size=1,
            trt_min_shape=1,
            trt_max_shape=1280,
            trt_opt_shape=640,
            trt_calib_mode=False,
            cpu_threads=1,
            enable_mkldnn=False,
            output_dir='output',
            threshold=0.5, ):
        self.pred_config = self.set_config(model_dir)
111
        self.predictor, self.config = load_predictor(
Q
qingqing01 已提交
112 113
            model_dir,
            run_mode=run_mode,
114
            batch_size=batch_size,
Q
qingqing01 已提交
115
            min_subgraph_size=self.pred_config.min_subgraph_size,
G
Guanghua Yu 已提交
116
            device=device,
117
            use_dynamic_shape=self.pred_config.use_dynamic_shape,
118 119
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
G
Guanghua Yu 已提交
120
            trt_opt_shape=trt_opt_shape,
121 122 123
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
            enable_mkldnn=enable_mkldnn)
G
Guanghua Yu 已提交
124 125
        self.det_times = Timer()
        self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
W
wangguanzhong 已提交
126 127 128 129 130 131
        self.batch_size = batch_size
        self.output_dir = output_dir
        self.threshold = threshold

    def set_config(self, model_dir):
        return PredictConfig(model_dir)
Q
qingqing01 已提交
132

C
cnn 已提交
133
    def preprocess(self, image_list):
Q
qingqing01 已提交
134 135 136 137 138
        preprocess_ops = []
        for op_info in self.pred_config.preprocess_infos:
            new_op_info = op_info.copy()
            op_type = new_op_info.pop('type')
            preprocess_ops.append(eval(op_type)(**new_op_info))
C
cnn 已提交
139 140 141 142

        input_im_lst = []
        input_im_info_lst = []
        for im_path in image_list:
143
            im, im_info = preprocess(im_path, preprocess_ops)
C
cnn 已提交
144 145 146
            input_im_lst.append(im)
            input_im_info_lst.append(im_info)
        inputs = create_inputs(input_im_lst, input_im_info_lst)
W
wangguanzhong 已提交
147 148 149 150 151
        input_names = self.predictor.get_input_names()
        for i in range(len(input_names)):
            input_tensor = self.predictor.get_input_handle(input_names[i])
            input_tensor.copy_from_cpu(inputs[input_names[i]])

Q
qingqing01 已提交
152 153
        return inputs

W
wangguanzhong 已提交
154
    def postprocess(self, inputs, result):
Q
qingqing01 已提交
155
        # postprocess output of predictor
W
wangguanzhong 已提交
156 157 158 159 160 161
        np_boxes_num = result['boxes_num']
        if np_boxes_num[0] <= 0:
            print('[WARNNING] No object detected.')
            result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
        result = {k: v for k, v in result.items() if v is not None}
        return result
Q
qingqing01 已提交
162

W
wangguanzhong 已提交
163
    def predict(self, repeats=1):
Q
qingqing01 已提交
164 165
        '''
        Args:
W
wangguanzhong 已提交
166
            repeats (int): repeats number for prediction
Q
qingqing01 已提交
167
        Returns:
W
wangguanzhong 已提交
168
            result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
Q
qingqing01 已提交
169
                            matix element:[class, score, x_min, y_min, x_max, y_max]
W
wangguanzhong 已提交
170
                            MaskRCNN's result include 'masks': np.ndarray:
G
Guanghua Yu 已提交
171
                            shape: [N, im_h, im_w]
Q
qingqing01 已提交
172
        '''
W
wangguanzhong 已提交
173
        # model prediction
W
wangguanzhong 已提交
174
        np_boxes, np_masks = None, None
Q
qingqing01 已提交
175 176 177 178 179
        for i in range(repeats):
            self.predictor.run()
            output_names = self.predictor.get_output_names()
            boxes_tensor = self.predictor.get_output_handle(output_names[0])
            np_boxes = boxes_tensor.copy_to_cpu()
C
cnn 已提交
180 181
            boxes_num = self.predictor.get_output_handle(output_names[1])
            np_boxes_num = boxes_num.copy_to_cpu()
G
Guanghua Yu 已提交
182
            if self.pred_config.mask:
Q
qingqing01 已提交
183 184
                masks_tensor = self.predictor.get_output_handle(output_names[2])
                np_masks = masks_tensor.copy_to_cpu()
W
wangguanzhong 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198
        result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
        return result

    def merge_batch_result(self, batch_result):
        if len(batch_result) == 1:
            return batch_result[0]
        res_key = batch_result[0].keys()
        results = {k: [] for k in res_key}
        for res in batch_result:
            for k, v in res.items():
                results[k].append(v)
        for k, v in results.items():
            results[k] = np.concatenate(v)
        return results
Q
qingqing01 已提交
199

W
wangguanzhong 已提交
200 201
    def get_timer(self):
        return self.det_times
W
wangguanzhong 已提交
202

W
wangguanzhong 已提交
203 204 205 206 207 208
    def predict_image(self,
                      image_list,
                      run_benchmark=False,
                      repeats=1,
                      visual=True):
        batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
Q
qingqing01 已提交
209
        results = []
W
wangguanzhong 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
        for i in range(batch_loop_cnt):
            start_index = i * self.batch_size
            end_index = min((i + 1) * self.batch_size, len(image_list))
            batch_image_list = image_list[start_index:end_index]
            if run_benchmark:
                # preprocess
                inputs = self.preprocess(batch_image_list)  # warmup
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                # model prediction
                result = self.predict(repeats=repeats)  # warmup
                self.det_times.inference_time_s.start()
                result = self.predict(repeats=repeats)
                self.det_times.inference_time_s.end(repeats=repeats)

                # postprocess
                result_warmup = self.postprocess(inputs, result)  # warmup
                self.det_times.postprocess_time_s.start()
                result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()
                self.det_times.img_num += len(batch_image_list)

                cm, gm, gu = get_current_memory_mb()
                self.cpu_mem += cm
                self.gpu_mem += gm
                self.gpu_util += gu
            else:
                # preprocess
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                # model prediction
                self.det_times.inference_time_s.start()
                result = self.predict()
                self.det_times.inference_time_s.end()

                # postprocess
                self.det_times.postprocess_time_s.start()
                result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()
                self.det_times.img_num += len(batch_image_list)

                if visual:
                    visualize(
                        batch_image_list,
                        result,
                        self.pred_config.labels,
                        output_dir=self.output_dir,
                        threshold=self.threshold)

            results.append(result)
            if visual:
                print('Test iter {}'.format(i))

        results = self.merge_batch_result(results)
Q
qingqing01 已提交
268 269
        return results

W
wangguanzhong 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
    def predict_video(self, video_file, camera_id):
        video_out_name = 'output.mp4'
        if camera_id != -1:
            capture = cv2.VideoCapture(camera_id)
        else:
            capture = cv2.VideoCapture(video_file)
            video_out_name = os.path.split(video_file)[-1]
        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        print("fps: %d, frame_count: %d" % (fps, frame_count))

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
        index = 1
        while (1):
            ret, frame = capture.read()
            if not ret:
                break
            print('detect frame: %d' % (index))
            index += 1
            results = self.predict_image([frame], visual=False)

            im = visualize_box_mask(
                frame,
                results,
                self.pred_config.labels,
                threshold=self.threshold)
            im = np.array(im)
            writer.write(im)
            if camera_id != -1:
                cv2.imshow('Mask Detection', im)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
        writer.release()
W
wangguanzhong 已提交
310

Q
qingqing01 已提交
311

G
Guanghua Yu 已提交
312 313 314 315
class DetectorSOLOv2(Detector):
    """
    Args:
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
G
Guanghua Yu 已提交
316
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
317
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
318
        batch_size (int): size of pre batch in inference
319 320 321
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
322 323 324 325
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN 
W
wangguanzhong 已提交
326 327 328
        output_dir (str): The path of output
        threshold (float): The threshold of score for visualization
       
G
Guanghua Yu 已提交
329 330
    """

W
wangguanzhong 已提交
331 332
    def __init__(
            self,
G
Guanghua Yu 已提交
333
            model_dir,
W
wangguanzhong 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347
            device='CPU',
            run_mode='paddle',
            batch_size=1,
            trt_min_shape=1,
            trt_max_shape=1280,
            trt_opt_shape=640,
            trt_calib_mode=False,
            cpu_threads=1,
            enable_mkldnn=False,
            output_dir='./',
            threshold=0.5, ):
        super(DetectorSOLOv2, self).__init__(
            model_dir=model_dir,
            device=device,
G
Guanghua Yu 已提交
348
            run_mode=run_mode,
349
            batch_size=batch_size,
350 351
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
G
Guanghua Yu 已提交
352
            trt_opt_shape=trt_opt_shape,
353 354
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
W
wangguanzhong 已提交
355 356 357
            enable_mkldnn=enable_mkldnn,
            output_dir=output_dir,
            threshold=threshold, )
G
Guanghua Yu 已提交
358

W
wangguanzhong 已提交
359
    def predict(self, repeats=1):
G
Guanghua Yu 已提交
360 361
        '''
        Args:
W
wangguanzhong 已提交
362
            repeats (int): repeat number for prediction
G
Guanghua Yu 已提交
363
        Returns:
W
wangguanzhong 已提交
364
            result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
G
Guanghua Yu 已提交
365 366
                            'cate_label': label of segm, shape:[N]
                            'cate_score': confidence score of segm, shape:[N]
G
Guanghua Yu 已提交
367 368 369 370 371
        '''
        np_label, np_score, np_segms = None, None, None
        for i in range(repeats):
            self.predictor.run()
            output_names = self.predictor.get_output_names()
W
wangguanzhong 已提交
372 373
            np_boxes_num = self.predictor.get_output_handle(output_names[
                0]).copy_to_cpu()
G
Guanghua Yu 已提交
374 375
            np_label = self.predictor.get_output_handle(output_names[
                1]).copy_to_cpu()
G
Guanghua Yu 已提交
376
            np_score = self.predictor.get_output_handle(output_names[
G
Guanghua Yu 已提交
377
                2]).copy_to_cpu()
G
Guanghua Yu 已提交
378 379
            np_segms = self.predictor.get_output_handle(output_names[
                3]).copy_to_cpu()
G
Guanghua Yu 已提交
380

W
wangguanzhong 已提交
381
        result = dict(
W
wangguanzhong 已提交
382 383 384 385
            segm=np_segms,
            label=np_label,
            score=np_score,
            boxes_num=np_boxes_num)
W
wangguanzhong 已提交
386
        return result
G
Guanghua Yu 已提交
387 388


389 390 391 392 393
class DetectorPicoDet(Detector):
    """
    Args:
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
394
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
395 396 397 398 399 400 401 402 403 404
        batch_size (int): size of pre batch in inference
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN 
    """

W
wangguanzhong 已提交
405 406
    def __init__(
            self,
407
            model_dir,
W
wangguanzhong 已提交
408 409 410 411 412 413 414 415 416 417 418 419 420 421
            device='CPU',
            run_mode='paddle',
            batch_size=1,
            trt_min_shape=1,
            trt_max_shape=1280,
            trt_opt_shape=640,
            trt_calib_mode=False,
            cpu_threads=1,
            enable_mkldnn=False,
            output_dir='./',
            threshold=0.5, ):
        super(DetectorPicoDet, self).__init__(
            model_dir=model_dir,
            device=device,
422 423 424 425 426 427 428
            run_mode=run_mode,
            batch_size=batch_size,
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
            trt_opt_shape=trt_opt_shape,
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
W
wangguanzhong 已提交
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
            enable_mkldnn=enable_mkldnn,
            output_dir=output_dir,
            threshold=threshold, )

    def postprocess(self, inputs, result):
        # postprocess output of predictor
        np_score_list = result['boxes']
        np_boxes_list = result['boxes_num']
        postprocessor = PicoDetPostProcess(
            inputs['image'].shape[2:],
            inputs['im_shape'],
            inputs['scale_factor'],
            strides=self.pred_config.fpn_stride,
            nms_threshold=self.pred_config.nms['nms_threshold'])
        np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
        result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
        return result
446

W
wangguanzhong 已提交
447
    def predict(self, repeats=1):
448 449
        '''
        Args:
W
wangguanzhong 已提交
450
            repeats (int): repeat number for prediction
451
        Returns:
W
wangguanzhong 已提交
452
            result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468
                            matix element:[class, score, x_min, y_min, x_max, y_max]
        '''
        np_score_list, np_boxes_list = [], []
        for i in range(repeats):
            self.predictor.run()
            np_score_list.clear()
            np_boxes_list.clear()
            output_names = self.predictor.get_output_names()
            num_outs = int(len(output_names) / 2)
            for out_idx in range(num_outs):
                np_score_list.append(
                    self.predictor.get_output_handle(output_names[out_idx])
                    .copy_to_cpu())
                np_boxes_list.append(
                    self.predictor.get_output_handle(output_names[
                        out_idx + num_outs]).copy_to_cpu())
W
wangguanzhong 已提交
469 470
        result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
        return result
471 472


C
cnn 已提交
473
def create_inputs(imgs, im_info):
Q
qingqing01 已提交
474 475
    """generate input for different model type
    Args:
W
wangguanzhong 已提交
476 477
        imgs (list(numpy)): list of images (np.ndarray)
        im_info (list(dict)): list of image info
Q
qingqing01 已提交
478 479 480 481 482
    Returns:
        inputs (dict): input of model
    """
    inputs = {}

C
cnn 已提交
483 484
    im_shape = []
    scale_factor = []
485 486 487 488 489 490 491 492
    if len(imgs) == 1:
        inputs['image'] = np.array((imgs[0], )).astype('float32')
        inputs['im_shape'] = np.array(
            (im_info[0]['im_shape'], )).astype('float32')
        inputs['scale_factor'] = np.array(
            (im_info[0]['scale_factor'], )).astype('float32')
        return inputs

C
cnn 已提交
493 494 495 496
    for e in im_info:
        im_shape.append(np.array((e['im_shape'], )).astype('float32'))
        scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))

C
cnn 已提交
497 498
    inputs['im_shape'] = np.concatenate(im_shape, axis=0)
    inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
C
cnn 已提交
499 500 501 502 503 504 505 506 507 508 509 510

    imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
    max_shape_h = max([e[0] for e in imgs_shape])
    max_shape_w = max([e[1] for e in imgs_shape])
    padding_imgs = []
    for img in imgs:
        im_c, im_h, im_w = img.shape[:]
        padding_im = np.zeros(
            (im_c, max_shape_h, max_shape_w), dtype=np.float32)
        padding_im[:, :im_h, :im_w] = img
        padding_imgs.append(padding_im)
    inputs['image'] = np.stack(padding_imgs, axis=0)
Q
qingqing01 已提交
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
    return inputs


class PredictConfig():
    """set config of preprocess, postprocess and visualize
    Args:
        model_dir (str): root path of model.yml
    """

    def __init__(self, model_dir):
        # parsing Yaml config for Preprocess
        deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
        with open(deploy_file) as f:
            yml_conf = yaml.safe_load(f)
        self.check_model(yml_conf)
        self.arch = yml_conf['arch']
        self.preprocess_infos = yml_conf['Preprocess']
        self.min_subgraph_size = yml_conf['min_subgraph_size']
        self.labels = yml_conf['label_list']
G
Guanghua Yu 已提交
530
        self.mask = False
531
        self.use_dynamic_shape = yml_conf['use_dynamic_shape']
G
Guanghua Yu 已提交
532 533
        if 'mask' in yml_conf:
            self.mask = yml_conf['mask']
534 535 536
        self.tracker = None
        if 'tracker' in yml_conf:
            self.tracker = yml_conf['tracker']
537 538 539 540
        if 'NMS' in yml_conf:
            self.nms = yml_conf['NMS']
        if 'fpn_stride' in yml_conf:
            self.fpn_stride = yml_conf['fpn_stride']
Q
qingqing01 已提交
541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
        self.print_config()

    def check_model(self, yml_conf):
        """
        Raises:
            ValueError: loaded model not in supported model type 
        """
        for support_model in SUPPORT_MODELS:
            if support_model in yml_conf['arch']:
                return True
        raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
            'arch'], SUPPORT_MODELS))

    def print_config(self):
        print('-----------  Model Configuration -----------')
        print('%s: %s' % ('Model Arch', self.arch))
        print('%s: ' % ('Transform Order'))
        for op_info in self.preprocess_infos:
            print('--%s: %s' % ('transform op', op_info['type']))
        print('--------------------------------------------')


def load_predictor(model_dir,
564
                   run_mode='paddle',
Q
qingqing01 已提交
565
                   batch_size=1,
G
Guanghua Yu 已提交
566
                   device='CPU',
567 568 569 570
                   min_subgraph_size=3,
                   use_dynamic_shape=False,
                   trt_min_shape=1,
                   trt_max_shape=1280,
G
Guanghua Yu 已提交
571
                   trt_opt_shape=640,
572 573 574
                   trt_calib_mode=False,
                   cpu_threads=1,
                   enable_mkldnn=False):
Q
qingqing01 已提交
575 576 577
    """set AnalysisConfig, generate AnalysisPredictor
    Args:
        model_dir (str): root path of __model__ and __params__
G
Guanghua Yu 已提交
578
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
579
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
580 581 582 583
        use_dynamic_shape (bool): use dynamic shape or not
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
G
Guanghua Yu 已提交
584 585
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
Q
qingqing01 已提交
586 587 588
    Returns:
        predictor (PaddlePredictor): AnalysisPredictor
    Raises:
G
Guanghua Yu 已提交
589
        ValueError: predict by TensorRT need device == 'GPU'.
Q
qingqing01 已提交
590
    """
591
    if device != 'GPU' and run_mode != 'paddle':
Q
qingqing01 已提交
592
        raise ValueError(
G
Guanghua Yu 已提交
593 594
            "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
            .format(run_mode, device))
Q
qingqing01 已提交
595 596 597
    config = Config(
        os.path.join(model_dir, 'model.pdmodel'),
        os.path.join(model_dir, 'model.pdiparams'))
G
Guanghua Yu 已提交
598
    if device == 'GPU':
Q
qingqing01 已提交
599 600 601
        # initial GPU memory(M), device ID
        config.enable_use_gpu(200, 0)
        # optimize graph and fuse op
602
        config.switch_ir_optim(True)
G
Guanghua Yu 已提交
603
    elif device == 'XPU':
604
        config.enable_lite_engine()
G
Guanghua Yu 已提交
605
        config.enable_xpu(10 * 1024 * 1024)
Q
qingqing01 已提交
606 607
    else:
        config.disable_gpu()
608 609
        config.set_cpu_math_library_num_threads(cpu_threads)
        if enable_mkldnn:
G
Guanghua Yu 已提交
610 611 612 613 614 615 616 617 618
            try:
                # cache 10 different shapes for mkldnn to avoid memory leak
                config.set_mkldnn_cache_capacity(10)
                config.enable_mkldnn()
            except Exception as e:
                print(
                    "The current environment does not support `mkldnn`, so disable mkldnn."
                )
                pass
Q
qingqing01 已提交
619

G
Guanghua Yu 已提交
620 621 622 623 624
    precision_map = {
        'trt_int8': Config.Precision.Int8,
        'trt_fp32': Config.Precision.Float32,
        'trt_fp16': Config.Precision.Half
    }
Q
qingqing01 已提交
625 626
    if run_mode in precision_map.keys():
        config.enable_tensorrt_engine(
627
            workspace_size=1 << 25,
Q
qingqing01 已提交
628 629 630 631
            max_batch_size=batch_size,
            min_subgraph_size=min_subgraph_size,
            precision_mode=precision_map[run_mode],
            use_static=False,
G
Guanghua Yu 已提交
632
            use_calib_mode=trt_calib_mode)
633 634

        if use_dynamic_shape:
635 636 637 638 639 640 641 642 643
            min_input_shape = {
                'image': [batch_size, 3, trt_min_shape, trt_min_shape]
            }
            max_input_shape = {
                'image': [batch_size, 3, trt_max_shape, trt_max_shape]
            }
            opt_input_shape = {
                'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
            }
644 645 646
            config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
                                              opt_input_shape)
            print('trt set dynamic shape done!')
Q
qingqing01 已提交
647 648 649 650 651 652 653 654

    # disable print log when predict
    config.disable_glog_info()
    # enable shared memory
    config.enable_memory_optim()
    # disable feed, fetch OP, needed by zero_copy_run
    config.switch_use_feed_fetch_ops(False)
    predictor = create_predictor(config)
655
    return predictor, config
Q
qingqing01 已提交
656 657


G
Guanghua Yu 已提交
658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688
def get_test_images(infer_dir, infer_img):
    """
    Get image path list in TEST mode
    """
    assert infer_img is not None or infer_dir is not None, \
        "--infer_img or --infer_dir should be set"
    assert infer_img is None or os.path.isfile(infer_img), \
            "{} is not a file".format(infer_img)
    assert infer_dir is None or os.path.isdir(infer_dir), \
            "{} is not a directory".format(infer_dir)

    # infer_img has a higher priority
    if infer_img and os.path.isfile(infer_img):
        return [infer_img]

    images = set()
    infer_dir = os.path.abspath(infer_dir)
    assert os.path.isdir(infer_dir), \
        "infer_dir {} is not a directory".format(infer_dir)
    exts = ['jpg', 'jpeg', 'png', 'bmp']
    exts += [ext.upper() for ext in exts]
    for ext in exts:
        images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
    images = list(images)

    assert len(images) > 0, "no image found in {}".format(infer_dir)
    print("Found {} inference images in total.".format(len(images)))

    return images


W
wangguanzhong 已提交
689
def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
Q
qingqing01 已提交
690
    # visualize the predict result
C
cnn 已提交
691 692
    start_idx = 0
    for idx, image_file in enumerate(image_list):
W
wangguanzhong 已提交
693
        im_bboxes_num = result['boxes_num'][idx]
C
cnn 已提交
694
        im_results = {}
W
wangguanzhong 已提交
695 696 697 698 699 700 701 702 703 704 705 706 707 708 709
        if 'boxes' in result:
            im_results['boxes'] = result['boxes'][start_idx:start_idx +
                                                  im_bboxes_num, :]
        if 'masks' in result:
            im_results['masks'] = result['masks'][start_idx:start_idx +
                                                  im_bboxes_num, :]
        if 'segm' in result:
            im_results['segm'] = result['segm'][start_idx:start_idx +
                                                im_bboxes_num, :]
        if 'label' in result:
            im_results['label'] = result['label'][start_idx:start_idx +
                                                  im_bboxes_num]
        if 'score' in result:
            im_results['score'] = result['score'][start_idx:start_idx +
                                                  im_bboxes_num]
W
wangguanzhong 已提交
710

C
cnn 已提交
711 712 713 714 715 716 717 718 719
        start_idx += im_bboxes_num
        im = visualize_box_mask(
            image_file, im_results, labels, threshold=threshold)
        img_name = os.path.split(image_file)[-1]
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        out_path = os.path.join(output_dir, img_name)
        im.save(out_path, quality=95)
        print("save result to: " + out_path)
Q
qingqing01 已提交
720 721 722 723 724 725 726 727 728 729


def print_arguments(args):
    print('-----------  Running Arguments -----------')
    for arg, value in sorted(vars(args).items()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------')


def main():
W
wangguanzhong 已提交
730 731 732 733
    deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
    with open(deploy_file) as f:
        yml_conf = yaml.safe_load(f)
    arch = yml_conf['arch']
734
    detector_func = 'Detector'
W
wangguanzhong 已提交
735
    if arch == 'SOLOv2':
736
        detector_func = 'DetectorSOLOv2'
W
wangguanzhong 已提交
737
    elif arch == 'PicoDet':
738 739
        detector_func = 'DetectorPicoDet'

W
wangguanzhong 已提交
740
    detector = eval(detector_func)(FLAGS.model_dir,
741 742 743 744 745 746 747 748
                                   device=FLAGS.device,
                                   run_mode=FLAGS.run_mode,
                                   batch_size=FLAGS.batch_size,
                                   trt_min_shape=FLAGS.trt_min_shape,
                                   trt_max_shape=FLAGS.trt_max_shape,
                                   trt_opt_shape=FLAGS.trt_opt_shape,
                                   trt_calib_mode=FLAGS.trt_calib_mode,
                                   cpu_threads=FLAGS.cpu_threads,
W
wangguanzhong 已提交
749 750 751
                                   enable_mkldnn=FLAGS.enable_mkldnn,
                                   threshold=FLAGS.threshold,
                                   output_dir=FLAGS.output_dir)
G
Guanghua Yu 已提交
752

Q
qingqing01 已提交
753
    # predict from video file or camera video stream
G
Guanghua Yu 已提交
754
    if FLAGS.video_file is not None or FLAGS.camera_id != -1:
W
wangguanzhong 已提交
755
        detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
G
Guanghua Yu 已提交
756 757
    else:
        # predict from image
C
cnn 已提交
758 759
        if FLAGS.image_dir is None and FLAGS.image_file is not None:
            assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
G
Guanghua Yu 已提交
760
        img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
W
wangguanzhong 已提交
761
        detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
G
Guanghua Yu 已提交
762 763 764
        if not FLAGS.run_benchmark:
            detector.det_times.info(average=True)
        else:
765
            mode = FLAGS.run_mode
W
wangguanzhong 已提交
766
            model_dir = FLAGS.model_dir
767
            model_info = {
768 769
                'model_name': model_dir.strip('/').split('/')[-1],
                'precision': mode.split('_')[-1]
770
            }
W
wangguanzhong 已提交
771
            bench_log(detector, img_list, model_info, name='DET')
Q
qingqing01 已提交
772 773 774 775


if __name__ == '__main__':
    paddle.enable_static()
G
Guanghua Yu 已提交
776
    parser = argsparser()
Q
qingqing01 已提交
777 778
    FLAGS = parser.parse_args()
    print_arguments(FLAGS)
G
Guanghua Yu 已提交
779 780 781 782
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"
    assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
Q
qingqing01 已提交
783 784

    main()