未验证 提交 0fe525af 编写于 作者: J JYChen 提交者: GitHub

Add Timer for Action (#5438)

* 1. add timer for kpt& action 2. trt support for stgcn

* add parameter description

* make warm_frame configurable
上级 5815fbc9
......@@ -2,6 +2,7 @@ crop_thresh: 0.5
attr_thresh: 0.5
kpt_thresh: 0.2
visual: True
warmup_frame: 50
DET:
model_dir: output_inference/mot_ppyolov3/
......
......@@ -265,7 +265,7 @@ class PipePredictor(object):
self.cfg = cfg
self.output_dir = output_dir
self.warmup_frame = 1
self.warmup_frame = self.cfg['warmup_frame']
self.pipeline_res = Result()
self.pipe_timer = PipeTimer()
self.file_name = None
......@@ -469,6 +469,8 @@ class PipePredictor(object):
self.pipeline_res.update(attr_res, 'attr')
if self.with_action:
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['kpt'].start()
kpt_pred = self.kpt_predictor.predict_image(
crop_input, visual=False)
keypoint_vector, score_vector = translate_to_ori_images(
......@@ -478,6 +480,9 @@ class PipePredictor(object):
keypoint_vector.tolist(), score_vector.tolist()
] if len(keypoint_vector) > 0 else [[], []]
kpt_res['bbox'] = ori_bboxes
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['kpt'].end()
self.pipeline_res.update(kpt_res, 'kpt')
self.kpt_collector.update(kpt_res,
......@@ -487,12 +492,16 @@ class PipePredictor(object):
action_res = {}
if state:
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['action'].start()
collected_keypoint = self.kpt_collector.get_collected_keypoint(
) # reoragnize kpt output with ID
action_input = parse_mot_keypoint(collected_keypoint,
self.coord_size)
action_res = self.action_predictor.predict_skeleton_with_mot(
action_input)
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['action'].end()
self.pipeline_res.update(action_res, 'action')
if self.cfg['visual']:
......
......@@ -80,7 +80,8 @@ class ActionRecognizer(Detector):
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold)
threshold=threshold,
delete_shuffle_pass=True)
def predict(self, repeats=1):
'''
......
......@@ -79,23 +79,25 @@ class Detector(object):
enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16
output_dir (str): The path of output
threshold (float): The threshold of score for visualization
delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
Used by action model.
"""
def __init__(
self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='output',
threshold=0.5, ):
def __init__(self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='output',
threshold=0.5,
delete_shuffle_pass=False):
self.pred_config = self.set_config(model_dir)
self.predictor, self.config = load_predictor(
model_dir,
......@@ -110,7 +112,8 @@ class Detector(object):
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16)
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
delete_shuffle_pass=delete_shuffle_pass)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
self.batch_size = batch_size
......@@ -591,7 +594,8 @@ def load_predictor(model_dir,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False):
enable_mkldnn_bfloat16=False,
delete_shuffle_pass=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
......@@ -603,6 +607,8 @@ def load_predictor(model_dir,
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
Used by action model.
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
......@@ -673,6 +679,8 @@ def load_predictor(model_dir,
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
if delete_shuffle_pass:
config.delete_pass("shuffle_channel_detect_pass")
predictor = create_predictor(config)
return predictor, config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册