infer.py 31.5 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
G
Guanghua Yu 已提交
17
import glob
Q
qingqing01 已提交
18 19 20 21
from functools import reduce

import cv2
import numpy as np
C
cnn 已提交
22
import math
Q
qingqing01 已提交
23 24 25 26
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor

W
wangguanzhong 已提交
27 28 29 30 31
import sys
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)

32
from benchmark_utils import PaddleInferBenchmark
33
from picodet_postprocess import PicoDetPostProcess
34
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, decode_image
W
wangguanzhong 已提交
35
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
G
Guanghua Yu 已提交
36
from visualize import visualize_box_mask
37
from utils import argsparser, Timer, get_current_memory_mb
G
Guanghua Yu 已提交
38

Q
qingqing01 已提交
39 40
# Global dictionary
SUPPORT_MODELS = {
J
JYChen 已提交
41 42 43
    'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
    'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD',
    'StrongBaseline', 'STGCN'
Q
qingqing01 已提交
44 45 46
}


W
wangguanzhong 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
def bench_log(detector, img_list, model_info, batch_size=1, name=None):
    mems = {
        'cpu_rss_mb': detector.cpu_mem / len(img_list),
        'gpu_rss_mb': detector.gpu_mem / len(img_list),
        'gpu_util': detector.gpu_util * 100 / len(img_list)
    }
    perf_info = detector.det_times.report(average=True)
    data_info = {
        'batch_size': batch_size,
        'shape': "dynamic_shape",
        'data_num': perf_info['img_num']
    }
    log = PaddleInferBenchmark(detector.config, model_info, data_info,
                               perf_info, mems)
    log(name)


Q
qingqing01 已提交
64 65 66
class Detector(object):
    """
    Args:
67
        pred_config (object): config of model, defined by `Config(model_dir)`
Q
qingqing01 已提交
68
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
G
Guanghua Yu 已提交
69
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
70
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
71
        batch_size (int): size of pre batch in inference
72 73 74
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
75 76 77 78
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN
79
        enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16
W
wangguanzhong 已提交
80 81
        output_dir (str): The path of output
        threshold (float): The threshold of score for visualization
J
JYChen 已提交
82 83
        delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT. 
                                    Used by action model.
Q
qingqing01 已提交
84 85
    """

J
JYChen 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    def __init__(self,
                 model_dir,
                 device='CPU',
                 run_mode='paddle',
                 batch_size=1,
                 trt_min_shape=1,
                 trt_max_shape=1280,
                 trt_opt_shape=640,
                 trt_calib_mode=False,
                 cpu_threads=1,
                 enable_mkldnn=False,
                 enable_mkldnn_bfloat16=False,
                 output_dir='output',
                 threshold=0.5,
                 delete_shuffle_pass=False):
W
wangguanzhong 已提交
101
        self.pred_config = self.set_config(model_dir)
102
        self.predictor, self.config = load_predictor(
Q
qingqing01 已提交
103 104
            model_dir,
            run_mode=run_mode,
105
            batch_size=batch_size,
Q
qingqing01 已提交
106
            min_subgraph_size=self.pred_config.min_subgraph_size,
G
Guanghua Yu 已提交
107
            device=device,
108
            use_dynamic_shape=self.pred_config.use_dynamic_shape,
109 110
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
G
Guanghua Yu 已提交
111
            trt_opt_shape=trt_opt_shape,
112 113
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
114
            enable_mkldnn=enable_mkldnn,
J
JYChen 已提交
115 116
            enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
            delete_shuffle_pass=delete_shuffle_pass)
G
Guanghua Yu 已提交
117 118
        self.det_times = Timer()
        self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
W
wangguanzhong 已提交
119 120 121 122 123 124
        self.batch_size = batch_size
        self.output_dir = output_dir
        self.threshold = threshold

    def set_config(self, model_dir):
        return PredictConfig(model_dir)
Q
qingqing01 已提交
125

C
cnn 已提交
126
    def preprocess(self, image_list):
Q
qingqing01 已提交
127 128 129 130 131
        preprocess_ops = []
        for op_info in self.pred_config.preprocess_infos:
            new_op_info = op_info.copy()
            op_type = new_op_info.pop('type')
            preprocess_ops.append(eval(op_type)(**new_op_info))
C
cnn 已提交
132 133 134 135

        input_im_lst = []
        input_im_info_lst = []
        for im_path in image_list:
136
            im, im_info = preprocess(im_path, preprocess_ops)
C
cnn 已提交
137 138 139
            input_im_lst.append(im)
            input_im_info_lst.append(im_info)
        inputs = create_inputs(input_im_lst, input_im_info_lst)
W
wangguanzhong 已提交
140 141 142 143 144
        input_names = self.predictor.get_input_names()
        for i in range(len(input_names)):
            input_tensor = self.predictor.get_input_handle(input_names[i])
            input_tensor.copy_from_cpu(inputs[input_names[i]])

Q
qingqing01 已提交
145 146
        return inputs

W
wangguanzhong 已提交
147
    def postprocess(self, inputs, result):
Q
qingqing01 已提交
148
        # postprocess output of predictor
W
wangguanzhong 已提交
149 150 151 152 153 154
        np_boxes_num = result['boxes_num']
        if np_boxes_num[0] <= 0:
            print('[WARNNING] No object detected.')
            result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
        result = {k: v for k, v in result.items() if v is not None}
        return result
Q
qingqing01 已提交
155

156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    def filter_box(self, result, threshold):
        np_boxes_num = result['boxes_num']
        boxes = result['boxes']
        start_idx = 0
        filter_boxes = []
        filter_num = []
        for i in range(len(np_boxes_num)):
            boxes_num = np_boxes_num[i]
            boxes_i = boxes[start_idx:start_idx + boxes_num, :]
            idx = boxes_i[:, 1] > threshold
            filter_boxes_i = boxes_i[idx, :]
            filter_boxes.append(filter_boxes_i)
            filter_num.append(filter_boxes_i.shape[0])
            start_idx += boxes_num
        boxes = np.concatenate(filter_boxes)
        filter_num = np.array(filter_num)
        filter_res = {'boxes': boxes, 'boxes_num': filter_num}
        return filter_res

W
wangguanzhong 已提交
175
    def predict(self, repeats=1):
Q
qingqing01 已提交
176 177
        '''
        Args:
W
wangguanzhong 已提交
178
            repeats (int): repeats number for prediction
Q
qingqing01 已提交
179
        Returns:
W
wangguanzhong 已提交
180
            result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
Q
qingqing01 已提交
181
                            matix element:[class, score, x_min, y_min, x_max, y_max]
W
wangguanzhong 已提交
182
                            MaskRCNN's result include 'masks': np.ndarray:
G
Guanghua Yu 已提交
183
                            shape: [N, im_h, im_w]
Q
qingqing01 已提交
184
        '''
W
wangguanzhong 已提交
185
        # model prediction
W
wangguanzhong 已提交
186
        np_boxes, np_masks = None, None
Q
qingqing01 已提交
187 188 189 190 191
        for i in range(repeats):
            self.predictor.run()
            output_names = self.predictor.get_output_names()
            boxes_tensor = self.predictor.get_output_handle(output_names[0])
            np_boxes = boxes_tensor.copy_to_cpu()
C
cnn 已提交
192 193
            boxes_num = self.predictor.get_output_handle(output_names[1])
            np_boxes_num = boxes_num.copy_to_cpu()
G
Guanghua Yu 已提交
194
            if self.pred_config.mask:
Q
qingqing01 已提交
195 196
                masks_tensor = self.predictor.get_output_handle(output_names[2])
                np_masks = masks_tensor.copy_to_cpu()
W
wangguanzhong 已提交
197 198 199 200 201 202 203 204 205 206 207 208
        result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
        return result

    def merge_batch_result(self, batch_result):
        if len(batch_result) == 1:
            return batch_result[0]
        res_key = batch_result[0].keys()
        results = {k: [] for k in res_key}
        for res in batch_result:
            for k, v in res.items():
                results[k].append(v)
        for k, v in results.items():
209 210
            if k != 'masks':
                results[k] = np.concatenate(v)
W
wangguanzhong 已提交
211
        return results
Q
qingqing01 已提交
212

W
wangguanzhong 已提交
213 214
    def get_timer(self):
        return self.det_times
W
wangguanzhong 已提交
215

W
wangguanzhong 已提交
216 217 218 219 220 221
    def predict_image(self,
                      image_list,
                      run_benchmark=False,
                      repeats=1,
                      visual=True):
        batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
Q
qingqing01 已提交
222
        results = []
W
wangguanzhong 已提交
223 224 225 226 227 228 229 230 231 232 233 234
        for i in range(batch_loop_cnt):
            start_index = i * self.batch_size
            end_index = min((i + 1) * self.batch_size, len(image_list))
            batch_image_list = image_list[start_index:end_index]
            if run_benchmark:
                # preprocess
                inputs = self.preprocess(batch_image_list)  # warmup
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                # model prediction
235
                result = self.predict(repeats=50)  # warmup
W
wangguanzhong 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
                self.det_times.inference_time_s.start()
                result = self.predict(repeats=repeats)
                self.det_times.inference_time_s.end(repeats=repeats)

                # postprocess
                result_warmup = self.postprocess(inputs, result)  # warmup
                self.det_times.postprocess_time_s.start()
                result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()
                self.det_times.img_num += len(batch_image_list)

                cm, gm, gu = get_current_memory_mb()
                self.cpu_mem += cm
                self.gpu_mem += gm
                self.gpu_util += gu
            else:
                # preprocess
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                # model prediction
                self.det_times.inference_time_s.start()
                result = self.predict()
                self.det_times.inference_time_s.end()

                # postprocess
                self.det_times.postprocess_time_s.start()
                result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()
                self.det_times.img_num += len(batch_image_list)

                if visual:
                    visualize(
                        batch_image_list,
                        result,
                        self.pred_config.labels,
                        output_dir=self.output_dir,
                        threshold=self.threshold)

            results.append(result)
            if visual:
                print('Test iter {}'.format(i))

        results = self.merge_batch_result(results)
Q
qingqing01 已提交
281 282
        return results

W
wangguanzhong 已提交
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
    def predict_video(self, video_file, camera_id):
        video_out_name = 'output.mp4'
        if camera_id != -1:
            capture = cv2.VideoCapture(camera_id)
        else:
            capture = cv2.VideoCapture(video_file)
            video_out_name = os.path.split(video_file)[-1]
        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        print("fps: %d, frame_count: %d" % (fps, frame_count))

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
300
        fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
W
wangguanzhong 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
        index = 1
        while (1):
            ret, frame = capture.read()
            if not ret:
                break
            print('detect frame: %d' % (index))
            index += 1
            results = self.predict_image([frame], visual=False)

            im = visualize_box_mask(
                frame,
                results,
                self.pred_config.labels,
                threshold=self.threshold)
            im = np.array(im)
            writer.write(im)
            if camera_id != -1:
                cv2.imshow('Mask Detection', im)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
        writer.release()
W
wangguanzhong 已提交
323

Q
qingqing01 已提交
324

G
Guanghua Yu 已提交
325 326 327 328
class DetectorSOLOv2(Detector):
    """
    Args:
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
G
Guanghua Yu 已提交
329
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
330
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
331
        batch_size (int): size of pre batch in inference
332 333 334
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
335 336 337 338
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN 
339
        enable_mkldnn_bfloat16 (bool): Whether to turn on mkldnn bfloat16
W
wangguanzhong 已提交
340 341 342
        output_dir (str): The path of output
        threshold (float): The threshold of score for visualization
       
G
Guanghua Yu 已提交
343 344
    """

W
wangguanzhong 已提交
345 346
    def __init__(
            self,
G
Guanghua Yu 已提交
347
            model_dir,
W
wangguanzhong 已提交
348 349 350 351 352 353 354 355 356
            device='CPU',
            run_mode='paddle',
            batch_size=1,
            trt_min_shape=1,
            trt_max_shape=1280,
            trt_opt_shape=640,
            trt_calib_mode=False,
            cpu_threads=1,
            enable_mkldnn=False,
357
            enable_mkldnn_bfloat16=False,
W
wangguanzhong 已提交
358 359 360 361 362
            output_dir='./',
            threshold=0.5, ):
        super(DetectorSOLOv2, self).__init__(
            model_dir=model_dir,
            device=device,
G
Guanghua Yu 已提交
363
            run_mode=run_mode,
364
            batch_size=batch_size,
365 366
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
G
Guanghua Yu 已提交
367
            trt_opt_shape=trt_opt_shape,
368 369
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
W
wangguanzhong 已提交
370
            enable_mkldnn=enable_mkldnn,
371
            enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
W
wangguanzhong 已提交
372 373
            output_dir=output_dir,
            threshold=threshold, )
G
Guanghua Yu 已提交
374

W
wangguanzhong 已提交
375
    def predict(self, repeats=1):
G
Guanghua Yu 已提交
376 377
        '''
        Args:
W
wangguanzhong 已提交
378
            repeats (int): repeat number for prediction
G
Guanghua Yu 已提交
379
        Returns:
W
wangguanzhong 已提交
380
            result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
G
Guanghua Yu 已提交
381 382
                            'cate_label': label of segm, shape:[N]
                            'cate_score': confidence score of segm, shape:[N]
G
Guanghua Yu 已提交
383 384 385 386 387
        '''
        np_label, np_score, np_segms = None, None, None
        for i in range(repeats):
            self.predictor.run()
            output_names = self.predictor.get_output_names()
W
wangguanzhong 已提交
388 389
            np_boxes_num = self.predictor.get_output_handle(output_names[
                0]).copy_to_cpu()
G
Guanghua Yu 已提交
390 391
            np_label = self.predictor.get_output_handle(output_names[
                1]).copy_to_cpu()
G
Guanghua Yu 已提交
392
            np_score = self.predictor.get_output_handle(output_names[
G
Guanghua Yu 已提交
393
                2]).copy_to_cpu()
G
Guanghua Yu 已提交
394 395
            np_segms = self.predictor.get_output_handle(output_names[
                3]).copy_to_cpu()
G
Guanghua Yu 已提交
396

W
wangguanzhong 已提交
397
        result = dict(
W
wangguanzhong 已提交
398 399 400 401
            segm=np_segms,
            label=np_label,
            score=np_score,
            boxes_num=np_boxes_num)
W
wangguanzhong 已提交
402
        return result
G
Guanghua Yu 已提交
403 404


405 406 407 408 409
class DetectorPicoDet(Detector):
    """
    Args:
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
410
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
411 412 413 414 415 416 417
        batch_size (int): size of pre batch in inference
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
418 419
        enable_mkldnn (bool): whether to turn on MKLDNN
        enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
420 421
    """

W
wangguanzhong 已提交
422 423
    def __init__(
            self,
424
            model_dir,
W
wangguanzhong 已提交
425 426 427 428 429 430 431 432 433
            device='CPU',
            run_mode='paddle',
            batch_size=1,
            trt_min_shape=1,
            trt_max_shape=1280,
            trt_opt_shape=640,
            trt_calib_mode=False,
            cpu_threads=1,
            enable_mkldnn=False,
434
            enable_mkldnn_bfloat16=False,
W
wangguanzhong 已提交
435 436 437 438 439
            output_dir='./',
            threshold=0.5, ):
        super(DetectorPicoDet, self).__init__(
            model_dir=model_dir,
            device=device,
440 441 442 443 444 445 446
            run_mode=run_mode,
            batch_size=batch_size,
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
            trt_opt_shape=trt_opt_shape,
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
W
wangguanzhong 已提交
447
            enable_mkldnn=enable_mkldnn,
448
            enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
W
wangguanzhong 已提交
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
            output_dir=output_dir,
            threshold=threshold, )

    def postprocess(self, inputs, result):
        # postprocess output of predictor
        np_score_list = result['boxes']
        np_boxes_list = result['boxes_num']
        postprocessor = PicoDetPostProcess(
            inputs['image'].shape[2:],
            inputs['im_shape'],
            inputs['scale_factor'],
            strides=self.pred_config.fpn_stride,
            nms_threshold=self.pred_config.nms['nms_threshold'])
        np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
        result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
        return result
465

W
wangguanzhong 已提交
466
    def predict(self, repeats=1):
467 468
        '''
        Args:
W
wangguanzhong 已提交
469
            repeats (int): repeat number for prediction
470
        Returns:
W
wangguanzhong 已提交
471
            result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
                            matix element:[class, score, x_min, y_min, x_max, y_max]
        '''
        np_score_list, np_boxes_list = [], []
        for i in range(repeats):
            self.predictor.run()
            np_score_list.clear()
            np_boxes_list.clear()
            output_names = self.predictor.get_output_names()
            num_outs = int(len(output_names) / 2)
            for out_idx in range(num_outs):
                np_score_list.append(
                    self.predictor.get_output_handle(output_names[out_idx])
                    .copy_to_cpu())
                np_boxes_list.append(
                    self.predictor.get_output_handle(output_names[
                        out_idx + num_outs]).copy_to_cpu())
W
wangguanzhong 已提交
488 489
        result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
        return result
490 491


C
cnn 已提交
492
def create_inputs(imgs, im_info):
Q
qingqing01 已提交
493 494
    """generate input for different model type
    Args:
W
wangguanzhong 已提交
495 496
        imgs (list(numpy)): list of images (np.ndarray)
        im_info (list(dict)): list of image info
Q
qingqing01 已提交
497 498 499 500 501
    Returns:
        inputs (dict): input of model
    """
    inputs = {}

C
cnn 已提交
502 503
    im_shape = []
    scale_factor = []
504 505 506 507 508 509 510 511
    if len(imgs) == 1:
        inputs['image'] = np.array((imgs[0], )).astype('float32')
        inputs['im_shape'] = np.array(
            (im_info[0]['im_shape'], )).astype('float32')
        inputs['scale_factor'] = np.array(
            (im_info[0]['scale_factor'], )).astype('float32')
        return inputs

C
cnn 已提交
512 513 514 515
    for e in im_info:
        im_shape.append(np.array((e['im_shape'], )).astype('float32'))
        scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))

C
cnn 已提交
516 517
    inputs['im_shape'] = np.concatenate(im_shape, axis=0)
    inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
C
cnn 已提交
518 519 520 521 522 523 524 525 526 527 528 529

    imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
    max_shape_h = max([e[0] for e in imgs_shape])
    max_shape_w = max([e[1] for e in imgs_shape])
    padding_imgs = []
    for img in imgs:
        im_c, im_h, im_w = img.shape[:]
        padding_im = np.zeros(
            (im_c, max_shape_h, max_shape_w), dtype=np.float32)
        padding_im[:, :im_h, :im_w] = img
        padding_imgs.append(padding_im)
    inputs['image'] = np.stack(padding_imgs, axis=0)
Q
qingqing01 已提交
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
    return inputs


class PredictConfig():
    """set config of preprocess, postprocess and visualize
    Args:
        model_dir (str): root path of model.yml
    """

    def __init__(self, model_dir):
        # parsing Yaml config for Preprocess
        deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
        with open(deploy_file) as f:
            yml_conf = yaml.safe_load(f)
        self.check_model(yml_conf)
        self.arch = yml_conf['arch']
        self.preprocess_infos = yml_conf['Preprocess']
        self.min_subgraph_size = yml_conf['min_subgraph_size']
        self.labels = yml_conf['label_list']
G
Guanghua Yu 已提交
549
        self.mask = False
550
        self.use_dynamic_shape = yml_conf['use_dynamic_shape']
G
Guanghua Yu 已提交
551 552
        if 'mask' in yml_conf:
            self.mask = yml_conf['mask']
553 554 555
        self.tracker = None
        if 'tracker' in yml_conf:
            self.tracker = yml_conf['tracker']
556 557 558 559
        if 'NMS' in yml_conf:
            self.nms = yml_conf['NMS']
        if 'fpn_stride' in yml_conf:
            self.fpn_stride = yml_conf['fpn_stride']
560 561 562 563
        if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
            print(
                'The RCNN export model is used for ONNX and it only supports batch_size = 1'
            )
Q
qingqing01 已提交
564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
        self.print_config()

    def check_model(self, yml_conf):
        """
        Raises:
            ValueError: loaded model not in supported model type 
        """
        for support_model in SUPPORT_MODELS:
            if support_model in yml_conf['arch']:
                return True
        raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
            'arch'], SUPPORT_MODELS))

    def print_config(self):
        print('-----------  Model Configuration -----------')
        print('%s: %s' % ('Model Arch', self.arch))
        print('%s: ' % ('Transform Order'))
        for op_info in self.preprocess_infos:
            print('--%s: %s' % ('transform op', op_info['type']))
        print('--------------------------------------------')


def load_predictor(model_dir,
587
                   run_mode='paddle',
Q
qingqing01 已提交
588
                   batch_size=1,
G
Guanghua Yu 已提交
589
                   device='CPU',
590 591 592 593
                   min_subgraph_size=3,
                   use_dynamic_shape=False,
                   trt_min_shape=1,
                   trt_max_shape=1280,
G
Guanghua Yu 已提交
594
                   trt_opt_shape=640,
595 596
                   trt_calib_mode=False,
                   cpu_threads=1,
597
                   enable_mkldnn=False,
J
JYChen 已提交
598 599
                   enable_mkldnn_bfloat16=False,
                   delete_shuffle_pass=False):
Q
qingqing01 已提交
600 601 602
    """set AnalysisConfig, generate AnalysisPredictor
    Args:
        model_dir (str): root path of __model__ and __params__
G
Guanghua Yu 已提交
603
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
604
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
605 606 607 608
        use_dynamic_shape (bool): use dynamic shape or not
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
G
Guanghua Yu 已提交
609 610
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
J
JYChen 已提交
611 612
        delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT. 
                                    Used by action model.
Q
qingqing01 已提交
613 614 615
    Returns:
        predictor (PaddlePredictor): AnalysisPredictor
    Raises:
G
Guanghua Yu 已提交
616
        ValueError: predict by TensorRT need device == 'GPU'.
Q
qingqing01 已提交
617
    """
618
    if device != 'GPU' and run_mode != 'paddle':
Q
qingqing01 已提交
619
        raise ValueError(
G
Guanghua Yu 已提交
620 621
            "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
            .format(run_mode, device))
Q
qingqing01 已提交
622 623 624
    config = Config(
        os.path.join(model_dir, 'model.pdmodel'),
        os.path.join(model_dir, 'model.pdiparams'))
G
Guanghua Yu 已提交
625
    if device == 'GPU':
Q
qingqing01 已提交
626 627 628
        # initial GPU memory(M), device ID
        config.enable_use_gpu(200, 0)
        # optimize graph and fuse op
629
        config.switch_ir_optim(True)
G
Guanghua Yu 已提交
630
    elif device == 'XPU':
631
        config.enable_lite_engine()
G
Guanghua Yu 已提交
632
        config.enable_xpu(10 * 1024 * 1024)
Q
qingqing01 已提交
633 634
    else:
        config.disable_gpu()
635 636
        config.set_cpu_math_library_num_threads(cpu_threads)
        if enable_mkldnn:
G
Guanghua Yu 已提交
637 638 639 640
            try:
                # cache 10 different shapes for mkldnn to avoid memory leak
                config.set_mkldnn_cache_capacity(10)
                config.enable_mkldnn()
641 642
                if enable_mkldnn_bfloat16:
                    config.enable_mkldnn_bfloat16()
G
Guanghua Yu 已提交
643 644 645 646 647
            except Exception as e:
                print(
                    "The current environment does not support `mkldnn`, so disable mkldnn."
                )
                pass
Q
qingqing01 已提交
648

G
Guanghua Yu 已提交
649 650 651 652 653
    precision_map = {
        'trt_int8': Config.Precision.Int8,
        'trt_fp32': Config.Precision.Float32,
        'trt_fp16': Config.Precision.Half
    }
Q
qingqing01 已提交
654 655
    if run_mode in precision_map.keys():
        config.enable_tensorrt_engine(
656
            workspace_size=(1 << 25) * batch_size,
Q
qingqing01 已提交
657 658 659 660
            max_batch_size=batch_size,
            min_subgraph_size=min_subgraph_size,
            precision_mode=precision_map[run_mode],
            use_static=False,
G
Guanghua Yu 已提交
661
            use_calib_mode=trt_calib_mode)
662 663

        if use_dynamic_shape:
664 665 666 667 668 669 670 671 672
            min_input_shape = {
                'image': [batch_size, 3, trt_min_shape, trt_min_shape]
            }
            max_input_shape = {
                'image': [batch_size, 3, trt_max_shape, trt_max_shape]
            }
            opt_input_shape = {
                'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
            }
673 674 675
            config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
                                              opt_input_shape)
            print('trt set dynamic shape done!')
Q
qingqing01 已提交
676 677 678 679 680 681 682

    # disable print log when predict
    config.disable_glog_info()
    # enable shared memory
    config.enable_memory_optim()
    # disable feed, fetch OP, needed by zero_copy_run
    config.switch_use_feed_fetch_ops(False)
J
JYChen 已提交
683 684
    if delete_shuffle_pass:
        config.delete_pass("shuffle_channel_detect_pass")
Q
qingqing01 已提交
685
    predictor = create_predictor(config)
686
    return predictor, config
Q
qingqing01 已提交
687 688


G
Guanghua Yu 已提交
689 690 691 692 693
def get_test_images(infer_dir, infer_img):
    """
    Get image path list in TEST mode
    """
    assert infer_img is not None or infer_dir is not None, \
694
        "--image_file or --image_dir should be set"
G
Guanghua Yu 已提交
695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719
    assert infer_img is None or os.path.isfile(infer_img), \
            "{} is not a file".format(infer_img)
    assert infer_dir is None or os.path.isdir(infer_dir), \
            "{} is not a directory".format(infer_dir)

    # infer_img has a higher priority
    if infer_img and os.path.isfile(infer_img):
        return [infer_img]

    images = set()
    infer_dir = os.path.abspath(infer_dir)
    assert os.path.isdir(infer_dir), \
        "infer_dir {} is not a directory".format(infer_dir)
    exts = ['jpg', 'jpeg', 'png', 'bmp']
    exts += [ext.upper() for ext in exts]
    for ext in exts:
        images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
    images = list(images)

    assert len(images) > 0, "no image found in {}".format(infer_dir)
    print("Found {} inference images in total.".format(len(images)))

    return images


W
wangguanzhong 已提交
720
def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
Q
qingqing01 已提交
721
    # visualize the predict result
C
cnn 已提交
722 723
    start_idx = 0
    for idx, image_file in enumerate(image_list):
W
wangguanzhong 已提交
724
        im_bboxes_num = result['boxes_num'][idx]
C
cnn 已提交
725
        im_results = {}
W
wangguanzhong 已提交
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740
        if 'boxes' in result:
            im_results['boxes'] = result['boxes'][start_idx:start_idx +
                                                  im_bboxes_num, :]
        if 'masks' in result:
            im_results['masks'] = result['masks'][start_idx:start_idx +
                                                  im_bboxes_num, :]
        if 'segm' in result:
            im_results['segm'] = result['segm'][start_idx:start_idx +
                                                im_bboxes_num, :]
        if 'label' in result:
            im_results['label'] = result['label'][start_idx:start_idx +
                                                  im_bboxes_num]
        if 'score' in result:
            im_results['score'] = result['score'][start_idx:start_idx +
                                                  im_bboxes_num]
W
wangguanzhong 已提交
741

C
cnn 已提交
742 743 744 745 746 747 748 749 750
        start_idx += im_bboxes_num
        im = visualize_box_mask(
            image_file, im_results, labels, threshold=threshold)
        img_name = os.path.split(image_file)[-1]
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        out_path = os.path.join(output_dir, img_name)
        im.save(out_path, quality=95)
        print("save result to: " + out_path)
Q
qingqing01 已提交
751 752 753 754 755 756 757 758 759 760


def print_arguments(args):
    print('-----------  Running Arguments -----------')
    for arg, value in sorted(vars(args).items()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------')


def main():
W
wangguanzhong 已提交
761 762 763 764
    deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
    with open(deploy_file) as f:
        yml_conf = yaml.safe_load(f)
    arch = yml_conf['arch']
765
    detector_func = 'Detector'
W
wangguanzhong 已提交
766
    if arch == 'SOLOv2':
767
        detector_func = 'DetectorSOLOv2'
W
wangguanzhong 已提交
768
    elif arch == 'PicoDet':
769 770
        detector_func = 'DetectorPicoDet'

771 772 773 774 775 776 777 778 779 780 781 782 783 784
    detector = eval(detector_func)(
        FLAGS.model_dir,
        device=FLAGS.device,
        run_mode=FLAGS.run_mode,
        batch_size=FLAGS.batch_size,
        trt_min_shape=FLAGS.trt_min_shape,
        trt_max_shape=FLAGS.trt_max_shape,
        trt_opt_shape=FLAGS.trt_opt_shape,
        trt_calib_mode=FLAGS.trt_calib_mode,
        cpu_threads=FLAGS.cpu_threads,
        enable_mkldnn=FLAGS.enable_mkldnn,
        enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
        threshold=FLAGS.threshold,
        output_dir=FLAGS.output_dir)
G
Guanghua Yu 已提交
785

Q
qingqing01 已提交
786
    # predict from video file or camera video stream
G
Guanghua Yu 已提交
787
    if FLAGS.video_file is not None or FLAGS.camera_id != -1:
W
wangguanzhong 已提交
788
        detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
G
Guanghua Yu 已提交
789 790
    else:
        # predict from image
C
cnn 已提交
791 792
        if FLAGS.image_dir is None and FLAGS.image_file is not None:
            assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
G
Guanghua Yu 已提交
793
        img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
794
        detector.predict_image(img_list, FLAGS.run_benchmark, repeats=100)
G
Guanghua Yu 已提交
795 796 797
        if not FLAGS.run_benchmark:
            detector.det_times.info(average=True)
        else:
798
            mode = FLAGS.run_mode
W
wangguanzhong 已提交
799
            model_dir = FLAGS.model_dir
800
            model_info = {
801 802
                'model_name': model_dir.strip('/').split('/')[-1],
                'precision': mode.split('_')[-1]
803
            }
W
wangguanzhong 已提交
804
            bench_log(detector, img_list, model_info, name='DET')
Q
qingqing01 已提交
805 806 807 808


if __name__ == '__main__':
    paddle.enable_static()
G
Guanghua Yu 已提交
809
    parser = argsparser()
Q
qingqing01 已提交
810 811
    FLAGS = parser.parse_args()
    print_arguments(FLAGS)
G
Guanghua Yu 已提交
812 813 814 815
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"
    assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
Q
qingqing01 已提交
816

817 818 819
    assert not (
        FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True
    ), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'
820

Q
qingqing01 已提交
821
    main()