From 9039cca26d3fcf35fa432e39b299da50df6342b3 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Fri, 18 Dec 2020 15:27:44 +0800 Subject: [PATCH] add tensorrt args --- tools/infer/predict_det.py | 1 + tools/infer/utility.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index ba0adaee..1720369d 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -35,6 +35,7 @@ logger = get_logger() class TextDetector(object): def __init__(self, args): + self.args = args self.det_algorithm = args.det_algorithm self.use_zero_copy_run = args.use_zero_copy_run pre_process_list = [{ diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 162f6cb7..b793254d 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -33,6 +33,8 @@ def parse_args(): parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) + parser.add_argument("--use_fp16", type=str2bool, default=False) + parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--gpu_mem", type=int, default=8000) # params for text detector @@ -46,7 +48,7 @@ def parse_args(): parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) - + parser.add_argument("--max_batch_size", type=int, default=10) # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) @@ -113,6 +115,11 @@ def create_predictor(args, mode, logger): if args.use_gpu: config.enable_use_gpu(args.gpu_mem, 0) + if args.use_tensorrt: + config.enable_tensorrt_engine( + precision_mode=AnalysisConfig.Precision.Half + if args.use_fp16 else AnalysisConfig.Precision.Float32, + max_batch_size=args.max_batch_size) else: config.disable_gpu() config.set_cpu_math_library_num_threads(6) -- GitLab