未验证 提交 8c46869c 编写于 作者: G Guanghua Yu 提交者: GitHub

fix detection paddle trt infer (#1390)

上级 0e487626
...@@ -80,7 +80,7 @@ def image_preprocess(img_path, target_shape): ...@@ -80,7 +80,7 @@ def image_preprocess(img_path, target_shape):
img -= img_mean img -= img_mean
img /= img_std img /= img_std
scale_factor = np.array([[im_scale_y, im_scale_x]]) scale_factor = np.array([[im_scale_y, im_scale_x]])
return img.astype(np.float32), scale_factor return img.astype(np.float32), scale_factor.astype(np.float32)
def get_color_map_list(num_classes): def get_color_map_list(num_classes):
...@@ -130,7 +130,7 @@ def load_predictor(model_dir, ...@@ -130,7 +130,7 @@ def load_predictor(model_dir,
device='CPU', device='CPU',
min_subgraph_size=3, min_subgraph_size=3,
use_dynamic_shape=False, use_dynamic_shape=False,
trt_min_shape=1, trt_min_shape=3,
trt_max_shape=1280, trt_max_shape=1280,
trt_opt_shape=640, trt_opt_shape=640,
trt_calib_mode=False, trt_calib_mode=False,
...@@ -215,8 +215,6 @@ def load_predictor(model_dir, ...@@ -215,8 +215,6 @@ def load_predictor(model_dir,
opt_input_shape) opt_input_shape)
print('trt set dynamic shape done!') print('trt set dynamic shape done!')
# disable print log when predict
config.disable_glog_info()
# enable shared memory # enable shared memory
config.enable_memory_optim() config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run # disable feed, fetch OP, needed by zero_copy_run
...@@ -233,10 +231,12 @@ def predict_image(predictor, ...@@ -233,10 +231,12 @@ def predict_image(predictor,
warmup=1, warmup=1,
repeats=1, repeats=1,
threshold=0.5, threshold=0.5,
arch='YOLOv5'): include_nms=True):
img, scale_factor = image_preprocess(image_file, image_shape) img, scale_factor = image_preprocess(image_file, image_shape)
inputs = {} inputs = {}
inputs['image'] = img inputs['image'] = img
if include_nms:
inputs['scale_factor'] = scale_factor
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
for i in range(len(input_names)): for i in range(len(input_names)):
input_tensor = predictor.get_input_handle(input_names[i]) input_tensor = predictor.get_input_handle(input_names[i])
...@@ -245,7 +245,7 @@ def predict_image(predictor, ...@@ -245,7 +245,7 @@ def predict_image(predictor,
for i in range(warmup): for i in range(warmup):
predictor.run() predictor.run()
np_boxes = None np_boxes, np_boxes_num = None, None
predict_time = 0. predict_time = 0.
time_min = float("inf") time_min = float("inf")
time_max = float('-inf') time_max = float('-inf')
...@@ -255,6 +255,9 @@ def predict_image(predictor, ...@@ -255,6 +255,9 @@ def predict_image(predictor,
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
boxes_tensor = predictor.get_output_handle(output_names[0]) boxes_tensor = predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu() np_boxes = boxes_tensor.copy_to_cpu()
if include_nms:
boxes_num = predictor.get_output_handle(output_names[1])
np_boxes_num = boxes_num.copy_to_cpu()
end_time = time.time() end_time = time.time()
timed = end_time - start_time timed = end_time - start_time
time_min = min(time_min, timed) time_min = min(time_min, timed)
...@@ -265,8 +268,11 @@ def predict_image(predictor, ...@@ -265,8 +268,11 @@ def predict_image(predictor,
print('Inference time(ms): min={}, max={}, avg={}'.format( print('Inference time(ms): min={}, max={}, avg={}'.format(
round(time_min * 1000, 2), round(time_min * 1000, 2),
round(time_max * 1000, 1), round(time_avg * 1000, 1))) round(time_max * 1000, 1), round(time_avg * 1000, 1)))
postprocess = PPYOLOEPostProcess(score_threshold=0.3, nms_threshold=0.6) if not include_nms:
res = postprocess(np_boxes, scale_factor) postprocess = PPYOLOEPostProcess(score_threshold=0.3, nms_threshold=0.6)
res = postprocess(np_boxes, scale_factor)
else:
res = {'bbox': np_boxes, 'bbox_num': np_boxes_num}
res_img = draw_box( res_img = draw_box(
image_file, res['bbox'], CLASS_LABEL, threshold=threshold) image_file, res['bbox'], CLASS_LABEL, threshold=threshold)
cv2.imwrite('result.jpg', res_img) cv2.imwrite('result.jpg', res_img)
...@@ -296,6 +302,11 @@ if __name__ == '__main__': ...@@ -296,6 +302,11 @@ if __name__ == '__main__':
help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU" help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU"
) )
parser.add_argument('--img_shape', type=int, default=640, help="input_size") parser.add_argument('--img_shape', type=int, default=640, help="input_size")
parser.add_argument(
'--include_nms',
type=bool,
default=True,
help="Whether include nms or not.")
args = parser.parse_args() args = parser.parse_args()
predictor = load_predictor( predictor = load_predictor(
...@@ -308,4 +319,5 @@ if __name__ == '__main__': ...@@ -308,4 +319,5 @@ if __name__ == '__main__':
args.image_file, args.image_file,
image_shape=[args.img_shape, args.img_shape], image_shape=[args.img_shape, args.img_shape],
warmup=warmup, warmup=warmup,
repeats=repeats) repeats=repeats,
include_nms=args.include_nms)
...@@ -147,9 +147,9 @@ python eval.py --config_path=./configs/yolov7_tiny_qat_dis.yaml ...@@ -147,9 +147,9 @@ python eval.py --config_path=./configs/yolov7_tiny_qat_dis.yaml
#### 导出至ONNX使用TensorRT部署 #### 导出至ONNX使用TensorRT部署
加载`quant_model.onnx``calibration.cache`,可以直接使用TensorRT测试脚本进行验证,详细代码可参考[./TensorRT] 加载`quant_model.onnx``calibration.cache`,可以直接使用TensorRT测试脚本进行验证,详细代码可参考[TensorRT部署](/TensorRT)
- 进行测试: - python测试:
```shell ```shell
cd TensorRT cd TensorRT
python trt_eval.py --onnx_model_file=output/ONNX/quant_model.onnx \ python trt_eval.py --onnx_model_file=output/ONNX/quant_model.onnx \
...@@ -158,6 +158,11 @@ python trt_eval.py --onnx_model_file=output/ONNX/quant_model.onnx \ ...@@ -158,6 +158,11 @@ python trt_eval.py --onnx_model_file=output/ONNX/quant_model.onnx \
--precision_mode=int8 --precision_mode=int8
``` ```
- 速度测试
```shell
trtexec --onnx=output/ONNX/quant_model.onnx --avgRuns=1000 --workspace=1024 --calib=output/ONNX/calibration.cache --int8
```
#### Paddle-TensorRT部署 #### Paddle-TensorRT部署
- C++部署 - C++部署
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册