From bf89554169a57240c5ab6fd39f112a51dd440e8c Mon Sep 17 00:00:00 2001 From: JYChen Date: Mon, 20 Jun 2022 11:22:55 +0800 Subject: [PATCH] More action for pphuman (#6142) * add det based action, still in working * add det_based action * add id-based class action into pipeline * remove clsactioninferer * return history result, optimize visualization for multiple actions * use history result for det-based action model * batchinfer for cls&det models --- deploy/pphuman/config/infer_cfg.yml | 12 +- deploy/pphuman/datacollector.py | 4 +- deploy/pphuman/pipe_utils.py | 4 +- deploy/pphuman/pipeline.py | 128 +++++++++++++---- deploy/python/action_infer.py | 213 ++++++++++++++++++++++++++++ deploy/python/action_utils.py | 6 +- deploy/python/infer.py | 22 +-- deploy/python/visualize.py | 12 +- 8 files changed, 341 insertions(+), 60 deletions(-) diff --git a/deploy/pphuman/config/infer_cfg.yml b/deploy/pphuman/config/infer_cfg.yml index 31b347ec0..0d73906f1 100644 --- a/deploy/pphuman/config/infer_cfg.yml +++ b/deploy/pphuman/config/infer_cfg.yml @@ -44,15 +44,19 @@ SKELETON_ACTION: enable: False ID_BASED_DETACTION: - model_dir: output_inference/detector - batch_size: 1 + model_dir: output_inference/ppyoloe_crn_s_300e_smoking/ + batch_size: 8 basemode: "idbased" + threshold: 0.4 + display_frames: 80 enable: False ID_BASED_CLSACTION: - model_dir: output_inference/classification - batch_size: 1 + model_dir: output_inference/PPHGNet_tiny_calling_halfbody + batch_size: 8 basemode: "idbased" + threshold: 0.45 + display_frames: 80 enable: False REID: diff --git a/deploy/pphuman/datacollector.py b/deploy/pphuman/datacollector.py index 22f0cc116..06a18b68c 100644 --- a/deploy/pphuman/datacollector.py +++ b/deploy/pphuman/datacollector.py @@ -25,7 +25,9 @@ class Result(object): 'kpt': dict(), 'video_action': dict(), 'skeleton_action': dict(), - 'reid': dict() + 'reid': dict(), + 'det_action': dict(), + 'cls_action': dict(), } def update(self, res, name): diff --git a/deploy/pphuman/pipe_utils.py b/deploy/pphuman/pipe_utils.py index a8968543c..8f39d8133 100644 --- a/deploy/pphuman/pipe_utils.py +++ b/deploy/pphuman/pipe_utils.py @@ -154,7 +154,9 @@ class PipeTimer(Times): 'kpt': Times(), 'video_action': Times(), 'skeleton_action': Times(), - 'reid': Times() + 'reid': Times(), + 'det_action': Times(), + 'cls_action': Times() } self.img_num = 0 diff --git a/deploy/pphuman/pipeline.py b/deploy/pphuman/pipeline.py index e40b96182..5abd7a779 100644 --- a/deploy/pphuman/pipeline.py +++ b/deploy/pphuman/pipeline.py @@ -36,9 +36,10 @@ from python.infer import Detector, DetectorPicoDet from python.attr_infer import AttrDetector from python.keypoint_infer import KeyPointDetector from python.keypoint_postprocess import translate_to_ori_images + from python.video_action_infer import VideoActionRecognizer -from python.action_infer import SkeletonActionRecognizer -from python.action_utils import KeyPointBuff, SkeletonActionVisualHelper +from python.action_infer import SkeletonActionRecognizer, DetActionRecognizer, ClsActionRecognizer +from python.action_utils import KeyPointBuff, ActionVisualHelper from pipe_utils import argsparser, print_arguments, merge_cfg, PipeTimer from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint @@ -366,17 +367,51 @@ class PipePredictor(object): trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads, enable_mkldnn) if self.with_idbased_detaction: - idbased_detaction_cfg = self.cfg['SKELETON_ACTION'] - idbased_detaction_model_dir = idbased_detaction_cfg['model_dir'] - idbased_detaction_batch_size = idbased_detaction_cfg[ - 'batch_size'] - # IDBasedDetActionRecognizer = IDBasedDetActionRecognizer() + idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION'] + model_dir = idbased_detaction_cfg['model_dir'] + batch_size = idbased_detaction_cfg['batch_size'] + basemode = idbased_detaction_cfg['basemode'] + threshold = idbased_detaction_cfg['threshold'] + display_frames = idbased_detaction_cfg['display_frames'] + self.modebase[basemode] = True + self.det_action_predictor = DetActionRecognizer( + model_dir, + device, + run_mode, + batch_size, + trt_min_shape, + trt_max_shape, + trt_opt_shape, + trt_calib_mode, + cpu_threads, + enable_mkldnn, + threshold=threshold, + display_frames=display_frames) + self.det_action_visual_helper = ActionVisualHelper(1) + if self.with_idbased_clsaction: - idbased_clsaction_cfg = self.cfg['SKELETON_ACTION'] - idbased_clsaction_model_dir = idbased_clsaction_cfg['model_dir'] - idbased_clsaction_batch_size = idbased_clsaction_cfg[ - 'batch_size'] - # IDBasedDetActionRecognizer = IDBasedClsActionRecognizer() + idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION'] + model_dir = idbased_clsaction_cfg['model_dir'] + batch_size = idbased_clsaction_cfg['batch_size'] + basemode = idbased_clsaction_cfg['basemode'] + threshold = idbased_clsaction_cfg['threshold'] + self.modebase[basemode] = True + display_frames = idbased_clsaction_cfg['display_frames'] + self.cls_action_predictor = ClsActionRecognizer( + model_dir, + device, + run_mode, + batch_size, + trt_min_shape, + trt_max_shape, + trt_opt_shape, + trt_calib_mode, + cpu_threads, + enable_mkldnn, + threshold=threshold, + display_frames=display_frames) + self.cls_action_visual_helper = ActionVisualHelper(1) + if self.with_skeleton_action: skeleton_action_cfg = self.cfg['SKELETON_ACTION'] skeleton_action_model_dir = skeleton_action_cfg['model_dir'] @@ -399,7 +434,7 @@ class PipePredictor(object): cpu_threads, enable_mkldnn, window_size=skeleton_action_frames) - self.skeleton_action_visual_helper = SkeletonActionVisualHelper( + self.skeleton_action_visual_helper = ActionVisualHelper( display_frames) if self.modebase["skeletonbased"]: @@ -609,10 +644,10 @@ class PipePredictor(object): continue self.pipeline_res.update(mot_res, 'mot') - if self.with_attr or self.with_skeleton_action: - #todo: move this code to each class's predeal function - crop_input, new_bboxes, ori_bboxes = crop_image_with_mot( - frame, mot_res) + + #todo: move this code to each class's predeal function + crop_input, new_bboxes, ori_bboxes = crop_image_with_mot( + frame, mot_res) if self.with_attr: if frame_id > self.warmup_frame: @@ -624,16 +659,28 @@ class PipePredictor(object): self.pipeline_res.update(attr_res, 'attr') if self.with_idbased_detaction: - #predeal, get what your model need - #predict, model preprocess\run\postprocess - #postdeal, interact with pipeline - pass + if frame_id > self.warmup_frame: + self.pipe_timer.module_time['det_action'].start() + det_action_res = self.det_action_predictor.predict( + crop_input, mot_res) + if frame_id > self.warmup_frame: + self.pipe_timer.module_time['det_action'].end() + self.pipeline_res.update(det_action_res, 'det_action') + + if self.cfg['visual']: + self.det_action_visual_helper.update(det_action_res) if self.with_idbased_clsaction: - #predeal, get what your model need - #predict, model preprocess\run\postprocess - #postdeal, interact with pipeline - pass + if frame_id > self.warmup_frame: + self.pipe_timer.module_time['cls_action'].start() + cls_action_res = self.cls_action_predictor.predict_with_mot( + crop_input, mot_res) + if frame_id > self.warmup_frame: + self.pipe_timer.module_time['cls_action'].end() + self.pipeline_res.update(cls_action_res, 'cls_action') + + if self.cfg['visual']: + self.cls_action_visual_helper.update(cls_action_res) if self.with_skeleton_action: if frame_id > self.warmup_frame: @@ -805,23 +852,42 @@ class PipePredictor(object): visual_thresh=self.cfg['kpt_thresh'], returnimg=True) - skeleton_action_res = result.get('skeleton_action') video_action_res = result.get('video_action') - if skeleton_action_res is not None or video_action_res is not None: + if video_action_res is not None: video_action_score = None - action_visual_helper = None if video_action_res and video_action_res["class"] == 1: video_action_score = video_action_res["score"] - if skeleton_action_res: - action_visual_helper = self.skeleton_action_visual_helper image = visualize_action( image, mot_res['boxes'], - action_visual_collector=action_visual_helper, + action_visual_collector=None, action_text="SkeletonAction", video_action_score=video_action_score, video_action_text="Fight") + visual_helper_for_display = [] + action_to_display = [] + + skeleton_action_res = result.get('skeleton_action') + if skeleton_action_res is not None: + visual_helper_for_display.append(self.skeleton_action_visual_helper) + action_to_display.append("Falling") + + det_action_res = result.get('det_action') + if det_action_res is not None: + visual_helper_for_display.append(self.det_action_visual_helper) + action_to_display.append("Smoking") + + cls_action_res = result.get('cls_action') + if cls_action_res is not None: + visual_helper_for_display.append(self.cls_action_visual_helper) + action_to_display.append("Calling") + + if len(visual_helper_for_display) > 0: + image = visualize_action(image, mot_res['boxes'], + visual_helper_for_display, + action_to_display) + return image def visualize_image(self, im_files, images, result): diff --git a/deploy/python/action_infer.py b/deploy/python/action_infer.py index bf91f0687..4c61f06c2 100644 --- a/deploy/python/action_infer.py +++ b/deploy/python/action_infer.py @@ -31,6 +31,7 @@ from paddle.inference import Config, create_predictor from utils import argsparser, Timer, get_current_memory_mb from benchmark_utils import PaddleInferBenchmark from infer import Detector, print_arguments +from attr_infer import AttrDetector class SkeletonActionRecognizer(Detector): @@ -263,6 +264,218 @@ def get_test_skeletons(input_file): "Now only support input with shape: (N, C, T, K, M) or (C, T, K, M)") +class DetActionRecognizer(object): + """ + Args: + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + threshold (float): The threshold of score for action feature object detection. + display_frames (int): The duration for corresponding detected action. + """ + + def __init__(self, + model_dir, + device='CPU', + run_mode='paddle', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + output_dir='output', + threshold=0.5, + display_frames=20): + super(DetActionRecognizer, self).__init__() + self.detector = Detector( + model_dir=model_dir, + device=device, + run_mode=run_mode, + batch_size=batch_size, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn, + output_dir=output_dir, + threshold=threshold) + self.threshold = threshold + self.frame_life = display_frames + self.result_history = {} + + def predict(self, images, mot_result): + det_result = self.detector.predict_image(images, visual=False) + result = self.postprocess(det_result, mot_result) + return result + + def postprocess(self, det_result, mot_result): + np_boxes_num = det_result['boxes_num'] + if np_boxes_num[0] <= 0: + return [[], []] + + mot_bboxes = mot_result.get('boxes') + + cur_box_idx = 0 + mot_id = [] + act_res = [] + for idx in range(len(mot_bboxes)): + tracker_id = mot_bboxes[idx, 0] + + # Current now, class 0 is positive, class 1 is negative. + action_ret = {'class': 1.0, 'score': -1.0} + box_num = np_boxes_num[idx] + boxes = det_result['boxes'][cur_box_idx:cur_box_idx + box_num] + cur_box_idx += box_num + isvalid = (boxes[:, 1] > self.threshold) & (boxes[:, 0] == 0) + valid_boxes = boxes[isvalid, :] + + if valid_boxes.shape[0] >= 1: + action_ret['class'] = valid_boxes[0, 0] + action_ret['score'] = valid_boxes[0, 1] + self.result_history[tracker_id] = [0, self.frame_life] + else: + history_det, life_remain = self.result_history.get(tracker_id, + [1, 0]) + action_ret['class'] = history_det + action_ret['score'] = -1.0 + life_remain -= 1 + if life_remain <= 0 and tracker_id in self.result_history: + del (self.result_history[tracker_id]) + elif tracker_id in self.result_history: + self.result_history[tracker_id][1] = life_remain + + mot_id.append(tracker_id) + act_res.append(action_ret) + result = list(zip(mot_id, act_res)) + + return result + + +class ClsActionRecognizer(AttrDetector): + """ + Args: + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + threshold (float): The threshold of score for action feature object detection. + display_frames (int): The duration for corresponding detected action. + """ + + def __init__(self, + model_dir, + device='CPU', + run_mode='paddle', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + output_dir='output', + threshold=0.5, + display_frames=80): + super(ClsActionRecognizer, self).__init__( + model_dir=model_dir, + device=device, + run_mode=run_mode, + batch_size=batch_size, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn, + output_dir=output_dir, + threshold=threshold) + self.threshold = threshold + self.frame_life = display_frames + self.result_history = {} + + def predict_with_mot(self, images, mot_result): + images = self.crop_half_body(images) + cls_result = self.predict_image(images, visual=False)["output"] + result = self.match_action_with_id(cls_result, mot_result) + return result + + def crop_half_body(self, images): + crop_images = [] + for image in images: + h = image.shape[0] + crop_images.append(image[:h // 2 + 1, :, :]) + return crop_images + + def postprocess(self, inputs, result): + # postprocess output of predictor + im_results = result['output'] + batch_res = [] + for res in im_results: + action_res = res.tolist() + for cid, score in enumerate(action_res): + action_res[cid] = score + batch_res.append(action_res) + result = {'output': batch_res} + return result + + def match_action_with_id(self, cls_result, mot_result): + mot_bboxes = mot_result.get('boxes') + + mot_id = [] + act_res = [] + + for idx in range(len(mot_bboxes)): + tracker_id = mot_bboxes[idx, 0] + + cls_id_res = 1 + cls_score_res = -1.0 + for cls_id in range(len(cls_result[idx])): + score = cls_result[idx][cls_id] + if score > cls_score_res: + cls_id_res = cls_id + cls_score_res = score + + # Current now, class 0 is positive, class 1 is negative. + if cls_id_res == 1 or (cls_id_res == 0 and + cls_score_res < self.threshold): + history_cls, life_remain = self.result_history.get(tracker_id, + [1, 0]) + cls_id_res = history_cls + cls_score_res = 1 - cls_score_res + life_remain -= 1 + if life_remain <= 0 and tracker_id in self.result_history: + del (self.result_history[tracker_id]) + elif tracker_id in self.result_history: + self.result_history[tracker_id][1] = life_remain + else: + self.result_history[tracker_id] = [cls_id_res, self.frame_life] + + action_ret = {'class': cls_id_res, 'score': cls_score_res} + mot_id.append(tracker_id) + act_res.append(action_ret) + result = list(zip(mot_id, act_res)) + + return result + + def main(): detector = SkeletonActionRecognizer( FLAGS.model_dir, diff --git a/deploy/python/action_utils.py b/deploy/python/action_utils.py index e62d202fa..483116584 100644 --- a/deploy/python/action_utils.py +++ b/deploy/python/action_utils.py @@ -80,7 +80,7 @@ class KeyPointBuff(object): return output -class SkeletonActionVisualHelper(object): +class ActionVisualHelper(object): def __init__(self, frame_life=20): self.frame_life = frame_life self.action_history = {} @@ -104,6 +104,10 @@ class SkeletonActionVisualHelper(object): def update(self, action_res_list): for mot_id, action_res in action_res_list: + if mot_id in self.action_history: + if int(action_res["class"]) != 0 and int(self.action_history[ + mot_id]["class"]) == 0: + continue action_info = self.action_history.get(mot_id, {}) action_info["class"] = action_res["class"] action_info["life_remain"] = self.frame_life diff --git a/deploy/python/infer.py b/deploy/python/infer.py index e16c3e95c..4ce496da4 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -40,25 +40,9 @@ from utils import argsparser, Timer, get_current_memory_mb # Global dictionary SUPPORT_MODELS = { - 'YOLO', - 'RCNN', - 'SSD', - 'Face', - 'FCOS', - 'SOLOv2', - 'TTFNet', - 'S2ANet', - 'JDE', - 'FairMOT', - 'DeepSORT', - 'GFL', - 'PicoDet', - 'CenterNet', - 'TOOD', - 'RetinaNet', - 'StrongBaseline', - 'STGCN', - 'YOLOX', + 'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE', + 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet', + 'StrongBaseline', 'STGCN', 'YOLOX', 'PPHGNet' } diff --git a/deploy/python/visualize.py b/deploy/python/visualize.py index 4bffc4d80..427fac641 100644 --- a/deploy/python/visualize.py +++ b/deploy/python/visualize.py @@ -378,13 +378,19 @@ def visualize_action(im, text_thickness = 2 if action_visual_collector: - id_detected = action_visual_collector.get_visualize_ids() + id_action_dict = {} + for collector, action_type in zip(action_visual_collector, action_text): + id_detected = collector.get_visualize_ids() + for pid in id_detected: + id_action_dict[pid] = id_action_dict.get(pid, []) + id_action_dict[pid].append(action_type) for mot_box in mot_boxes: # mot_box is a format with [mot_id, class, score, xmin, ymin, w, h] - if mot_box[0] in id_detected: + if mot_box[0] in id_action_dict: text_position = (int(mot_box[3] + mot_box[5] * 0.75), int(mot_box[4] - 10)) - cv2.putText(im, action_text, text_position, + display_text = ', '.join(id_action_dict[mot_box[0]]) + cv2.putText(im, display_text, text_position, cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255), 2) if video_action_score: -- GitLab