From 1b07e429123ace494dbe5a0567524ee8b77c46db Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Fri, 16 Sep 2022 20:11:18 +0800 Subject: [PATCH] fix trt --- deploy/cpp_infer/src/ocr_cls.cpp | 7 +++++-- deploy/cpp_infer/src/ocr_det.cpp | 7 +++++-- deploy/cpp_infer/src/ocr_rec.cpp | 8 +++++--- tools/infer/utility.py | 15 +++++++-------- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/deploy/cpp_infer/src/ocr_cls.cpp b/deploy/cpp_infer/src/ocr_cls.cpp index 827ddef5..92d83600 100644 --- a/deploy/cpp_infer/src/ocr_cls.cpp +++ b/deploy/cpp_infer/src/ocr_cls.cpp @@ -112,8 +112,11 @@ void Classifier::LoadModel(const std::string &model_dir) { precision = paddle_infer::Config::Precision::kInt8; } config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); - config.CollectShapeRangeInfo("./trt_shape.txt"); - config.EnableTunedTensorRtDynamicShape("./trt_shape.txt", true); + if (!Utility::PathExists("./trt_cls_shape.txt")){ + config.CollectShapeRangeInfo("./trt_cls_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_cls_shape.txt", true); + } } } else { config.DisableGpu(); diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 41a97c81..030d5c2f 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -33,8 +33,11 @@ void DBDetector::LoadModel(const std::string &model_dir) { precision = paddle_infer::Config::Precision::kInt8; } config.EnableTensorRtEngine(1 << 30, 1, 20, precision, false, false); - config.CollectShapeRangeInfo("./trt_shape.txt"); - config.EnableTunedTensorRtDynamicShape("./trt_shape.txt", true); + if (!Utility::PathExists("./trt_det_shape.txt")){ + config.CollectShapeRangeInfo("./trt_det_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_det_shape.txt", true); + } } } else { diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index 36bfaf19..088cb942 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -147,9 +147,11 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { if (this->precision_ == "int8") { precision = paddle_infer::Config::Precision::kInt8; } - config.EnableTensorRtEngine(1 << 20, 10, 15, precision, false, false); - config.CollectShapeRangeInfo("./trt_shape.txt"); - config.EnableTunedTensorRtDynamicShape("./trt_shape.txt", true); + if (!Utility::PathExists("./trt_rec_shape.txt")){ + config.CollectShapeRangeInfo("./trt_rec_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_rec_shape.txt", true); + } } } else { diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 46336009..07b2172c 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -228,19 +228,18 @@ def create_predictor(args, mode, logger): use_calib_mode=False) # collect shape - if args.shape_info_filename is not None: - if not os.path.exists(args.shape_info_filename): - config.collect_shape_range_info( - args.shape_info_filename) + trt_shape_f = f"{os.path.dirname(args.shape_info_filename)}/{mode}_{os.path.basename(args.shape_info_filename)}" + if trt_shape_f is not None: + if not os.path.exists(trt_shape_f): + config.collect_shape_range_info(trt_shape_f) logger.info( - f"collect dynamic shape info into : {args.shape_info_filename}" + f"collect dynamic shape info into : {trt_shape_f}" ) else: logger.info( - f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again." + f"dynamic shape info file( {trt_shape_f} ) already exists, not need to generate again." ) - config.enable_tuned_tensorrt_dynamic_shape( - args.shape_info_filename, True) + config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True) else: logger.info( f"when using tensorrt, dynamic shape is a suggested option, you can use '--shape_info_filename=shape.txt' for offline dygnamic shape tuning" -- GitLab