未验证 提交 23e9cb95 编写于 作者: J JYChen 提交者: GitHub

add frame-skip to boost inference (#6383)

上级 9e5f22ae
...@@ -50,6 +50,7 @@ ID_BASED_DETACTION: ...@@ -50,6 +50,7 @@ ID_BASED_DETACTION:
basemode: "idbased" basemode: "idbased"
threshold: 0.6 threshold: 0.6
display_frames: 80 display_frames: 80
skip_frame_num: 2
enable: False enable: False
ID_BASED_CLSACTION: ID_BASED_CLSACTION:
...@@ -58,6 +59,7 @@ ID_BASED_CLSACTION: ...@@ -58,6 +59,7 @@ ID_BASED_CLSACTION:
basemode: "idbased" basemode: "idbased"
threshold: 0.8 threshold: 0.8
display_frames: 80 display_frames: 80
skip_frame_num: 2
enable: False enable: False
REID: REID:
......
...@@ -342,7 +342,9 @@ class PipePredictor(object): ...@@ -342,7 +342,9 @@ class PipePredictor(object):
basemode = idbased_detaction_cfg['basemode'] basemode = idbased_detaction_cfg['basemode']
threshold = idbased_detaction_cfg['threshold'] threshold = idbased_detaction_cfg['threshold']
display_frames = idbased_detaction_cfg['display_frames'] display_frames = idbased_detaction_cfg['display_frames']
skip_frame_num = idbased_detaction_cfg['skip_frame_num']
self.modebase[basemode] = True self.modebase[basemode] = True
self.det_action_predictor = DetActionRecognizer( self.det_action_predictor = DetActionRecognizer(
model_dir, model_dir,
device, device,
...@@ -355,7 +357,8 @@ class PipePredictor(object): ...@@ -355,7 +357,8 @@ class PipePredictor(object):
cpu_threads, cpu_threads,
enable_mkldnn, enable_mkldnn,
threshold=threshold, threshold=threshold,
display_frames=display_frames) display_frames=display_frames,
skip_frame_num=skip_frame_num)
self.det_action_visual_helper = ActionVisualHelper(1) self.det_action_visual_helper = ActionVisualHelper(1)
if self.with_idbased_clsaction: if self.with_idbased_clsaction:
...@@ -366,6 +369,8 @@ class PipePredictor(object): ...@@ -366,6 +369,8 @@ class PipePredictor(object):
threshold = idbased_clsaction_cfg['threshold'] threshold = idbased_clsaction_cfg['threshold']
self.modebase[basemode] = True self.modebase[basemode] = True
display_frames = idbased_clsaction_cfg['display_frames'] display_frames = idbased_clsaction_cfg['display_frames']
skip_frame_num = idbased_clsaction_cfg['skip_frame_num']
self.cls_action_predictor = ClsActionRecognizer( self.cls_action_predictor = ClsActionRecognizer(
model_dir, model_dir,
device, device,
...@@ -378,7 +383,8 @@ class PipePredictor(object): ...@@ -378,7 +383,8 @@ class PipePredictor(object):
cpu_threads, cpu_threads,
enable_mkldnn, enable_mkldnn,
threshold=threshold, threshold=threshold,
display_frames=display_frames) display_frames=display_frames,
skip_frame_num=skip_frame_num)
self.cls_action_visual_helper = ActionVisualHelper(1) self.cls_action_visual_helper = ActionVisualHelper(1)
if self.with_skeleton_action: if self.with_skeleton_action:
......
...@@ -279,7 +279,11 @@ class DetActionRecognizer(object): ...@@ -279,7 +279,11 @@ class DetActionRecognizer(object):
cpu_threads (int): cpu threads cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
threshold (float): The threshold of score for action feature object detection. threshold (float): The threshold of score for action feature object detection.
display_frames (int): The duration for corresponding detected action. display_frames (int): The duration for corresponding detected action.
skip_frame_num (int): The number of frames for interval prediction. A skipped frame will
reuse the result of its last frame. If it is set to 0, no frame will be skipped. Default
is 0.
""" """
def __init__(self, def __init__(self,
...@@ -295,7 +299,8 @@ class DetActionRecognizer(object): ...@@ -295,7 +299,8 @@ class DetActionRecognizer(object):
enable_mkldnn=False, enable_mkldnn=False,
output_dir='output', output_dir='output',
threshold=0.5, threshold=0.5,
display_frames=20): display_frames=20,
skip_frame_num=0):
super(DetActionRecognizer, self).__init__() super(DetActionRecognizer, self).__init__()
self.detector = Detector( self.detector = Detector(
model_dir=model_dir, model_dir=model_dir,
...@@ -313,10 +318,21 @@ class DetActionRecognizer(object): ...@@ -313,10 +318,21 @@ class DetActionRecognizer(object):
self.threshold = threshold self.threshold = threshold
self.frame_life = display_frames self.frame_life = display_frames
self.result_history = {} self.result_history = {}
self.skip_frame_num = skip_frame_num
self.skip_frame_cnt = 0
self.id_in_last_frame = []
def predict(self, images, mot_result): def predict(self, images, mot_result):
det_result = self.detector.predict_image(images, visual=False) if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)):
result = self.postprocess(det_result, mot_result) det_result = self.detector.predict_image(images, visual=False)
result = self.postprocess(det_result, mot_result)
else:
result = self.reuse_result(mot_result)
self.skip_frame_cnt += 1
if self.skip_frame_cnt >= self.skip_frame_num:
self.skip_frame_cnt = 0
return result return result
def postprocess(self, det_result, mot_result): def postprocess(self, det_result, mot_result):
...@@ -343,10 +359,11 @@ class DetActionRecognizer(object): ...@@ -343,10 +359,11 @@ class DetActionRecognizer(object):
if valid_boxes.shape[0] >= 1: if valid_boxes.shape[0] >= 1:
action_ret['class'] = valid_boxes[0, 0] action_ret['class'] = valid_boxes[0, 0]
action_ret['score'] = valid_boxes[0, 1] action_ret['score'] = valid_boxes[0, 1]
self.result_history[tracker_id] = [0, self.frame_life] self.result_history[
tracker_id] = [0, self.frame_life, valid_boxes[0, 1]]
else: else:
history_det, life_remain = self.result_history.get(tracker_id, history_det, life_remain, history_score = self.result_history.get(
[1, 0]) tracker_id, [1, self.frame_life, -1.0])
action_ret['class'] = history_det action_ret['class'] = history_det
action_ret['score'] = -1.0 action_ret['score'] = -1.0
life_remain -= 1 life_remain -= 1
...@@ -354,10 +371,48 @@ class DetActionRecognizer(object): ...@@ -354,10 +371,48 @@ class DetActionRecognizer(object):
del (self.result_history[tracker_id]) del (self.result_history[tracker_id])
elif tracker_id in self.result_history: elif tracker_id in self.result_history:
self.result_history[tracker_id][1] = life_remain self.result_history[tracker_id][1] = life_remain
else:
self.result_history[tracker_id] = [
history_det, life_remain, history_score
]
mot_id.append(tracker_id)
act_res.append(action_ret)
result = list(zip(mot_id, act_res))
self.id_in_last_frame = mot_id
return result
def check_id_is_same(self, mot_result):
mot_bboxes = mot_result.get('boxes')
for idx in range(len(mot_bboxes)):
tracker_id = mot_bboxes[idx, 0]
if tracker_id not in self.id_in_last_frame:
return False
return True
def reuse_result(self, mot_result):
# This function reusing previous results of the same ID directly.
mot_bboxes = mot_result.get('boxes')
mot_id = []
act_res = []
for idx in range(len(mot_bboxes)):
tracker_id = mot_bboxes[idx, 0]
history_cls, life_remain, history_score = self.result_history.get(
tracker_id, [1, 0, -1.0])
life_remain -= 1
if tracker_id in self.result_history:
self.result_history[tracker_id][1] = life_remain
action_ret = {'class': history_cls, 'score': history_score}
mot_id.append(tracker_id) mot_id.append(tracker_id)
act_res.append(action_ret) act_res.append(action_ret)
result = list(zip(mot_id, act_res)) result = list(zip(mot_id, act_res))
self.id_in_last_frame = mot_id
return result return result
...@@ -378,6 +433,9 @@ class ClsActionRecognizer(AttrDetector): ...@@ -378,6 +433,9 @@ class ClsActionRecognizer(AttrDetector):
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
threshold (float): The threshold of score for action feature object detection. threshold (float): The threshold of score for action feature object detection.
display_frames (int): The duration for corresponding detected action. display_frames (int): The duration for corresponding detected action.
skip_frame_num (int): The number of frames for interval prediction. A skipped frame will
reuse the result of its last frame. If it is set to 0, no frame will be skipped. Default
is 0.
""" """
def __init__(self, def __init__(self,
...@@ -393,7 +451,8 @@ class ClsActionRecognizer(AttrDetector): ...@@ -393,7 +451,8 @@ class ClsActionRecognizer(AttrDetector):
enable_mkldnn=False, enable_mkldnn=False,
output_dir='output', output_dir='output',
threshold=0.5, threshold=0.5,
display_frames=80): display_frames=80,
skip_frame_num=0):
super(ClsActionRecognizer, self).__init__( super(ClsActionRecognizer, self).__init__(
model_dir=model_dir, model_dir=model_dir,
device=device, device=device,
...@@ -410,11 +469,22 @@ class ClsActionRecognizer(AttrDetector): ...@@ -410,11 +469,22 @@ class ClsActionRecognizer(AttrDetector):
self.threshold = threshold self.threshold = threshold
self.frame_life = display_frames self.frame_life = display_frames
self.result_history = {} self.result_history = {}
self.skip_frame_num = skip_frame_num
self.skip_frame_cnt = 0
self.id_in_last_frame = []
def predict_with_mot(self, images, mot_result): def predict_with_mot(self, images, mot_result):
images = self.crop_half_body(images) if self.skip_frame_cnt == 0 or (not self.check_id_is_same(mot_result)):
cls_result = self.predict_image(images, visual=False)["output"] images = self.crop_half_body(images)
result = self.match_action_with_id(cls_result, mot_result) cls_result = self.predict_image(images, visual=False)["output"]
result = self.match_action_with_id(cls_result, mot_result)
else:
result = self.reuse_result(mot_result)
self.skip_frame_cnt += 1
if self.skip_frame_cnt >= self.skip_frame_num:
self.skip_frame_cnt = 0
return result return result
def crop_half_body(self, images): def crop_half_body(self, images):
...@@ -456,8 +526,8 @@ class ClsActionRecognizer(AttrDetector): ...@@ -456,8 +526,8 @@ class ClsActionRecognizer(AttrDetector):
# Current now, class 0 is positive, class 1 is negative. # Current now, class 0 is positive, class 1 is negative.
if cls_id_res == 1 or (cls_id_res == 0 and if cls_id_res == 1 or (cls_id_res == 0 and
cls_score_res < self.threshold): cls_score_res < self.threshold):
history_cls, life_remain = self.result_history.get(tracker_id, history_cls, life_remain, history_score = self.result_history.get(
[1, 0]) tracker_id, [1, self.frame_life, -1.0])
cls_id_res = history_cls cls_id_res = history_cls
cls_score_res = 1 - cls_score_res cls_score_res = 1 - cls_score_res
life_remain -= 1 life_remain -= 1
...@@ -465,13 +535,51 @@ class ClsActionRecognizer(AttrDetector): ...@@ -465,13 +535,51 @@ class ClsActionRecognizer(AttrDetector):
del (self.result_history[tracker_id]) del (self.result_history[tracker_id])
elif tracker_id in self.result_history: elif tracker_id in self.result_history:
self.result_history[tracker_id][1] = life_remain self.result_history[tracker_id][1] = life_remain
else:
self.result_history[
tracker_id] = [cls_id_res, life_remain, cls_score_res]
else: else:
self.result_history[tracker_id] = [cls_id_res, self.frame_life] self.result_history[
tracker_id] = [cls_id_res, self.frame_life, cls_score_res]
action_ret = {'class': cls_id_res, 'score': cls_score_res} action_ret = {'class': cls_id_res, 'score': cls_score_res}
mot_id.append(tracker_id) mot_id.append(tracker_id)
act_res.append(action_ret) act_res.append(action_ret)
result = list(zip(mot_id, act_res)) result = list(zip(mot_id, act_res))
self.id_in_last_frame = mot_id
return result
def check_id_is_same(self, mot_result):
mot_bboxes = mot_result.get('boxes')
for idx in range(len(mot_bboxes)):
tracker_id = mot_bboxes[idx, 0]
if tracker_id not in self.id_in_last_frame:
return False
return True
def reuse_result(self, mot_result):
# This function reusing previous results of the same ID directly.
mot_bboxes = mot_result.get('boxes')
mot_id = []
act_res = []
for idx in range(len(mot_bboxes)):
tracker_id = mot_bboxes[idx, 0]
history_cls, life_remain, history_score = self.result_history.get(
tracker_id, [1, 0, -1.0])
life_remain -= 1
if tracker_id in self.result_history:
self.result_history[tracker_id][1] = life_remain
action_ret = {'class': history_cls, 'score': history_score}
mot_id.append(tracker_id)
act_res.append(action_ret)
result = list(zip(mot_id, act_res))
self.id_in_last_frame = mot_id
return result return result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册