未验证 提交 12ae16e4 编写于 作者: G Guanghua Yu 提交者: GitHub

fix trt inference when bs > 1 (#3283)

上级 9f9db4b7
...@@ -425,9 +425,15 @@ def load_predictor(model_dir, ...@@ -425,9 +425,15 @@ def load_predictor(model_dir,
use_calib_mode=trt_calib_mode) use_calib_mode=trt_calib_mode)
if use_dynamic_shape: if use_dynamic_shape:
min_input_shape = {'image': [1, 3, trt_min_shape, trt_min_shape]} min_input_shape = {
max_input_shape = {'image': [1, 3, trt_max_shape, trt_max_shape]} 'image': [batch_size, 3, trt_min_shape, trt_min_shape]
opt_input_shape = {'image': [1, 3, trt_opt_shape, trt_opt_shape]} }
max_input_shape = {
'image': [batch_size, 3, trt_max_shape, trt_max_shape]
}
opt_input_shape = {
'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
opt_input_shape) opt_input_shape)
print('trt set dynamic shape done!') print('trt set dynamic shape done!')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册