未验证 提交 23e9cb95 编写于 作者: J JYChen 提交者: GitHub

add frame-skip to boost inference (#6383)

上级 9e5f22ae
......@@ -50,6 +50,7 @@ ID_BASED_DETACTION:
basemode: "idbased"
threshold: 0.6
display_frames: 80
skip_frame_num: 2
enable: False
ID_BASED_CLSACTION:
......@@ -58,6 +59,7 @@ ID_BASED_CLSACTION:
basemode: "idbased"
threshold: 0.8
display_frames: 80
skip_frame_num: 2
enable: False
REID:
......
......@@ -342,7 +342,9 @@ class PipePredictor(object):
basemode = idbased_detaction_cfg['basemode']
threshold = idbased_detaction_cfg['threshold']
display_frames = idbased_detaction_cfg['display_frames']
skip_frame_num = idbased_detaction_cfg['skip_frame_num']
self.modebase[basemode] = True
self.det_action_predictor = DetActionRecognizer(
model_dir,
device,
......@@ -355,7 +357,8 @@ class PipePredictor(object):
cpu_threads,
enable_mkldnn,
threshold=threshold,
display_frames=display_frames)
display_frames=display_frames,
skip_frame_num=skip_frame_num)
self.det_action_visual_helper = ActionVisualHelper(1)
if self.with_idbased_clsaction:
......@@ -366,6 +369,8 @@ class PipePredictor(object):
threshold = idbased_clsaction_cfg['threshold']
self.modebase[basemode] = True
display_frames = idbased_clsaction_cfg['display_frames']
skip_frame_num = idbased_clsaction_cfg['skip_frame_num']
self.cls_action_predictor = ClsActionRecognizer(
model_dir,
device,
......@@ -378,7 +383,8 @@ class PipePredictor(object):
cpu_threads,
enable_mkldnn,
threshold=threshold,
display_frames=display_frames)
display_frames=display_frames,
skip_frame_num=skip_frame_num)
self.cls_action_visual_helper = ActionVisualHelper(1)
if self.with_skeleton_action:
......
......@@ -280,6 +280,10 @@ class DetActionRecognizer(object):
enable_mkldnn (bool): whether to open MKLDNN
threshold (float): The threshold of score for action feature object detection.
display_frames (int): The duration for corresponding detected action.
skip_frame_num (int): The number of frames for interval prediction. A skipped frame will
reuse the result of its last frame. If it is set to 0, no frame will be skipped. Default
is 0.
"""
def __init__(self,
......@@ -295,7 +299,8 @@ class DetActionRecognizer(object):
enable_mkldnn=False,
output_dir='output',
threshold=0.5,
display_frames=20):
display_frames=20,
skip_frame_num=0):
super(DetActionRecognizer, self).__init__()
self.detector = Detector(
model_dir=model_dir,
......@@ -313,10 +318,21 @@ class DetActionRecognizer(object):
self.threshold = threshold
self.frame_life = display_frames
self.result_history = {}
self.skip_frame_num = skip_frame_num
self.skip_frame_cnt = 0
self.id_in_last_frame = []
def predict(self, images, mot_result):
if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)):
det_result = self.detector.predict_image(images, visual=False)
result = self.postprocess(det_result, mot_result)
else:
result = self.reuse_result(mot_result)
self.skip_frame_cnt += 1
if self.skip_frame_cnt >= self.skip_frame_num:
self.skip_frame_cnt = 0
return result
def postprocess(self, det_result, mot_result):
......@@ -343,10 +359,11 @@ class DetActionRecognizer(object):
if valid_boxes.shape[0] >= 1:
action_ret['class'] = valid_boxes[0, 0]
action_ret['score'] = valid_boxes[0, 1]
self.result_history[tracker_id] = [0, self.frame_life]
self.result_history[
tracker_id] = [0, self.frame_life, valid_boxes[0, 1]]
else:
history_det, life_remain = self.result_history.get(tracker_id,
[1, 0])
history_det, life_remain, history_score = self.result_history.get(
tracker_id, [1, self.frame_life, -1.0])
action_ret['class'] = history_det
action_ret['score'] = -1.0
life_remain -= 1
......@@ -354,10 +371,48 @@ class DetActionRecognizer(object):
del (self.result_history[tracker_id])
elif tracker_id in self.result_history:
self.result_history[tracker_id][1] = life_remain
else:
self.result_history[tracker_id] = [
history_det, life_remain, history_score
]
mot_id.append(tracker_id)
act_res.append(action_ret)
result = list(zip(mot_id, act_res))
self.id_in_last_frame = mot_id
return result
def check_id_is_same(self, mot_result):
mot_bboxes = mot_result.get('boxes')
for idx in range(len(mot_bboxes)):
tracker_id = mot_bboxes[idx, 0]
if tracker_id not in self.id_in_last_frame:
return False
return True
def reuse_result(self, mot_result):
# This function reusing previous results of the same ID directly.
mot_bboxes = mot_result.get('boxes')
mot_id = []
act_res = []
for idx in range(len(mot_bboxes)):
tracker_id = mot_bboxes[idx, 0]
history_cls, life_remain, history_score = self.result_history.get(
tracker_id, [1, 0, -1.0])
life_remain -= 1
if tracker_id in self.result_history:
self.result_history[tracker_id][1] = life_remain
action_ret = {'class': history_cls, 'score': history_score}
mot_id.append(tracker_id)
act_res.append(action_ret)
result = list(zip(mot_id, act_res))
self.id_in_last_frame = mot_id
return result
......@@ -378,6 +433,9 @@ class ClsActionRecognizer(AttrDetector):
enable_mkldnn (bool): whether to open MKLDNN
threshold (float): The threshold of score for action feature object detection.
display_frames (int): The duration for corresponding detected action.
skip_frame_num (int): The number of frames for interval prediction. A skipped frame will
reuse the result of its last frame. If it is set to 0, no frame will be skipped. Default
is 0.
"""
def __init__(self,
......@@ -393,7 +451,8 @@ class ClsActionRecognizer(AttrDetector):
enable_mkldnn=False,
output_dir='output',
threshold=0.5,
display_frames=80):
display_frames=80,
skip_frame_num=0):
super(ClsActionRecognizer, self).__init__(
model_dir=model_dir,
device=device,
......@@ -410,11 +469,22 @@ class ClsActionRecognizer(AttrDetector):
self.threshold = threshold
self.frame_life = display_frames
self.result_history = {}
self.skip_frame_num = skip_frame_num
self.skip_frame_cnt = 0
self.id_in_last_frame = []
def predict_with_mot(self, images, mot_result):
if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)):
images = self.crop_half_body(images)
cls_result = self.predict_image(images, visual=False)["output"]
result = self.match_action_with_id(cls_result, mot_result)
else:
result = self.reuse_result(mot_result)
self.skip_frame_cnt += 1
if self.skip_frame_cnt >= self.skip_frame_num:
self.skip_frame_cnt = 0
return result
def crop_half_body(self, images):
......@@ -456,8 +526,8 @@ class ClsActionRecognizer(AttrDetector):
# Current now, class 0 is positive, class 1 is negative.
if cls_id_res == 1 or (cls_id_res == 0 and
cls_score_res < self.threshold):
history_cls, life_remain = self.result_history.get(tracker_id,
[1, 0])
history_cls, life_remain, history_score = self.result_history.get(
tracker_id, [1, self.frame_life, -1.0])
cls_id_res = history_cls
cls_score_res = 1 - cls_score_res
life_remain -= 1
......@@ -466,12 +536,50 @@ class ClsActionRecognizer(AttrDetector):
elif tracker_id in self.result_history:
self.result_history[tracker_id][1] = life_remain
else:
self.result_history[tracker_id] = [cls_id_res, self.frame_life]
self.result_history[
tracker_id] = [cls_id_res, life_remain, cls_score_res]
else:
self.result_history[
tracker_id] = [cls_id_res, self.frame_life, cls_score_res]
action_ret = {'class': cls_id_res, 'score': cls_score_res}
mot_id.append(tracker_id)
act_res.append(action_ret)
result = list(zip(mot_id, act_res))
self.id_in_last_frame = mot_id
return result
def check_id_is_same(self, mot_result):
mot_bboxes = mot_result.get('boxes')
for idx in range(len(mot_bboxes)):
tracker_id = mot_bboxes[idx, 0]
if tracker_id not in self.id_in_last_frame:
return False
return True
def reuse_result(self, mot_result):
# This function reusing previous results of the same ID directly.
mot_bboxes = mot_result.get('boxes')
mot_id = []
act_res = []
for idx in range(len(mot_bboxes)):
tracker_id = mot_bboxes[idx, 0]
history_cls, life_remain, history_score = self.result_history.get(
tracker_id, [1, 0, -1.0])
life_remain -= 1
if tracker_id in self.result_history:
self.result_history[tracker_id][1] = life_remain
action_ret = {'class': history_cls, 'score': history_score}
mot_id.append(tracker_id)
act_res.append(action_ret)
result = list(zip(mot_id, act_res))
self.id_in_last_frame = mot_id
return result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册