infer.py 22.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

G
Guanghua Yu 已提交
15 16 17 18 19 20 21 22 23 24
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os, sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3)))
if parent_path not in sys.path:
    sys.path.append(parent_path)

25
import argparse
26
import time
27
import yaml
C
channings 已提交
28 29
import ast
from functools import reduce
30

31 32 33
from PIL import Image
import cv2
import numpy as np
34
import paddle
35
import paddle.fluid as fluid
G
Guanghua Yu 已提交
36
from preprocess import preprocess, Resize, Normalize, Permute, PadStride
37
from visualize import visualize_box_mask, lmk2out
38

39 40 41 42 43 44 45 46 47 48
# Global dictionary
SUPPORT_MODELS = {
    'YOLO',
    'SSD',
    'RetinaNet',
    'EfficientDet',
    'RCNN',
    'Face',
    'TTF',
    'FCOS',
G
Guanghua Yu 已提交
49
    'SOLOv2',
50 51
}

52

53
class Detector(object):
54 55
    """
    Args:
G
Guanghua Yu 已提交
56 57 58 59 60
        config (object): config of model, defined by `Config(model_dir)`
        model_dir (str): root path of __model__, __params__ and infer_cfg.yml
        use_gpu (bool): whether use gpu
        run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
        threshold (float): threshold to reserve the result for output.
61 62 63
    """

    def __init__(self,
G
Guanghua Yu 已提交
64 65 66 67
                 config,
                 model_dir,
                 use_gpu=False,
                 run_mode='fluid',
68 69
                 threshold=0.5,
                 trt_calib_mode=False):
G
Guanghua Yu 已提交
70 71 72 73
        self.config = config
        if self.config.use_python_inference:
            self.executor, self.program, self.fecth_targets = load_executor(
                model_dir, use_gpu=use_gpu)
74
        else:
G
Guanghua Yu 已提交
75 76 77 78
            self.predictor = load_predictor(
                model_dir,
                run_mode=run_mode,
                min_subgraph_size=self.config.min_subgraph_size,
79 80
                use_gpu=use_gpu,
                trt_calib_mode=trt_calib_mode)
81

G
Guanghua Yu 已提交
82 83 84
    def preprocess(self, im):
        preprocess_ops = []
        for op_info in self.config.preprocess_infos:
85 86
            new_op_info = op_info.copy()
            op_type = new_op_info.pop('type')
G
Guanghua Yu 已提交
87
            if op_type == 'Resize':
88 89
                new_op_info['arch'] = self.config.arch
            preprocess_ops.append(eval(op_type)(**new_op_info))
G
Guanghua Yu 已提交
90 91 92
        im, im_info = preprocess(im, preprocess_ops)
        inputs = create_inputs(im, im_info, self.config.arch)
        return inputs, im_info
93

94
    def postprocess(self, np_boxes, np_masks, np_lmk, im_info, threshold=0.5):
G
Guanghua Yu 已提交
95 96
        # postprocess output of predictor
        results = {}
97 98 99
        if np_lmk is not None:
            results['landmark'] = lmk2out(np_boxes, np_lmk, im_info, threshold)

G
Guanghua Yu 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
        if self.config.arch in ['SSD', 'Face']:
            w, h = im_info['origin_shape']
            np_boxes[:, 2] *= h
            np_boxes[:, 3] *= w
            np_boxes[:, 4] *= h
            np_boxes[:, 5] *= w
        expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
        np_boxes = np_boxes[expect_boxes, :]
        for box in np_boxes:
            print('class_id:{:d}, confidence:{:.4f},'
                  'left_top:[{:.2f},{:.2f}],'
                  ' right_bottom:[{:.2f},{:.2f}]'.format(
                      int(box[0]), box[1], box[2], box[3], box[4], box[5]))
        results['boxes'] = np_boxes
        if np_masks is not None:
            np_masks = np_masks[expect_boxes, :, :, :]
            results['masks'] = np_masks
        return results
118

G
Guanghua Yu 已提交
119 120 121 122 123 124 125
    def predict(self,
                image,
                threshold=0.5,
                warmup=0,
                repeats=1,
                run_benchmark=False):
        '''
126
        Args:
G
Guanghua Yu 已提交
127 128
            image (str/np.ndarray): path of image/ np.ndarray read by cv2
            threshold (float): threshold of predicted box' score
129
        Returns:
G
Guanghua Yu 已提交
130 131 132 133 134 135
            results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
                            matix element:[class, score, x_min, y_min, x_max, y_max]
                            MaskRCNN's results include 'masks': np.ndarray:
                            shape:[N, class_num, mask_resolution, mask_resolution]
        '''
        inputs, im_info = self.preprocess(image)
136
        np_boxes, np_masks, np_lmk = None, None, None
G
Guanghua Yu 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
        if self.config.use_python_inference:
            for i in range(warmup):
                outs = self.executor.run(self.program,
                                         feed=inputs,
                                         fetch_list=self.fecth_targets,
                                         return_numpy=False)
            t1 = time.time()
            for i in range(repeats):
                outs = self.executor.run(self.program,
                                         feed=inputs,
                                         fetch_list=self.fecth_targets,
                                         return_numpy=False)
            t2 = time.time()
            ms = (t2 - t1) * 1000.0 / repeats
            print("Inference: {} ms per batch image".format(ms))
            np_boxes = np.array(outs[0])
            if self.config.mask_resolution is not None:
                np_masks = np.array(outs[1])
155
        else:
G
Guanghua Yu 已提交
156 157 158 159
            input_names = self.predictor.get_input_names()
            for i in range(len(input_names)):
                input_tensor = self.predictor.get_input_tensor(input_names[i])
                input_tensor.copy_from_cpu(inputs[input_names[i]])
160

G
Guanghua Yu 已提交
161 162 163 164 165 166 167 168 169
            for i in range(warmup):
                self.predictor.zero_copy_run()
                output_names = self.predictor.get_output_names()
                boxes_tensor = self.predictor.get_output_tensor(output_names[0])
                np_boxes = boxes_tensor.copy_to_cpu()
                if self.config.mask_resolution is not None:
                    masks_tensor = self.predictor.get_output_tensor(
                        output_names[1])
                    np_masks = masks_tensor.copy_to_cpu()
170

171 172 173 174 175 176 177 178 179 180 181
                if self.config.with_lmk is not None and self.config.with_lmk == True:
                    face_index = self.predictor.get_output_tensor(output_names[
                        1])
                    landmark = self.predictor.get_output_tensor(output_names[2])
                    prior_boxes = self.predictor.get_output_tensor(output_names[
                        3])
                    np_face_index = face_index.copy_to_cpu()
                    np_prior_boxes = prior_boxes.copy_to_cpu()
                    np_landmark = landmark.copy_to_cpu()
                    np_lmk = [np_face_index, np_landmark, np_prior_boxes]

G
Guanghua Yu 已提交
182 183 184 185 186 187 188 189 190 191
            t1 = time.time()
            for i in range(repeats):
                self.predictor.zero_copy_run()
                output_names = self.predictor.get_output_names()
                boxes_tensor = self.predictor.get_output_tensor(output_names[0])
                np_boxes = boxes_tensor.copy_to_cpu()
                if self.config.mask_resolution is not None:
                    masks_tensor = self.predictor.get_output_tensor(
                        output_names[1])
                    np_masks = masks_tensor.copy_to_cpu()
192 193 194 195 196 197 198 199 200 201 202

                if self.config.with_lmk is not None and self.config.with_lmk == True:
                    face_index = self.predictor.get_output_tensor(output_names[
                        1])
                    landmark = self.predictor.get_output_tensor(output_names[2])
                    prior_boxes = self.predictor.get_output_tensor(output_names[
                        3])
                    np_face_index = face_index.copy_to_cpu()
                    np_prior_boxes = prior_boxes.copy_to_cpu()
                    np_landmark = landmark.copy_to_cpu()
                    np_lmk = [np_face_index, np_landmark, np_prior_boxes]
G
Guanghua Yu 已提交
203 204 205
            t2 = time.time()
            ms = (t2 - t1) * 1000.0 / repeats
            print("Inference: {} ms per batch image".format(ms))
206

G
Guanghua Yu 已提交
207 208 209 210 211 212 213 214
        # do not perform postprocess in benchmark mode
        results = []
        if not run_benchmark:
            if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
                print('[WARNNING] No object detected.')
                results = {'boxes': np.array([])}
            else:
                results = self.postprocess(
215
                    np_boxes, np_masks, np_lmk, im_info, threshold=threshold)
216

G
Guanghua Yu 已提交
217
        return results
218 219


G
Guanghua Yu 已提交
220 221 222 223 224 225
class DetectorSOLOv2(Detector):
    def __init__(self,
                 config,
                 model_dir,
                 use_gpu=False,
                 run_mode='fluid',
226 227
                 threshold=0.5,
                 trt_calib_mode=False):
G
Guanghua Yu 已提交
228 229 230 231 232
        super(DetectorSOLOv2, self).__init__(
            config=config,
            model_dir=model_dir,
            use_gpu=use_gpu,
            run_mode=run_mode,
233 234
            threshold=threshold,
            trt_calib_mode=trt_calib_mode)
235

G
Guanghua Yu 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
    def predict(self,
                image,
                threshold=0.5,
                warmup=0,
                repeats=1,
                run_benchmark=False):
        inputs, im_info = self.preprocess(image)
        np_label, np_score, np_segms = None, None, None
        if self.config.use_python_inference:
            for i in range(warmup):
                outs = self.executor.run(self.program,
                                         feed=inputs,
                                         fetch_list=self.fecth_targets,
                                         return_numpy=False)
            t1 = time.time()
            for i in range(repeats):
                outs = self.executor.run(self.program,
                                         feed=inputs,
                                         fetch_list=self.fecth_targets,
                                         return_numpy=False)
            t2 = time.time()
            ms = (t2 - t1) * 1000.0 / repeats
            print("Inference: {} ms per batch image".format(ms))
            np_label, np_score, np_segms = np.array(outs[0]), np.array(outs[
                1]), np.array(outs[2])
        else:
            input_names = self.predictor.get_input_names()
            for i in range(len(input_names)):
                input_tensor = self.predictor.get_input_tensor(input_names[i])
                input_tensor.copy_from_cpu(inputs[input_names[i]])
            for i in range(warmup):
                self.predictor.zero_copy_run()
                output_names = self.predictor.get_output_names()
                np_label = self.predictor.get_output_tensor(output_names[
                    0]).copy_to_cpu()
                np_score = self.predictor.get_output_tensor(output_names[
                    1]).copy_to_cpu()
                np_segms = self.predictor.get_output_tensor(output_names[
                    2]).copy_to_cpu()
275

G
Guanghua Yu 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
            t1 = time.time()
            for i in range(repeats):
                self.predictor.zero_copy_run()
                output_names = self.predictor.get_output_names()
                np_label = self.predictor.get_output_tensor(output_names[
                    0]).copy_to_cpu()
                np_score = self.predictor.get_output_tensor(output_names[
                    1]).copy_to_cpu()
                np_segms = self.predictor.get_output_tensor(output_names[
                    2]).copy_to_cpu()
            t2 = time.time()
            ms = (t2 - t1) * 1000.0 / repeats
            print("Inference: {} ms per batch image".format(ms))

        # do not perform postprocess in benchmark mode
        results = []
        if not run_benchmark:
            return dict(segm=np_segms, label=np_label, score=np_score)
        return results
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309


def create_inputs(im, im_info, model_arch='YOLO'):
    """generate input for different model type
    Args:
        im (np.ndarray): image (np.ndarray)
        im_info (dict): info of image
        model_arch (str): model type
    Returns:
        inputs (dict): input of model
    """
    inputs = {}
    inputs['image'] = im
    origin_shape = list(im_info['origin_shape'])
    resize_shape = list(im_info['resize_shape'])
G
Guanghua Yu 已提交
310 311
    pad_shape = list(im_info['pad_shape']) if im_info[
        'pad_shape'] is not None else list(im_info['resize_shape'])
W
wangguanzhong 已提交
312
    scale_x, scale_y = im_info['scale']
313 314 315
    if 'YOLO' in model_arch:
        im_size = np.array([origin_shape]).astype('int32')
        inputs['im_size'] = im_size
316
    elif 'RetinaNet' in model_arch or 'EfficientDet' in model_arch:
W
wangguanzhong 已提交
317
        scale = scale_x
G
Guanghua Yu 已提交
318
        im_info = np.array([pad_shape + [scale]]).astype('float32')
319
        inputs['im_info'] = im_info
320
    elif ('RCNN' in model_arch) or ('FCOS' in model_arch):
W
wangguanzhong 已提交
321
        scale = scale_x
G
Guanghua Yu 已提交
322
        im_info = np.array([pad_shape + [scale]]).astype('float32')
323 324 325
        im_shape = np.array([origin_shape + [1.]]).astype('float32')
        inputs['im_info'] = im_info
        inputs['im_shape'] = im_shape
W
wangguanzhong 已提交
326 327 328
    elif 'TTF' in model_arch:
        scale_factor = np.array([scale_x, scale_y] * 2).astype('float32')
        inputs['scale_factor'] = scale_factor
G
Guanghua Yu 已提交
329 330 331 332
    elif 'SOLOv2' in model_arch:
        scale = scale_x
        im_info = np.array([resize_shape + [scale]]).astype('float32')
        inputs['im_info'] = im_info
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
    return inputs


class Config():
    """set config of preprocess, postprocess and visualize
    Args:
        model_dir (str): root path of model.yml
    """

    def __init__(self, model_dir):
        # parsing Yaml config for Preprocess
        deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
        with open(deploy_file) as f:
            yml_conf = yaml.safe_load(f)
        self.check_model(yml_conf)
        self.arch = yml_conf['arch']
        self.preprocess_infos = yml_conf['Preprocess']
        self.use_python_inference = yml_conf['use_python_inference']
        self.min_subgraph_size = yml_conf['min_subgraph_size']
        self.labels = yml_conf['label_list']
        self.mask_resolution = None
        if 'mask_resolution' in yml_conf:
            self.mask_resolution = yml_conf['mask_resolution']
356 357 358
        self.with_lmk = None
        if 'with_lmk' in yml_conf:
            self.with_lmk = yml_conf['with_lmk']
C
channings 已提交
359
        self.print_config()
360 361 362 363 364 365

    def check_model(self, yml_conf):
        """
        Raises:
            ValueError: loaded model not in supported model type 
        """
366
        for support_model in SUPPORT_MODELS:
367 368
            if support_model in yml_conf['arch']:
                return True
W
wangguanzhong 已提交
369
        raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
370
            'arch'], SUPPORT_MODELS))
371

C
channings 已提交
372 373 374
    def print_config(self):
        print('-----------  Model Configuration -----------')
        print('%s: %s' % ('Model Arch', self.arch))
375
        print('%s: %s' % ('Use Paddle Executor', self.use_python_inference))
C
channings 已提交
376 377 378 379 380
        print('%s: ' % ('Transform Order'))
        for op_info in self.preprocess_infos:
            print('--%s: %s' % ('transform op', op_info['type']))
        print('--------------------------------------------')

381 382 383 384 385

def load_predictor(model_dir,
                   run_mode='fluid',
                   batch_size=1,
                   use_gpu=False,
386 387
                   min_subgraph_size=3,
                   trt_calib_mode=False):
388
    """set AnalysisConfig, generate AnalysisPredictor
389 390 391
    Args:
        model_dir (str): root path of __model__ and __params__
        use_gpu (bool): whether use gpu
392 393
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
394 395 396
    Returns:
        predictor (PaddlePredictor): AnalysisPredictor
    Raises:
397
        ValueError: predict by TensorRT need use_gpu == True.
398
    """
399
    if not use_gpu and not run_mode == 'fluid':
400 401 402 403
        raise ValueError(
            "Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}"
            .format(run_mode, use_gpu))
    precision_map = {
C
channings 已提交
404
        'trt_int8': fluid.core.AnalysisConfig.Precision.Int8,
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
        'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32,
        'trt_fp16': fluid.core.AnalysisConfig.Precision.Half
    }
    config = fluid.core.AnalysisConfig(
        os.path.join(model_dir, '__model__'),
        os.path.join(model_dir, '__params__'))
    if use_gpu:
        # initial GPU memory(M), device ID
        config.enable_use_gpu(100, 0)
        # optimize graph and fuse op
        config.switch_ir_optim(True)
    else:
        config.disable_gpu()

    if run_mode in precision_map.keys():
        config.enable_tensorrt_engine(
421
            workspace_size=1 << 10,
422 423 424 425
            max_batch_size=batch_size,
            min_subgraph_size=min_subgraph_size,
            precision_mode=precision_map[run_mode],
            use_static=False,
426
            use_calib_mode=trt_calib_mode)
427 428 429 430 431

    # disable print log when predict
    config.disable_glog_info()
    # enable shared memory
    config.enable_memory_optim()
432
    # disable feed, fetch OP, needed by zero_copy_run
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
    config.switch_use_feed_fetch_ops(False)
    predictor = fluid.core.create_paddle_predictor(config)
    return predictor


def load_executor(model_dir, use_gpu=False):
    if use_gpu:
        place = fluid.CUDAPlace(0)
    else:
        place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    program, feed_names, fetch_targets = fluid.io.load_inference_model(
        dirname=model_dir,
        executor=exe,
        model_filename='__model__',
        params_filename='__params__')
    return exe, program, fetch_targets


def visualize(image_file,
              results,
              labels,
              mask_resolution=14,
G
Guanghua Yu 已提交
456 457
              output_dir='output/',
              threshold=0.5):
458 459
    # visualize the predict result
    im = visualize_box_mask(
G
Guanghua Yu 已提交
460 461 462 463 464
        image_file,
        results,
        labels,
        mask_resolution=mask_resolution,
        threshold=threshold)
465 466 467 468 469 470 471 472
    img_name = os.path.split(image_file)[-1]
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    out_path = os.path.join(output_dir, img_name)
    im.save(out_path, quality=95)
    print("save result to: " + out_path)


G
Guanghua Yu 已提交
473 474 475 476 477
def print_arguments(args):
    print('-----------  Running Arguments -----------')
    for arg, value in sorted(vars(args).items()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------')
478 479


G
Guanghua Yu 已提交
480
def predict_image(detector):
C
channings 已提交
481 482
    if FLAGS.run_benchmark:
        detector.predict(
K
Kaipeng Deng 已提交
483 484 485 486 487
            FLAGS.image_file,
            FLAGS.threshold,
            warmup=100,
            repeats=100,
            run_benchmark=True)
C
channings 已提交
488 489 490 491 492 493 494
    else:
        results = detector.predict(FLAGS.image_file, FLAGS.threshold)
        visualize(
            FLAGS.image_file,
            results,
            detector.config.labels,
            mask_resolution=detector.config.mask_resolution,
G
Guanghua Yu 已提交
495 496
            output_dir=FLAGS.output_dir,
            threshold=FLAGS.threshold)
497 498


G
Guanghua Yu 已提交
499
def predict_video(detector, camera_id):
C
channings 已提交
500 501 502 503 504 505
    if camera_id != -1:
        capture = cv2.VideoCapture(camera_id)
        video_name = 'output.mp4'
    else:
        capture = cv2.VideoCapture(FLAGS.video_file)
        video_name = os.path.split(FLAGS.video_file)[-1]
506 507 508
    fps = 30
    width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
509
    fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
510
    if not os.path.exists(FLAGS.output_dir):
511
        os.makedirs(FLAGS.output_dir)
512 513 514 515 516 517 518 519 520 521 522 523 524 525
    out_path = os.path.join(FLAGS.output_dir, video_name)
    writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
    index = 1
    while (1):
        ret, frame = capture.read()
        if not ret:
            break
        print('detect frame:%d' % (index))
        index += 1
        results = detector.predict(frame, FLAGS.threshold)
        im = visualize_box_mask(
            frame,
            results,
            detector.config.labels,
526 527
            mask_resolution=detector.config.mask_resolution,
            threshold=FLAGS.threshold)
528 529
        im = np.array(im)
        writer.write(im)
C
channings 已提交
530 531 532 533
        if camera_id != -1:
            cv2.imshow('Mask Detection', im)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
534 535 536
    writer.release()


G
Guanghua Yu 已提交
537 538 539
def main():
    config = Config(FLAGS.model_dir)
    detector = Detector(
540 541 542 543 544
        config,
        FLAGS.model_dir,
        use_gpu=FLAGS.use_gpu,
        run_mode=FLAGS.run_mode,
        trt_calib_mode=FLAGS.trt_calib_mode)
G
Guanghua Yu 已提交
545 546 547 548 549
    if config.arch == 'SOLOv2':
        detector = DetectorSOLOv2(
            config,
            FLAGS.model_dir,
            use_gpu=FLAGS.use_gpu,
550 551
            run_mode=FLAGS.run_mode,
            trt_calib_mode=FLAGS.trt_calib_mode)
G
Guanghua Yu 已提交
552 553 554 555 556 557
    # predict from image
    if FLAGS.image_file != '':
        predict_image(detector)
    # predict from video file or camera video stream
    if FLAGS.video_file != '' or FLAGS.camera_id != -1:
        predict_video(detector, FLAGS.camera_id)
C
channings 已提交
558 559


560
if __name__ == '__main__':
561 562 563 564
    try:
        paddle.enable_static()
    except:
        pass
565 566 567 568 569 570
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--model_dir",
        type=str,
        default=None,
        help=("Directory include:'__model__', '__params__', "
571
              "'infer_cfg.yml', created by tools/export_model.py."),
572 573 574 575 576
        required=True)
    parser.add_argument(
        "--image_file", type=str, default='', help="Path of image file.")
    parser.add_argument(
        "--video_file", type=str, default='', help="Path of video file.")
C
channings 已提交
577 578 579 580 581
    parser.add_argument(
        "--camera_id",
        type=int,
        default=-1,
        help="device id of camera to predict.")
582 583 584 585
    parser.add_argument(
        "--run_mode",
        type=str,
        default='fluid',
586
        help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)")
587
    parser.add_argument(
C
channings 已提交
588 589 590 591 592 593 594 595 596
        "--use_gpu",
        type=ast.literal_eval,
        default=False,
        help="Whether to predict with GPU.")
    parser.add_argument(
        "--run_benchmark",
        type=ast.literal_eval,
        default=False,
        help="Whether to predict a image_file repeatedly for benchmark")
597 598 599 600 601 602 603
    parser.add_argument(
        "--threshold", type=float, default=0.5, help="Threshold of score.")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Directory of output visualization files.")
604 605 606 607 608 609
    parser.add_argument(
        "--trt_calib_mode",
        type=bool,
        default=False,
        help="If the model is produced by TRT offline quantitative "
        "calibration, trt_calib_mode need to set True.")
610 611

    FLAGS = parser.parse_args()
C
channings 已提交
612
    print_arguments(FLAGS)
613 614
    if FLAGS.image_file != '' and FLAGS.video_file != '':
        assert "Cannot predict image and video at the same time"
G
Guanghua Yu 已提交
615 616

    main()