未验证 提交 0e487626 编写于 作者: G Guanghua Yu 提交者: GitHub

fix paddle_trt infer (#1387)

上级 aa122039
......@@ -21,7 +21,7 @@ import time
from paddle.inference import Config
from paddle.inference import create_predictor
from post_process import YOLOv7PostProcess
from post_process import YOLOPostProcess
CLASS_LABEL = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
......@@ -165,6 +165,7 @@ def load_predictor(model_dir,
Raises:
ValueError: predict by TensorRT need device == 'GPU'.
"""
rerun_flag = False
if device != 'GPU' and run_mode != 'paddle':
raise ValueError(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
......@@ -211,18 +212,16 @@ def load_predictor(model_dir,
use_calib_mode=trt_calib_mode)
if use_dynamic_shape:
min_input_shape = {
'image': [batch_size, 3, trt_min_shape, trt_min_shape]
}
max_input_shape = {
'image': [batch_size, 3, trt_max_shape, trt_max_shape]
}
opt_input_shape = {
'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
opt_input_shape)
print('trt set dynamic shape done!')
dynamic_shape_file = os.path.join(args.model_path,
'dynamic_shape.txt')
if os.path.exists(dynamic_shape_file):
config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file,
True)
print('trt set dynamic shape done!')
else:
config.collect_shape_range_info(dynamic_shape_file)
print('Start collect dynamic shape...')
rerun_flag = True
# disable print log when predict
config.disable_glog_info()
......@@ -233,7 +232,7 @@ def load_predictor(model_dir,
if delete_shuffle_pass:
config.delete_pass("shuffle_channel_detect_pass")
predictor = create_predictor(config)
return predictor
return predictor, rerun_flag
def predict_image(predictor,
......@@ -244,6 +243,7 @@ def predict_image(predictor,
threshold=0.5,
arch='YOLOv5'):
img, scale_factor = image_preprocess(image_file, image_shape)
inputs = {}
if arch == 'YOLOv6':
inputs['x2paddle_image_arrays'] = img
else:
......@@ -276,7 +276,7 @@ def predict_image(predictor,
print('Inference time(ms): min={}, max={}, avg={}'.format(
round(time_min * 1000, 2),
round(time_max * 1000, 1), round(time_avg * 1000, 1)))
postprocess = YOLOv7PostProcess(
postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np_boxes, scale_factor)
res_img = draw_box(
......@@ -296,6 +296,11 @@ if __name__ == '__main__':
type=bool,
default=False,
help="Whether run benchmark or not.")
parser.add_argument(
'--use_dynamic_shape',
type=bool,
default=True,
help="Whether use dynamic shape or not.")
parser.add_argument(
'--run_mode',
type=str,
......@@ -312,11 +317,15 @@ if __name__ == '__main__':
parser.add_argument('--img_shape', type=int, default=640, help="input_size")
args = parser.parse_args()
predictor = load_predictor(
args.model_path, run_mode=args.run_mode, device=args.device)
warmup, repeats = 1, 1
if args.benchmark:
warmup, repeats = 50, 100
predictor, rerun_flag = load_predictor(
args.model_path,
run_mode=args.run_mode,
device=args.device,
use_dynamic_shape=args.use_dynamic_shape)
predict_image(
predictor,
args.image_file,
......@@ -324,3 +333,8 @@ if __name__ == '__main__':
warmup=warmup,
repeats=repeats,
arch=args.arch)
if rerun_flag:
print(
"***** Collect dynamic shape done, Please rerun the program to get correct results. *****"
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册