diff --git a/deploy/cpp_infer/src/ocr_cls.cpp b/deploy/cpp_infer/src/ocr_cls.cpp index 827ddef59371bee018e202bceda9522cdc8bdc61..92d83600cea04419db231c0097caa53ed6fec58b 100644 --- a/deploy/cpp_infer/src/ocr_cls.cpp +++ b/deploy/cpp_infer/src/ocr_cls.cpp @@ -112,8 +112,11 @@ void Classifier::LoadModel(const std::string &model_dir) { precision = paddle_infer::Config::Precision::kInt8; } config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); - config.CollectShapeRangeInfo("./trt_shape.txt"); - config.EnableTunedTensorRtDynamicShape("./trt_shape.txt", true); + if (!Utility::PathExists("./trt_cls_shape.txt")){ + config.CollectShapeRangeInfo("./trt_cls_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_cls_shape.txt", true); + } } } else { config.DisableGpu(); diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 41a97c81ce8e1b44618e32cff88888b2a8ffa076..030d5c2f359bba522662324d84c6ef1cc0bc83b8 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -33,8 +33,11 @@ void DBDetector::LoadModel(const std::string &model_dir) { precision = paddle_infer::Config::Precision::kInt8; } config.EnableTensorRtEngine(1 << 30, 1, 20, precision, false, false); - config.CollectShapeRangeInfo("./trt_shape.txt"); - config.EnableTunedTensorRtDynamicShape("./trt_shape.txt", true); + if (!Utility::PathExists("./trt_det_shape.txt")){ + config.CollectShapeRangeInfo("./trt_det_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_det_shape.txt", true); + } } } else { diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index 36bfaf19d7d080dec00fe163ace3d54563121c30..088cb942ba5ac4b09c9e8d1731a3b20d40967edf 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -147,9 +147,11 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { if (this->precision_ == "int8") { precision = paddle_infer::Config::Precision::kInt8; } - config.EnableTensorRtEngine(1 << 20, 10, 15, precision, false, false); - config.CollectShapeRangeInfo("./trt_shape.txt"); - config.EnableTunedTensorRtDynamicShape("./trt_shape.txt", true); + if (!Utility::PathExists("./trt_rec_shape.txt")){ + config.CollectShapeRangeInfo("./trt_rec_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_rec_shape.txt", true); + } } } else { diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 463360091c4fc69092dddf2b1389de9974a11ee9..07b2172cd3c6a624d4b1026163dcb811edebde02 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -228,19 +228,18 @@ def create_predictor(args, mode, logger): use_calib_mode=False) # collect shape - if args.shape_info_filename is not None: - if not os.path.exists(args.shape_info_filename): - config.collect_shape_range_info( - args.shape_info_filename) + trt_shape_f = f"{os.path.dirname(args.shape_info_filename)}/{mode}_{os.path.basename(args.shape_info_filename)}" + if trt_shape_f is not None: + if not os.path.exists(trt_shape_f): + config.collect_shape_range_info(trt_shape_f) logger.info( - f"collect dynamic shape info into : {args.shape_info_filename}" + f"collect dynamic shape info into : {trt_shape_f}" ) else: logger.info( - f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again." + f"dynamic shape info file( {trt_shape_f} ) already exists, not need to generate again." ) - config.enable_tuned_tensorrt_dynamic_shape( - args.shape_info_filename, True) + config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True) else: logger.info( f"when using tensorrt, dynamic shape is a suggested option, you can use '--shape_info_filename=shape.txt' for offline dygnamic shape tuning"