“8238164fe8ce6f4574c38302057646946d14e813”上不存在“doc/api/v2/fluid/profiler.html”
未验证 提交 f4ef20df 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] fix use trt dynamic shape to infer (#7889)

上级 1d07733b
......@@ -46,8 +46,6 @@ SUPPORT_MODELS = {
'PPLCNet', 'DETR', 'CenterTrack'
}
TUNED_TRT_DYNAMIC_MODELS = {'DETR'}
def bench_log(detector, img_list, model_info, batch_size=1, name=None):
mems = {
......@@ -445,7 +443,7 @@ class Detector(object):
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
......@@ -823,8 +821,7 @@ def load_predictor(model_dir,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
delete_shuffle_pass=False,
tuned_trt_shape_file="shape_range_info.pbtxt"):
delete_shuffle_pass=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
......@@ -891,8 +888,6 @@ def load_predictor(model_dir,
'trt_fp16': Config.Precision.Half
}
if run_mode in precision_map.keys():
if arch in TUNED_TRT_DYNAMIC_MODELS:
config.collect_shape_range_info(tuned_trt_shape_file)
config.enable_tensorrt_engine(
workspace_size=(1 << 25) * batch_size,
max_batch_size=batch_size,
......@@ -900,9 +895,13 @@ def load_predictor(model_dir,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=trt_calib_mode)
if arch in TUNED_TRT_DYNAMIC_MODELS:
config.enable_tuned_tensorrt_dynamic_shape(tuned_trt_shape_file,
True)
if FLAGS.collect_trt_shape_info:
config.collect_shape_range_info(FLAGS.tuned_trt_shape_file)
elif os.path.exists(FLAGS.tuned_trt_shape_file):
print(f'Use dynamic shape file: '
f'{FLAGS.tuned_trt_shape_file} for TRT...')
config.enable_tuned_tensorrt_dynamic_shape(
FLAGS.tuned_trt_shape_file, True)
if use_dynamic_shape:
min_input_shape = {
......
......@@ -201,6 +201,16 @@ def argsparser():
type=str,
default='ios',
help="Combine method matching metric, choose in ['iou', 'ios'].")
parser.add_argument(
"--collect_trt_shape_info",
action='store_true',
default=False,
help="Whether to collect dynamic shape before using tensorrt.")
parser.add_argument(
"--tuned_trt_shape_file",
type=str,
default="shape_range_info.pbtxt",
help="Path of a dynamic shape file for tensorrt.")
return parser
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册