diff --git a/deploy/python/infer.py b/deploy/python/infer.py index c60aae4b4e1487766885cdfec2129f0aff62813c..5ad8a3512bf488f424d725d262445d4f5b69fb8d 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -425,9 +425,15 @@ def load_predictor(model_dir, use_calib_mode=trt_calib_mode) if use_dynamic_shape: - min_input_shape = {'image': [1, 3, trt_min_shape, trt_min_shape]} - max_input_shape = {'image': [1, 3, trt_max_shape, trt_max_shape]} - opt_input_shape = {'image': [1, 3, trt_opt_shape, trt_opt_shape]} + min_input_shape = { + 'image': [batch_size, 3, trt_min_shape, trt_min_shape] + } + max_input_shape = { + 'image': [batch_size, 3, trt_max_shape, trt_max_shape] + } + opt_input_shape = { + 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape] + } config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, opt_input_shape) print('trt set dynamic shape done!')