diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 793ff28a22dddffa937dc3cb042c12123cbc90fb..a8c59fac6fbc4e3782974c588e405cd515ed92dc 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -227,19 +227,21 @@ def create_predictor(args, mode, logger): use_calib_mode=False) # collect shape - model_name = os.path.basename(model_dir[:-1]) if model_dir.endswith("/") else os.path.basename(model_dir) + model_name = os.path.basename( + model_dir[:-1]) if model_dir.endswith( + "/") else os.path.basename(model_dir) trt_shape_f = f"{mode}_{model_name}.txt" - if trt_shape_f is not None: - if not os.path.exists(trt_shape_f): - config.collect_shape_range_info(trt_shape_f) - logger.info( - f"collect dynamic shape info into : {trt_shape_f}" - ) - else: - logger.info( - f"dynamic shape info file( {trt_shape_f} ) already exists, not need to generate again." - ) - config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True) + + if not os.path.exists(trt_shape_f): + config.collect_shape_range_info(trt_shape_f) + logger.info( + f"collect dynamic shape info into : {trt_shape_f}") + try: + config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, + True) + except Exception as E: + logger.info(E) + logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!") elif args.use_npu: config.enable_npu()