infer.py 23.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

G
Guanghua Yu 已提交
15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

M
Manuel Garcia 已提交
19 20 21
import os
import sys

G
Guanghua Yu 已提交
22 23 24 25 26
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3)))
if parent_path not in sys.path:
    sys.path.append(parent_path)

27
import argparse
28
import time
29
import yaml
C
channings 已提交
30 31
import ast
from functools import reduce
32

33 34
import cv2
import numpy as np
35
import paddle
36
import paddle.fluid as fluid
G
Guanghua Yu 已提交
37
from preprocess import preprocess, Resize, Normalize, Permute, PadStride
38
from visualize import visualize_box_mask, lmk2out
39

40 41 42 43 44 45 46 47 48 49
# Global dictionary
SUPPORT_MODELS = {
    'YOLO',
    'SSD',
    'RetinaNet',
    'EfficientDet',
    'RCNN',
    'Face',
    'TTF',
    'FCOS',
G
Guanghua Yu 已提交
50
    'SOLOv2',
51 52
}

53

54
class Detector(object):
55 56
    """
    Args:
G
Guanghua Yu 已提交
57 58
        config (object): config of model, defined by `Config(model_dir)`
        model_dir (str): root path of __model__, __params__ and infer_cfg.yml
G
Guanghua Yu 已提交
59
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
G
Guanghua Yu 已提交
60 61
        run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
        threshold (float): threshold to reserve the result for output.
62 63 64
    """

    def __init__(self,
G
Guanghua Yu 已提交
65 66
                 config,
                 model_dir,
G
Guanghua Yu 已提交
67
                 device='CPU',
G
Guanghua Yu 已提交
68
                 run_mode='fluid',
69 70
                 threshold=0.5,
                 trt_calib_mode=False):
G
Guanghua Yu 已提交
71 72 73
        self.config = config
        if self.config.use_python_inference:
            self.executor, self.program, self.fecth_targets = load_executor(
G
Guanghua Yu 已提交
74
                model_dir, device=device)
75
        else:
G
Guanghua Yu 已提交
76 77 78 79
            self.predictor = load_predictor(
                model_dir,
                run_mode=run_mode,
                min_subgraph_size=self.config.min_subgraph_size,
G
Guanghua Yu 已提交
80
                device=device,
81
                trt_calib_mode=trt_calib_mode)
82

G
Guanghua Yu 已提交
83 84 85
    def preprocess(self, im):
        preprocess_ops = []
        for op_info in self.config.preprocess_infos:
86 87
            new_op_info = op_info.copy()
            op_type = new_op_info.pop('type')
G
Guanghua Yu 已提交
88
            if op_type == 'Resize':
89 90
                new_op_info['arch'] = self.config.arch
            preprocess_ops.append(eval(op_type)(**new_op_info))
G
Guanghua Yu 已提交
91 92 93
        im, im_info = preprocess(im, preprocess_ops)
        inputs = create_inputs(im, im_info, self.config.arch)
        return inputs, im_info
94

95
    def postprocess(self, np_boxes, np_masks, np_lmk, im_info, threshold=0.5):
G
Guanghua Yu 已提交
96 97
        # postprocess output of predictor
        results = {}
98 99 100
        if np_lmk is not None:
            results['landmark'] = lmk2out(np_boxes, np_lmk, im_info, threshold)

G
Guanghua Yu 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        if self.config.arch in ['SSD', 'Face']:
            w, h = im_info['origin_shape']
            np_boxes[:, 2] *= h
            np_boxes[:, 3] *= w
            np_boxes[:, 4] *= h
            np_boxes[:, 5] *= w
        expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
        np_boxes = np_boxes[expect_boxes, :]
        for box in np_boxes:
            print('class_id:{:d}, confidence:{:.4f},'
                  'left_top:[{:.2f},{:.2f}],'
                  ' right_bottom:[{:.2f},{:.2f}]'.format(
                      int(box[0]), box[1], box[2], box[3], box[4], box[5]))
        results['boxes'] = np_boxes
        if np_masks is not None:
            np_masks = np_masks[expect_boxes, :, :, :]
            results['masks'] = np_masks
        return results
119

G
Guanghua Yu 已提交
120 121 122 123 124 125 126
    def predict(self,
                image,
                threshold=0.5,
                warmup=0,
                repeats=1,
                run_benchmark=False):
        '''
127
        Args:
G
Guanghua Yu 已提交
128 129
            image (str/np.ndarray): path of image/ np.ndarray read by cv2
            threshold (float): threshold of predicted box' score
130
        Returns:
G
Guanghua Yu 已提交
131 132 133 134 135 136
            results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
                            matix element:[class, score, x_min, y_min, x_max, y_max]
                            MaskRCNN's results include 'masks': np.ndarray:
                            shape:[N, class_num, mask_resolution, mask_resolution]
        '''
        inputs, im_info = self.preprocess(image)
137
        np_boxes, np_masks, np_lmk = None, None, None
G
Guanghua Yu 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        if self.config.use_python_inference:
            for i in range(warmup):
                outs = self.executor.run(self.program,
                                         feed=inputs,
                                         fetch_list=self.fecth_targets,
                                         return_numpy=False)
            t1 = time.time()
            for i in range(repeats):
                outs = self.executor.run(self.program,
                                         feed=inputs,
                                         fetch_list=self.fecth_targets,
                                         return_numpy=False)
            t2 = time.time()
            ms = (t2 - t1) * 1000.0 / repeats
            print("Inference: {} ms per batch image".format(ms))
            np_boxes = np.array(outs[0])
            if self.config.mask_resolution is not None:
                np_masks = np.array(outs[1])
156
        else:
G
Guanghua Yu 已提交
157 158 159 160
            input_names = self.predictor.get_input_names()
            for i in range(len(input_names)):
                input_tensor = self.predictor.get_input_tensor(input_names[i])
                input_tensor.copy_from_cpu(inputs[input_names[i]])
161

G
Guanghua Yu 已提交
162 163 164 165 166 167 168 169 170
            for i in range(warmup):
                self.predictor.zero_copy_run()
                output_names = self.predictor.get_output_names()
                boxes_tensor = self.predictor.get_output_tensor(output_names[0])
                np_boxes = boxes_tensor.copy_to_cpu()
                if self.config.mask_resolution is not None:
                    masks_tensor = self.predictor.get_output_tensor(
                        output_names[1])
                    np_masks = masks_tensor.copy_to_cpu()
171

172 173 174 175 176 177 178 179 180 181 182
                if self.config.with_lmk is not None and self.config.with_lmk == True:
                    face_index = self.predictor.get_output_tensor(output_names[
                        1])
                    landmark = self.predictor.get_output_tensor(output_names[2])
                    prior_boxes = self.predictor.get_output_tensor(output_names[
                        3])
                    np_face_index = face_index.copy_to_cpu()
                    np_prior_boxes = prior_boxes.copy_to_cpu()
                    np_landmark = landmark.copy_to_cpu()
                    np_lmk = [np_face_index, np_landmark, np_prior_boxes]

G
Guanghua Yu 已提交
183 184 185 186 187 188 189 190 191 192
            t1 = time.time()
            for i in range(repeats):
                self.predictor.zero_copy_run()
                output_names = self.predictor.get_output_names()
                boxes_tensor = self.predictor.get_output_tensor(output_names[0])
                np_boxes = boxes_tensor.copy_to_cpu()
                if self.config.mask_resolution is not None:
                    masks_tensor = self.predictor.get_output_tensor(
                        output_names[1])
                    np_masks = masks_tensor.copy_to_cpu()
193 194 195 196 197 198 199 200 201 202 203

                if self.config.with_lmk is not None and self.config.with_lmk == True:
                    face_index = self.predictor.get_output_tensor(output_names[
                        1])
                    landmark = self.predictor.get_output_tensor(output_names[2])
                    prior_boxes = self.predictor.get_output_tensor(output_names[
                        3])
                    np_face_index = face_index.copy_to_cpu()
                    np_prior_boxes = prior_boxes.copy_to_cpu()
                    np_landmark = landmark.copy_to_cpu()
                    np_lmk = [np_face_index, np_landmark, np_prior_boxes]
G
Guanghua Yu 已提交
204 205 206
            t2 = time.time()
            ms = (t2 - t1) * 1000.0 / repeats
            print("Inference: {} ms per batch image".format(ms))
207

G
Guanghua Yu 已提交
208 209 210 211 212 213 214 215
        # do not perform postprocess in benchmark mode
        results = []
        if not run_benchmark:
            if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
                print('[WARNNING] No object detected.')
                results = {'boxes': np.array([])}
            else:
                results = self.postprocess(
216
                    np_boxes, np_masks, np_lmk, im_info, threshold=threshold)
217

G
Guanghua Yu 已提交
218
        return results
219 220


G
Guanghua Yu 已提交
221 222 223 224
class DetectorSOLOv2(Detector):
    def __init__(self,
                 config,
                 model_dir,
G
Guanghua Yu 已提交
225
                 device='CPU',
G
Guanghua Yu 已提交
226
                 run_mode='fluid',
227 228
                 threshold=0.5,
                 trt_calib_mode=False):
G
Guanghua Yu 已提交
229 230 231
        super(DetectorSOLOv2, self).__init__(
            config=config,
            model_dir=model_dir,
G
Guanghua Yu 已提交
232
            device=device,
G
Guanghua Yu 已提交
233
            run_mode=run_mode,
234 235
            threshold=threshold,
            trt_calib_mode=trt_calib_mode)
236

G
Guanghua Yu 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    def predict(self,
                image,
                threshold=0.5,
                warmup=0,
                repeats=1,
                run_benchmark=False):
        inputs, im_info = self.preprocess(image)
        np_label, np_score, np_segms = None, None, None
        if self.config.use_python_inference:
            for i in range(warmup):
                outs = self.executor.run(self.program,
                                         feed=inputs,
                                         fetch_list=self.fecth_targets,
                                         return_numpy=False)
            t1 = time.time()
            for i in range(repeats):
                outs = self.executor.run(self.program,
                                         feed=inputs,
                                         fetch_list=self.fecth_targets,
                                         return_numpy=False)
            t2 = time.time()
            ms = (t2 - t1) * 1000.0 / repeats
            print("Inference: {} ms per batch image".format(ms))
            np_label, np_score, np_segms = np.array(outs[0]), np.array(outs[
                1]), np.array(outs[2])
        else:
            input_names = self.predictor.get_input_names()
            for i in range(len(input_names)):
                input_tensor = self.predictor.get_input_tensor(input_names[i])
                input_tensor.copy_from_cpu(inputs[input_names[i]])
            for i in range(warmup):
                self.predictor.zero_copy_run()
                output_names = self.predictor.get_output_names()
                np_label = self.predictor.get_output_tensor(output_names[
                    0]).copy_to_cpu()
                np_score = self.predictor.get_output_tensor(output_names[
                    1]).copy_to_cpu()
                np_segms = self.predictor.get_output_tensor(output_names[
                    2]).copy_to_cpu()
276

G
Guanghua Yu 已提交
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
            t1 = time.time()
            for i in range(repeats):
                self.predictor.zero_copy_run()
                output_names = self.predictor.get_output_names()
                np_label = self.predictor.get_output_tensor(output_names[
                    0]).copy_to_cpu()
                np_score = self.predictor.get_output_tensor(output_names[
                    1]).copy_to_cpu()
                np_segms = self.predictor.get_output_tensor(output_names[
                    2]).copy_to_cpu()
            t2 = time.time()
            ms = (t2 - t1) * 1000.0 / repeats
            print("Inference: {} ms per batch image".format(ms))

        # do not perform postprocess in benchmark mode
        results = []
        if not run_benchmark:
            return dict(segm=np_segms, label=np_label, score=np_score)
        return results
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310


def create_inputs(im, im_info, model_arch='YOLO'):
    """generate input for different model type
    Args:
        im (np.ndarray): image (np.ndarray)
        im_info (dict): info of image
        model_arch (str): model type
    Returns:
        inputs (dict): input of model
    """
    inputs = {}
    inputs['image'] = im
    origin_shape = list(im_info['origin_shape'])
    resize_shape = list(im_info['resize_shape'])
G
Guanghua Yu 已提交
311 312
    pad_shape = list(im_info['pad_shape']) if im_info[
        'pad_shape'] is not None else list(im_info['resize_shape'])
W
wangguanzhong 已提交
313
    scale_x, scale_y = im_info['scale']
314 315 316
    if 'YOLO' in model_arch:
        im_size = np.array([origin_shape]).astype('int32')
        inputs['im_size'] = im_size
317
    elif 'RetinaNet' in model_arch or 'EfficientDet' in model_arch:
W
wangguanzhong 已提交
318
        scale = scale_x
G
Guanghua Yu 已提交
319
        im_info = np.array([pad_shape + [scale]]).astype('float32')
320
        inputs['im_info'] = im_info
321
    elif ('RCNN' in model_arch) or ('FCOS' in model_arch):
W
wangguanzhong 已提交
322
        scale = scale_x
G
Guanghua Yu 已提交
323
        im_info = np.array([pad_shape + [scale]]).astype('float32')
324 325 326
        im_shape = np.array([origin_shape + [1.]]).astype('float32')
        inputs['im_info'] = im_info
        inputs['im_shape'] = im_shape
W
wangguanzhong 已提交
327 328 329
    elif 'TTF' in model_arch:
        scale_factor = np.array([scale_x, scale_y] * 2).astype('float32')
        inputs['scale_factor'] = scale_factor
G
Guanghua Yu 已提交
330 331 332 333
    elif 'SOLOv2' in model_arch:
        scale = scale_x
        im_info = np.array([resize_shape + [scale]]).astype('float32')
        inputs['im_info'] = im_info
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
    return inputs


class Config():
    """set config of preprocess, postprocess and visualize
    Args:
        model_dir (str): root path of model.yml
    """

    def __init__(self, model_dir):
        # parsing Yaml config for Preprocess
        deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
        with open(deploy_file) as f:
            yml_conf = yaml.safe_load(f)
        self.check_model(yml_conf)
        self.arch = yml_conf['arch']
        self.preprocess_infos = yml_conf['Preprocess']
        self.use_python_inference = yml_conf['use_python_inference']
        self.min_subgraph_size = yml_conf['min_subgraph_size']
        self.labels = yml_conf['label_list']
        self.mask_resolution = None
        if 'mask_resolution' in yml_conf:
            self.mask_resolution = yml_conf['mask_resolution']
357 358 359
        self.with_lmk = None
        if 'with_lmk' in yml_conf:
            self.with_lmk = yml_conf['with_lmk']
C
channings 已提交
360
        self.print_config()
361 362 363 364 365 366

    def check_model(self, yml_conf):
        """
        Raises:
            ValueError: loaded model not in supported model type 
        """
367
        for support_model in SUPPORT_MODELS:
368 369
            if support_model in yml_conf['arch']:
                return True
W
wangguanzhong 已提交
370
        raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
371
            'arch'], SUPPORT_MODELS))
372

C
channings 已提交
373 374 375
    def print_config(self):
        print('-----------  Model Configuration -----------')
        print('%s: %s' % ('Model Arch', self.arch))
376
        print('%s: %s' % ('Use Paddle Executor', self.use_python_inference))
C
channings 已提交
377 378 379 380 381
        print('%s: ' % ('Transform Order'))
        for op_info in self.preprocess_infos:
            print('--%s: %s' % ('transform op', op_info['type']))
        print('--------------------------------------------')

382 383 384 385

def load_predictor(model_dir,
                   run_mode='fluid',
                   batch_size=1,
G
Guanghua Yu 已提交
386
                   device='CPU',
387 388
                   min_subgraph_size=3,
                   trt_calib_mode=False):
389
    """set AnalysisConfig, generate AnalysisPredictor
390 391
    Args:
        model_dir (str): root path of __model__ and __params__
G
Guanghua Yu 已提交
392
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
393 394
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
395 396 397
    Returns:
        predictor (PaddlePredictor): AnalysisPredictor
    Raises:
G
Guanghua Yu 已提交
398
        ValueError: predict by TensorRT need device == GPU.
399
    """
G
Guanghua Yu 已提交
400
    if device != 'GPU' and not run_mode == 'fluid':
401
        raise ValueError(
G
Guanghua Yu 已提交
402 403
            "Predict by TensorRT mode: {}, expect device==GPU, but device == {}"
            .format(run_mode, device))
404
    precision_map = {
C
channings 已提交
405
        'trt_int8': fluid.core.AnalysisConfig.Precision.Int8,
406 407 408 409 410 411
        'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32,
        'trt_fp16': fluid.core.AnalysisConfig.Precision.Half
    }
    config = fluid.core.AnalysisConfig(
        os.path.join(model_dir, '__model__'),
        os.path.join(model_dir, '__params__'))
G
Guanghua Yu 已提交
412
    if device == 'GPU':
413 414 415 416
        # initial GPU memory(M), device ID
        config.enable_use_gpu(100, 0)
        # optimize graph and fuse op
        config.switch_ir_optim(True)
G
Guanghua Yu 已提交
417
    elif device == 'XPU':
418
        config.enable_lite_engine()
G
Guanghua Yu 已提交
419
        config.enable_xpu(10 * 1024 * 1024)
420 421 422 423 424
    else:
        config.disable_gpu()

    if run_mode in precision_map.keys():
        config.enable_tensorrt_engine(
425
            workspace_size=1 << 10,
426 427 428 429
            max_batch_size=batch_size,
            min_subgraph_size=min_subgraph_size,
            precision_mode=precision_map[run_mode],
            use_static=False,
430
            use_calib_mode=trt_calib_mode)
431 432 433 434 435

    # disable print log when predict
    config.disable_glog_info()
    # enable shared memory
    config.enable_memory_optim()
436
    # disable feed, fetch OP, needed by zero_copy_run
437 438 439 440 441
    config.switch_use_feed_fetch_ops(False)
    predictor = fluid.core.create_paddle_predictor(config)
    return predictor


G
Guanghua Yu 已提交
442 443
def load_executor(model_dir, device='CPU'):
    if device == 'GPU':
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
        place = fluid.CUDAPlace(0)
    else:
        place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    program, feed_names, fetch_targets = fluid.io.load_inference_model(
        dirname=model_dir,
        executor=exe,
        model_filename='__model__',
        params_filename='__params__')
    return exe, program, fetch_targets


def visualize(image_file,
              results,
              labels,
              mask_resolution=14,
G
Guanghua Yu 已提交
460 461
              output_dir='output/',
              threshold=0.5):
462 463
    # visualize the predict result
    im = visualize_box_mask(
G
Guanghua Yu 已提交
464 465 466 467 468
        image_file,
        results,
        labels,
        mask_resolution=mask_resolution,
        threshold=threshold)
469 470 471 472 473 474 475 476
    img_name = os.path.split(image_file)[-1]
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    out_path = os.path.join(output_dir, img_name)
    im.save(out_path, quality=95)
    print("save result to: " + out_path)


G
Guanghua Yu 已提交
477 478 479 480 481
def print_arguments(args):
    print('-----------  Running Arguments -----------')
    for arg, value in sorted(vars(args).items()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------')
482 483


G
Guanghua Yu 已提交
484
def predict_image(detector):
C
channings 已提交
485 486
    if FLAGS.run_benchmark:
        detector.predict(
K
Kaipeng Deng 已提交
487 488 489 490 491
            FLAGS.image_file,
            FLAGS.threshold,
            warmup=100,
            repeats=100,
            run_benchmark=True)
C
channings 已提交
492 493 494 495 496 497 498
    else:
        results = detector.predict(FLAGS.image_file, FLAGS.threshold)
        visualize(
            FLAGS.image_file,
            results,
            detector.config.labels,
            mask_resolution=detector.config.mask_resolution,
G
Guanghua Yu 已提交
499 500
            output_dir=FLAGS.output_dir,
            threshold=FLAGS.threshold)
501 502


G
Guanghua Yu 已提交
503
def predict_video(detector, camera_id):
C
channings 已提交
504 505 506 507 508 509
    if camera_id != -1:
        capture = cv2.VideoCapture(camera_id)
        video_name = 'output.mp4'
    else:
        capture = cv2.VideoCapture(FLAGS.video_file)
        video_name = os.path.split(FLAGS.video_file)[-1]
510 511 512
    fps = 30
    width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
M
Manuel Garcia 已提交
513
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
514
    if not os.path.exists(FLAGS.output_dir):
515
        os.makedirs(FLAGS.output_dir)
516 517 518 519 520 521 522 523 524 525 526 527 528 529
    out_path = os.path.join(FLAGS.output_dir, video_name)
    writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
    index = 1
    while (1):
        ret, frame = capture.read()
        if not ret:
            break
        print('detect frame:%d' % (index))
        index += 1
        results = detector.predict(frame, FLAGS.threshold)
        im = visualize_box_mask(
            frame,
            results,
            detector.config.labels,
530 531
            mask_resolution=detector.config.mask_resolution,
            threshold=FLAGS.threshold)
532 533
        im = np.array(im)
        writer.write(im)
C
channings 已提交
534 535 536 537
        if camera_id != -1:
            cv2.imshow('Mask Detection', im)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
538 539 540
    writer.release()


G
Guanghua Yu 已提交
541 542 543
def main():
    config = Config(FLAGS.model_dir)
    detector = Detector(
544 545
        config,
        FLAGS.model_dir,
G
Guanghua Yu 已提交
546
        device=FLAGS.device,
547 548
        run_mode=FLAGS.run_mode,
        trt_calib_mode=FLAGS.trt_calib_mode)
G
Guanghua Yu 已提交
549 550 551 552
    if config.arch == 'SOLOv2':
        detector = DetectorSOLOv2(
            config,
            FLAGS.model_dir,
G
Guanghua Yu 已提交
553
            device=FLAGS.device,
554 555
            run_mode=FLAGS.run_mode,
            trt_calib_mode=FLAGS.trt_calib_mode)
G
Guanghua Yu 已提交
556 557 558 559 560 561
    # predict from image
    if FLAGS.image_file != '':
        predict_image(detector)
    # predict from video file or camera video stream
    if FLAGS.video_file != '' or FLAGS.camera_id != -1:
        predict_video(detector, FLAGS.camera_id)
C
channings 已提交
562 563


564
if __name__ == '__main__':
565 566 567 568
    try:
        paddle.enable_static()
    except:
        pass
569 570 571 572 573 574
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--model_dir",
        type=str,
        default=None,
        help=("Directory include:'__model__', '__params__', "
575
              "'infer_cfg.yml', created by tools/export_model.py."),
576 577 578 579 580
        required=True)
    parser.add_argument(
        "--image_file", type=str, default='', help="Path of image file.")
    parser.add_argument(
        "--video_file", type=str, default='', help="Path of video file.")
C
channings 已提交
581 582 583 584 585
    parser.add_argument(
        "--camera_id",
        type=int,
        default=-1,
        help="device id of camera to predict.")
586 587 588 589
    parser.add_argument(
        "--run_mode",
        type=str,
        default='fluid',
590
        help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)")
G
Guanghua Yu 已提交
591 592 593 594 595 596
    parser.add_argument(
        "--device",
        type=str,
        default='cpu',
        help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."
    )
597
    parser.add_argument(
C
channings 已提交
598 599 600
        "--use_gpu",
        type=ast.literal_eval,
        default=False,
G
Guanghua Yu 已提交
601 602
        help="Deprecated, please use `--device` to set the device you want to run."
    )
C
channings 已提交
603 604 605 606 607
    parser.add_argument(
        "--run_benchmark",
        type=ast.literal_eval,
        default=False,
        help="Whether to predict a image_file repeatedly for benchmark")
608 609 610 611 612 613 614
    parser.add_argument(
        "--threshold", type=float, default=0.5, help="Threshold of score.")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Directory of output visualization files.")
615 616 617 618 619 620
    parser.add_argument(
        "--trt_calib_mode",
        type=bool,
        default=False,
        help="If the model is produced by TRT offline quantitative "
        "calibration, trt_calib_mode need to set True.")
621 622

    FLAGS = parser.parse_args()
C
channings 已提交
623
    print_arguments(FLAGS)
624 625
    if FLAGS.image_file != '' and FLAGS.video_file != '':
        assert "Cannot predict image and video at the same time"
G
Guanghua Yu 已提交
626 627 628 629
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"
    assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
G
Guanghua Yu 已提交
630 631

    main()