未验证 提交 0e487626 编写于 作者: G Guanghua Yu 提交者: GitHub

fix paddle_trt infer (#1387)

上级 aa122039
...@@ -21,7 +21,7 @@ import time ...@@ -21,7 +21,7 @@ import time
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
from post_process import YOLOv7PostProcess from post_process import YOLOPostProcess
CLASS_LABEL = [ CLASS_LABEL = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
...@@ -165,6 +165,7 @@ def load_predictor(model_dir, ...@@ -165,6 +165,7 @@ def load_predictor(model_dir,
Raises: Raises:
ValueError: predict by TensorRT need device == 'GPU'. ValueError: predict by TensorRT need device == 'GPU'.
""" """
rerun_flag = False
if device != 'GPU' and run_mode != 'paddle': if device != 'GPU' and run_mode != 'paddle':
raise ValueError( raise ValueError(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}" "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
...@@ -211,18 +212,16 @@ def load_predictor(model_dir, ...@@ -211,18 +212,16 @@ def load_predictor(model_dir,
use_calib_mode=trt_calib_mode) use_calib_mode=trt_calib_mode)
if use_dynamic_shape: if use_dynamic_shape:
min_input_shape = { dynamic_shape_file = os.path.join(args.model_path,
'image': [batch_size, 3, trt_min_shape, trt_min_shape] 'dynamic_shape.txt')
} if os.path.exists(dynamic_shape_file):
max_input_shape = { config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file,
'image': [batch_size, 3, trt_max_shape, trt_max_shape] True)
} print('trt set dynamic shape done!')
opt_input_shape = { else:
'image': [batch_size, 3, trt_opt_shape, trt_opt_shape] config.collect_shape_range_info(dynamic_shape_file)
} print('Start collect dynamic shape...')
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, rerun_flag = True
opt_input_shape)
print('trt set dynamic shape done!')
# disable print log when predict # disable print log when predict
config.disable_glog_info() config.disable_glog_info()
...@@ -233,7 +232,7 @@ def load_predictor(model_dir, ...@@ -233,7 +232,7 @@ def load_predictor(model_dir,
if delete_shuffle_pass: if delete_shuffle_pass:
config.delete_pass("shuffle_channel_detect_pass") config.delete_pass("shuffle_channel_detect_pass")
predictor = create_predictor(config) predictor = create_predictor(config)
return predictor return predictor, rerun_flag
def predict_image(predictor, def predict_image(predictor,
...@@ -244,6 +243,7 @@ def predict_image(predictor, ...@@ -244,6 +243,7 @@ def predict_image(predictor,
threshold=0.5, threshold=0.5,
arch='YOLOv5'): arch='YOLOv5'):
img, scale_factor = image_preprocess(image_file, image_shape) img, scale_factor = image_preprocess(image_file, image_shape)
inputs = {}
if arch == 'YOLOv6': if arch == 'YOLOv6':
inputs['x2paddle_image_arrays'] = img inputs['x2paddle_image_arrays'] = img
else: else:
...@@ -276,7 +276,7 @@ def predict_image(predictor, ...@@ -276,7 +276,7 @@ def predict_image(predictor,
print('Inference time(ms): min={}, max={}, avg={}'.format( print('Inference time(ms): min={}, max={}, avg={}'.format(
round(time_min * 1000, 2), round(time_min * 1000, 2),
round(time_max * 1000, 1), round(time_avg * 1000, 1))) round(time_max * 1000, 1), round(time_avg * 1000, 1)))
postprocess = YOLOv7PostProcess( postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True) score_threshold=0.001, nms_threshold=0.65, multi_label=True)
res = postprocess(np_boxes, scale_factor) res = postprocess(np_boxes, scale_factor)
res_img = draw_box( res_img = draw_box(
...@@ -296,6 +296,11 @@ if __name__ == '__main__': ...@@ -296,6 +296,11 @@ if __name__ == '__main__':
type=bool, type=bool,
default=False, default=False,
help="Whether run benchmark or not.") help="Whether run benchmark or not.")
parser.add_argument(
'--use_dynamic_shape',
type=bool,
default=True,
help="Whether use dynamic shape or not.")
parser.add_argument( parser.add_argument(
'--run_mode', '--run_mode',
type=str, type=str,
...@@ -312,11 +317,15 @@ if __name__ == '__main__': ...@@ -312,11 +317,15 @@ if __name__ == '__main__':
parser.add_argument('--img_shape', type=int, default=640, help="input_size") parser.add_argument('--img_shape', type=int, default=640, help="input_size")
args = parser.parse_args() args = parser.parse_args()
predictor = load_predictor(
args.model_path, run_mode=args.run_mode, device=args.device)
warmup, repeats = 1, 1 warmup, repeats = 1, 1
if args.benchmark: if args.benchmark:
warmup, repeats = 50, 100 warmup, repeats = 50, 100
predictor, rerun_flag = load_predictor(
args.model_path,
run_mode=args.run_mode,
device=args.device,
use_dynamic_shape=args.use_dynamic_shape)
predict_image( predict_image(
predictor, predictor,
args.image_file, args.image_file,
...@@ -324,3 +333,8 @@ if __name__ == '__main__': ...@@ -324,3 +333,8 @@ if __name__ == '__main__':
warmup=warmup, warmup=warmup,
repeats=repeats, repeats=repeats,
arch=args.arch) arch=args.arch)
if rerun_flag:
print(
"***** Collect dynamic shape done, Please rerun the program to get correct results. *****"
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册