未验证 提交 bf895541 编写于 作者: J JYChen 提交者: GitHub

More action for pphuman (#6142)

* add det based action, still in working

* add det_based action

* add id-based class action into pipeline

* remove clsactioninferer

* return history result, optimize visualization for multiple actions

* use history result for det-based action model

* batchinfer for cls&det models
上级 195ebcf0
......@@ -44,15 +44,19 @@ SKELETON_ACTION:
enable: False
ID_BASED_DETACTION:
model_dir: output_inference/detector
batch_size: 1
model_dir: output_inference/ppyoloe_crn_s_300e_smoking/
batch_size: 8
basemode: "idbased"
threshold: 0.4
display_frames: 80
enable: False
ID_BASED_CLSACTION:
model_dir: output_inference/classification
batch_size: 1
model_dir: output_inference/PPHGNet_tiny_calling_halfbody
batch_size: 8
basemode: "idbased"
threshold: 0.45
display_frames: 80
enable: False
REID:
......
......@@ -25,7 +25,9 @@ class Result(object):
'kpt': dict(),
'video_action': dict(),
'skeleton_action': dict(),
'reid': dict()
'reid': dict(),
'det_action': dict(),
'cls_action': dict(),
}
def update(self, res, name):
......
......@@ -154,7 +154,9 @@ class PipeTimer(Times):
'kpt': Times(),
'video_action': Times(),
'skeleton_action': Times(),
'reid': Times()
'reid': Times(),
'det_action': Times(),
'cls_action': Times()
}
self.img_num = 0
......
......@@ -36,9 +36,10 @@ from python.infer import Detector, DetectorPicoDet
from python.attr_infer import AttrDetector
from python.keypoint_infer import KeyPointDetector
from python.keypoint_postprocess import translate_to_ori_images
from python.video_action_infer import VideoActionRecognizer
from python.action_infer import SkeletonActionRecognizer
from python.action_utils import KeyPointBuff, SkeletonActionVisualHelper
from python.action_infer import SkeletonActionRecognizer, DetActionRecognizer, ClsActionRecognizer
from python.action_utils import KeyPointBuff, ActionVisualHelper
from pipe_utils import argsparser, print_arguments, merge_cfg, PipeTimer
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint
......@@ -366,17 +367,51 @@ class PipePredictor(object):
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
enable_mkldnn)
if self.with_idbased_detaction:
idbased_detaction_cfg = self.cfg['SKELETON_ACTION']
idbased_detaction_model_dir = idbased_detaction_cfg['model_dir']
idbased_detaction_batch_size = idbased_detaction_cfg[
'batch_size']
# IDBasedDetActionRecognizer = IDBasedDetActionRecognizer()
idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION']
model_dir = idbased_detaction_cfg['model_dir']
batch_size = idbased_detaction_cfg['batch_size']
basemode = idbased_detaction_cfg['basemode']
threshold = idbased_detaction_cfg['threshold']
display_frames = idbased_detaction_cfg['display_frames']
self.modebase[basemode] = True
self.det_action_predictor = DetActionRecognizer(
model_dir,
device,
run_mode,
batch_size,
trt_min_shape,
trt_max_shape,
trt_opt_shape,
trt_calib_mode,
cpu_threads,
enable_mkldnn,
threshold=threshold,
display_frames=display_frames)
self.det_action_visual_helper = ActionVisualHelper(1)
if self.with_idbased_clsaction:
idbased_clsaction_cfg = self.cfg['SKELETON_ACTION']
idbased_clsaction_model_dir = idbased_clsaction_cfg['model_dir']
idbased_clsaction_batch_size = idbased_clsaction_cfg[
'batch_size']
# IDBasedDetActionRecognizer = IDBasedClsActionRecognizer()
idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION']
model_dir = idbased_clsaction_cfg['model_dir']
batch_size = idbased_clsaction_cfg['batch_size']
basemode = idbased_clsaction_cfg['basemode']
threshold = idbased_clsaction_cfg['threshold']
self.modebase[basemode] = True
display_frames = idbased_clsaction_cfg['display_frames']
self.cls_action_predictor = ClsActionRecognizer(
model_dir,
device,
run_mode,
batch_size,
trt_min_shape,
trt_max_shape,
trt_opt_shape,
trt_calib_mode,
cpu_threads,
enable_mkldnn,
threshold=threshold,
display_frames=display_frames)
self.cls_action_visual_helper = ActionVisualHelper(1)
if self.with_skeleton_action:
skeleton_action_cfg = self.cfg['SKELETON_ACTION']
skeleton_action_model_dir = skeleton_action_cfg['model_dir']
......@@ -399,7 +434,7 @@ class PipePredictor(object):
cpu_threads,
enable_mkldnn,
window_size=skeleton_action_frames)
self.skeleton_action_visual_helper = SkeletonActionVisualHelper(
self.skeleton_action_visual_helper = ActionVisualHelper(
display_frames)
if self.modebase["skeletonbased"]:
......@@ -609,10 +644,10 @@ class PipePredictor(object):
continue
self.pipeline_res.update(mot_res, 'mot')
if self.with_attr or self.with_skeleton_action:
#todo: move this code to each class's predeal function
crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
frame, mot_res)
#todo: move this code to each class's predeal function
crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
frame, mot_res)
if self.with_attr:
if frame_id > self.warmup_frame:
......@@ -624,16 +659,28 @@ class PipePredictor(object):
self.pipeline_res.update(attr_res, 'attr')
if self.with_idbased_detaction:
#predeal, get what your model need
#predict, model preprocess\run\postprocess
#postdeal, interact with pipeline
pass
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['det_action'].start()
det_action_res = self.det_action_predictor.predict(
crop_input, mot_res)
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['det_action'].end()
self.pipeline_res.update(det_action_res, 'det_action')
if self.cfg['visual']:
self.det_action_visual_helper.update(det_action_res)
if self.with_idbased_clsaction:
#predeal, get what your model need
#predict, model preprocess\run\postprocess
#postdeal, interact with pipeline
pass
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['cls_action'].start()
cls_action_res = self.cls_action_predictor.predict_with_mot(
crop_input, mot_res)
if frame_id > self.warmup_frame:
self.pipe_timer.module_time['cls_action'].end()
self.pipeline_res.update(cls_action_res, 'cls_action')
if self.cfg['visual']:
self.cls_action_visual_helper.update(cls_action_res)
if self.with_skeleton_action:
if frame_id > self.warmup_frame:
......@@ -805,23 +852,42 @@ class PipePredictor(object):
visual_thresh=self.cfg['kpt_thresh'],
returnimg=True)
skeleton_action_res = result.get('skeleton_action')
video_action_res = result.get('video_action')
if skeleton_action_res is not None or video_action_res is not None:
if video_action_res is not None:
video_action_score = None
action_visual_helper = None
if video_action_res and video_action_res["class"] == 1:
video_action_score = video_action_res["score"]
if skeleton_action_res:
action_visual_helper = self.skeleton_action_visual_helper
image = visualize_action(
image,
mot_res['boxes'],
action_visual_collector=action_visual_helper,
action_visual_collector=None,
action_text="SkeletonAction",
video_action_score=video_action_score,
video_action_text="Fight")
visual_helper_for_display = []
action_to_display = []
skeleton_action_res = result.get('skeleton_action')
if skeleton_action_res is not None:
visual_helper_for_display.append(self.skeleton_action_visual_helper)
action_to_display.append("Falling")
det_action_res = result.get('det_action')
if det_action_res is not None:
visual_helper_for_display.append(self.det_action_visual_helper)
action_to_display.append("Smoking")
cls_action_res = result.get('cls_action')
if cls_action_res is not None:
visual_helper_for_display.append(self.cls_action_visual_helper)
action_to_display.append("Calling")
if len(visual_helper_for_display) > 0:
image = visualize_action(image, mot_res['boxes'],
visual_helper_for_display,
action_to_display)
return image
def visualize_image(self, im_files, images, result):
......
......@@ -31,6 +31,7 @@ from paddle.inference import Config, create_predictor
from utils import argsparser, Timer, get_current_memory_mb
from benchmark_utils import PaddleInferBenchmark
from infer import Detector, print_arguments
from attr_infer import AttrDetector
class SkeletonActionRecognizer(Detector):
......@@ -263,6 +264,218 @@ def get_test_skeletons(input_file):
"Now only support input with shape: (N, C, T, K, M) or (C, T, K, M)")
class DetActionRecognizer(object):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
threshold (float): The threshold of score for action feature object detection.
display_frames (int): The duration for corresponding detected action.
"""
def __init__(self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
output_dir='output',
threshold=0.5,
display_frames=20):
super(DetActionRecognizer, self).__init__()
self.detector = Detector(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold)
self.threshold = threshold
self.frame_life = display_frames
self.result_history = {}
def predict(self, images, mot_result):
det_result = self.detector.predict_image(images, visual=False)
result = self.postprocess(det_result, mot_result)
return result
def postprocess(self, det_result, mot_result):
np_boxes_num = det_result['boxes_num']
if np_boxes_num[0] <= 0:
return [[], []]
mot_bboxes = mot_result.get('boxes')
cur_box_idx = 0
mot_id = []
act_res = []
for idx in range(len(mot_bboxes)):
tracker_id = mot_bboxes[idx, 0]
# Current now, class 0 is positive, class 1 is negative.
action_ret = {'class': 1.0, 'score': -1.0}
box_num = np_boxes_num[idx]
boxes = det_result['boxes'][cur_box_idx:cur_box_idx + box_num]
cur_box_idx += box_num
isvalid = (boxes[:, 1] > self.threshold) & (boxes[:, 0] == 0)
valid_boxes = boxes[isvalid, :]
if valid_boxes.shape[0] >= 1:
action_ret['class'] = valid_boxes[0, 0]
action_ret['score'] = valid_boxes[0, 1]
self.result_history[tracker_id] = [0, self.frame_life]
else:
history_det, life_remain = self.result_history.get(tracker_id,
[1, 0])
action_ret['class'] = history_det
action_ret['score'] = -1.0
life_remain -= 1
if life_remain <= 0 and tracker_id in self.result_history:
del (self.result_history[tracker_id])
elif tracker_id in self.result_history:
self.result_history[tracker_id][1] = life_remain
mot_id.append(tracker_id)
act_res.append(action_ret)
result = list(zip(mot_id, act_res))
return result
class ClsActionRecognizer(AttrDetector):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
threshold (float): The threshold of score for action feature object detection.
display_frames (int): The duration for corresponding detected action.
"""
def __init__(self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
output_dir='output',
threshold=0.5,
display_frames=80):
super(ClsActionRecognizer, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold)
self.threshold = threshold
self.frame_life = display_frames
self.result_history = {}
def predict_with_mot(self, images, mot_result):
images = self.crop_half_body(images)
cls_result = self.predict_image(images, visual=False)["output"]
result = self.match_action_with_id(cls_result, mot_result)
return result
def crop_half_body(self, images):
crop_images = []
for image in images:
h = image.shape[0]
crop_images.append(image[:h // 2 + 1, :, :])
return crop_images
def postprocess(self, inputs, result):
# postprocess output of predictor
im_results = result['output']
batch_res = []
for res in im_results:
action_res = res.tolist()
for cid, score in enumerate(action_res):
action_res[cid] = score
batch_res.append(action_res)
result = {'output': batch_res}
return result
def match_action_with_id(self, cls_result, mot_result):
mot_bboxes = mot_result.get('boxes')
mot_id = []
act_res = []
for idx in range(len(mot_bboxes)):
tracker_id = mot_bboxes[idx, 0]
cls_id_res = 1
cls_score_res = -1.0
for cls_id in range(len(cls_result[idx])):
score = cls_result[idx][cls_id]
if score > cls_score_res:
cls_id_res = cls_id
cls_score_res = score
# Current now, class 0 is positive, class 1 is negative.
if cls_id_res == 1 or (cls_id_res == 0 and
cls_score_res < self.threshold):
history_cls, life_remain = self.result_history.get(tracker_id,
[1, 0])
cls_id_res = history_cls
cls_score_res = 1 - cls_score_res
life_remain -= 1
if life_remain <= 0 and tracker_id in self.result_history:
del (self.result_history[tracker_id])
elif tracker_id in self.result_history:
self.result_history[tracker_id][1] = life_remain
else:
self.result_history[tracker_id] = [cls_id_res, self.frame_life]
action_ret = {'class': cls_id_res, 'score': cls_score_res}
mot_id.append(tracker_id)
act_res.append(action_ret)
result = list(zip(mot_id, act_res))
return result
def main():
detector = SkeletonActionRecognizer(
FLAGS.model_dir,
......
......@@ -80,7 +80,7 @@ class KeyPointBuff(object):
return output
class SkeletonActionVisualHelper(object):
class ActionVisualHelper(object):
def __init__(self, frame_life=20):
self.frame_life = frame_life
self.action_history = {}
......@@ -104,6 +104,10 @@ class SkeletonActionVisualHelper(object):
def update(self, action_res_list):
for mot_id, action_res in action_res_list:
if mot_id in self.action_history:
if int(action_res["class"]) != 0 and int(self.action_history[
mot_id]["class"]) == 0:
continue
action_info = self.action_history.get(mot_id, {})
action_info["class"] = action_res["class"]
action_info["life_remain"] = self.frame_life
......
......@@ -40,25 +40,9 @@ from utils import argsparser, Timer, get_current_memory_mb
# Global dictionary
SUPPORT_MODELS = {
'YOLO',
'RCNN',
'SSD',
'Face',
'FCOS',
'SOLOv2',
'TTFNet',
'S2ANet',
'JDE',
'FairMOT',
'DeepSORT',
'GFL',
'PicoDet',
'CenterNet',
'TOOD',
'RetinaNet',
'StrongBaseline',
'STGCN',
'YOLOX',
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
'StrongBaseline', 'STGCN', 'YOLOX', 'PPHGNet'
}
......
......@@ -378,13 +378,19 @@ def visualize_action(im,
text_thickness = 2
if action_visual_collector:
id_detected = action_visual_collector.get_visualize_ids()
id_action_dict = {}
for collector, action_type in zip(action_visual_collector, action_text):
id_detected = collector.get_visualize_ids()
for pid in id_detected:
id_action_dict[pid] = id_action_dict.get(pid, [])
id_action_dict[pid].append(action_type)
for mot_box in mot_boxes:
# mot_box is a format with [mot_id, class, score, xmin, ymin, w, h]
if mot_box[0] in id_detected:
if mot_box[0] in id_action_dict:
text_position = (int(mot_box[3] + mot_box[5] * 0.75),
int(mot_box[4] - 10))
cv2.putText(im, action_text, text_position,
display_text = ', '.join(id_action_dict[mot_box[0]])
cv2.putText(im, display_text, text_position,
cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255), 2)
if video_action_score:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册