From 838eecbb57a0788eb9431ad9031a35139af94f27 Mon Sep 17 00:00:00 2001 From: Double_V Date: Mon, 19 Sep 2022 17:40:05 +0800 Subject: [PATCH] delete shape_info_name (#7640) * delete shape_info_name * Update utility.py --- tools/infer/utility.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b9793123..e6adad3d 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -38,7 +38,6 @@ def init_args(): parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--min_subgraph_size", type=int, default=15) - parser.add_argument("--shape_info_filename", type=str, default=None) parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--gpu_mem", type=int, default=500) @@ -226,22 +225,22 @@ def create_predictor(args, mode, logger): use_calib_mode=False) # collect shape - trt_shape_f = f"{os.path.dirname(args.shape_info_filename)}/{mode}_{os.path.basename(args.shape_info_filename)}" - if trt_shape_f is not None: - if not os.path.exists(trt_shape_f): - config.collect_shape_range_info(trt_shape_f) - logger.info( - f"collect dynamic shape info into : {trt_shape_f}" - ) - else: - logger.info( - f"dynamic shape info file( {trt_shape_f} ) already exists, not need to generate again." - ) - config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True) + trt_shape_f = os.path.join(model_dir, f"{mode}_trt_dynamic_shape.txt") + + if not os.path.exists(trt_shape_f): + config.collect_shape_range_info(trt_shape_f) + logger.info( + f"collect dynamic shape info into : {trt_shape_f}") else: logger.info( - f"when using tensorrt, dynamic shape is a suggested option, you can use '--shape_info_filename=shape.txt' for offline dygnamic shape tuning" + f"dynamic shape info file( {trt_shape_f} ) already exists, not need to generate again." ) + try: + config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, + True) + except Exception as E: + logger.info(E) + logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!") elif args.use_xpu: config.enable_xpu(10 * 1024 * 1024) -- GitLab