diff --git a/tools/infer/utility.py b/tools/infer/utility.py index a8c59fac6fbc4e3782974c588e405cd515ed92dc..dafbfbeaf3e25fdd402027190c92ab45cbe352b4 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -227,10 +227,7 @@ def create_predictor(args, mode, logger): use_calib_mode=False) # collect shape - model_name = os.path.basename( - model_dir[:-1]) if model_dir.endswith( - "/") else os.path.basename(model_dir) - trt_shape_f = f"{mode}_{model_name}.txt" + trt_shape_f = os.path.join(model_dir, f"{mode}_trt_dynamic_shape.txt") if not os.path.exists(trt_shape_f): config.collect_shape_range_info(trt_shape_f)