提交 1b07e429 编写于 作者: L LDOUBLEV

fix trt

上级 31a84a33
...@@ -112,8 +112,11 @@ void Classifier::LoadModel(const std::string &model_dir) { ...@@ -112,8 +112,11 @@ void Classifier::LoadModel(const std::string &model_dir) {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
config.CollectShapeRangeInfo("./trt_shape.txt"); if (!Utility::PathExists("./trt_cls_shape.txt")){
config.EnableTunedTensorRtDynamicShape("./trt_shape.txt", true); config.CollectShapeRangeInfo("./trt_cls_shape.txt");
} else {
config.EnableTunedTensorRtDynamicShape("./trt_cls_shape.txt", true);
}
} }
} else { } else {
config.DisableGpu(); config.DisableGpu();
......
...@@ -33,8 +33,11 @@ void DBDetector::LoadModel(const std::string &model_dir) { ...@@ -33,8 +33,11 @@ void DBDetector::LoadModel(const std::string &model_dir) {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
config.EnableTensorRtEngine(1 << 30, 1, 20, precision, false, false); config.EnableTensorRtEngine(1 << 30, 1, 20, precision, false, false);
config.CollectShapeRangeInfo("./trt_shape.txt"); if (!Utility::PathExists("./trt_det_shape.txt")){
config.EnableTunedTensorRtDynamicShape("./trt_shape.txt", true); config.CollectShapeRangeInfo("./trt_det_shape.txt");
} else {
config.EnableTunedTensorRtDynamicShape("./trt_det_shape.txt", true);
}
} }
} else { } else {
......
...@@ -147,9 +147,11 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { ...@@ -147,9 +147,11 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
if (this->precision_ == "int8") { if (this->precision_ == "int8") {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
config.EnableTensorRtEngine(1 << 20, 10, 15, precision, false, false); if (!Utility::PathExists("./trt_rec_shape.txt")){
config.CollectShapeRangeInfo("./trt_shape.txt"); config.CollectShapeRangeInfo("./trt_rec_shape.txt");
config.EnableTunedTensorRtDynamicShape("./trt_shape.txt", true); } else {
config.EnableTunedTensorRtDynamicShape("./trt_rec_shape.txt", true);
}
} }
} else { } else {
......
...@@ -228,19 +228,18 @@ def create_predictor(args, mode, logger): ...@@ -228,19 +228,18 @@ def create_predictor(args, mode, logger):
use_calib_mode=False) use_calib_mode=False)
# collect shape # collect shape
if args.shape_info_filename is not None: trt_shape_f = f"{os.path.dirname(args.shape_info_filename)}/{mode}_{os.path.basename(args.shape_info_filename)}"
if not os.path.exists(args.shape_info_filename): if trt_shape_f is not None:
config.collect_shape_range_info( if not os.path.exists(trt_shape_f):
args.shape_info_filename) config.collect_shape_range_info(trt_shape_f)
logger.info( logger.info(
f"collect dynamic shape info into : {args.shape_info_filename}" f"collect dynamic shape info into : {trt_shape_f}"
) )
else: else:
logger.info( logger.info(
f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again." f"dynamic shape info file( {trt_shape_f} ) already exists, not need to generate again."
) )
config.enable_tuned_tensorrt_dynamic_shape( config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True)
args.shape_info_filename, True)
else: else:
logger.info( logger.info(
f"when using tensorrt, dynamic shape is a suggested option, you can use '--shape_info_filename=shape.txt' for offline dygnamic shape tuning" f"when using tensorrt, dynamic shape is a suggested option, you can use '--shape_info_filename=shape.txt' for offline dygnamic shape tuning"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册