diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index ba0adaee258096ea9970425cc05ca7a8f1cf08c4..1720369d01c42b41a9daea6907afe7714d97025b 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -35,6 +35,7 @@ logger = get_logger() class TextDetector(object): def __init__(self, args): + self.args = args self.det_algorithm = args.det_algorithm self.use_zero_copy_run = args.use_zero_copy_run pre_process_list = [{ diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 162f6cb757a6725e67a76409b6a50ae25b1a6dc8..b793254da688079c5a6782f2c071f1c3d8f992d4 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -33,6 +33,8 @@ def parse_args(): parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) + parser.add_argument("--use_fp16", type=str2bool, default=False) + parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--gpu_mem", type=int, default=8000) # params for text detector @@ -46,7 +48,7 @@ def parse_args(): parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) - + parser.add_argument("--max_batch_size", type=int, default=10) # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) @@ -113,6 +115,11 @@ def create_predictor(args, mode, logger): if args.use_gpu: config.enable_use_gpu(args.gpu_mem, 0) + if args.use_tensorrt: + config.enable_tensorrt_engine( + precision_mode=AnalysisConfig.Precision.Half + if args.use_fp16 else AnalysisConfig.Precision.Float32, + max_batch_size=args.max_batch_size) else: config.disable_gpu() config.set_cpu_math_library_num_threads(6)