未验证 提交 0492ba4f 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #1503 from LDOUBLEV/dyg_db

add tensorrt args
...@@ -35,6 +35,7 @@ logger = get_logger() ...@@ -35,6 +35,7 @@ logger = get_logger()
class TextDetector(object): class TextDetector(object):
def __init__(self, args): def __init__(self, args):
self.args = args
self.det_algorithm = args.det_algorithm self.det_algorithm = args.det_algorithm
self.use_zero_copy_run = args.use_zero_copy_run self.use_zero_copy_run = args.use_zero_copy_run
pre_process_list = [{ pre_process_list = [{
......
...@@ -33,6 +33,8 @@ def parse_args(): ...@@ -33,6 +33,8 @@ def parse_args():
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector # params for text detector
...@@ -46,7 +48,7 @@ def parse_args(): ...@@ -46,7 +48,7 @@ def parse_args():
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--max_batch_size", type=int, default=10)
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
...@@ -113,6 +115,11 @@ def create_predictor(args, mode, logger): ...@@ -113,6 +115,11 @@ def create_predictor(args, mode, logger):
if args.use_gpu: if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0) config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt:
config.enable_tensorrt_engine(
precision_mode=AnalysisConfig.Precision.Half
if args.use_fp16 else AnalysisConfig.Precision.Float32,
max_batch_size=args.max_batch_size)
else: else:
config.disable_gpu() config.disable_gpu()
config.set_cpu_math_library_num_threads(6) config.set_cpu_math_library_num_threads(6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册