From 12ae16e4c40c649b67455b2357468795b05bc519 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Fri, 4 Jun 2021 20:13:51 +0800 Subject: [PATCH] fix trt inference when bs > 1 (#3283) --- deploy/python/infer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/deploy/python/infer.py b/deploy/python/infer.py index c60aae4b4..5ad8a3512 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -425,9 +425,15 @@ def load_predictor(model_dir, use_calib_mode=trt_calib_mode) if use_dynamic_shape: - min_input_shape = {'image': [1, 3, trt_min_shape, trt_min_shape]} - max_input_shape = {'image': [1, 3, trt_max_shape, trt_max_shape]} - opt_input_shape = {'image': [1, 3, trt_opt_shape, trt_opt_shape]} + min_input_shape = { + 'image': [batch_size, 3, trt_min_shape, trt_min_shape] + } + max_input_shape = { + 'image': [batch_size, 3, trt_max_shape, trt_max_shape] + } + opt_input_shape = { + 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape] + } config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, opt_input_shape) print('trt set dynamic shape done!') -- GitLab