未验证 提交 7018dad1 编写于 作者: J JYChen 提交者: GitHub

Pipeline with kpt and act (#5399)

* add keypoint infer and visualize into Pipeline

* add independent action model inference

* add action inference into pipeline, still in working

* test different display frames and normalization methods

* use bbox and scale normalization

* Remove debug info and Optimize code structure

* remove useless visual param

* make action parameters configurable
上级 6a17524f
crop_thresh: 0.5
attr_thresh: 0.5
kpt_thresh: 0.2
visual: True
DET:
model_dir: output_inference/mot_ppyolov3//
model_dir: output_inference/mot_ppyolov3/
batch_size: 1
ATTR:
......@@ -14,3 +15,14 @@ MOT:
model_dir: output_inference/mot_ppyolov3/
tracker_config: deploy/pphuman/config/tracker_config.yml
batch_size: 1
KPT:
model_dir: output_inference/dark_hrnet_w32_256x192/
batch_size: 8
ACTION:
model_dir: output_inference/STGCN
batch_size: 1
max_frames: 50
display_frames: 80
coord_size: [384, 512]
......@@ -290,11 +290,15 @@ def crop_image_with_det(batch_input, det_res):
def crop_image_with_mot(input, mot_res):
res = mot_res['boxes']
crop_res = []
new_bboxes = []
ori_bboxes = []
for box in res:
crop_image, new_box, ori_box = expand_crop(input, box[1:])
crop_image, new_bbox, ori_bbox = expand_crop(input, box[1:])
if crop_image is not None:
crop_res.append(crop_image)
return crop_res
new_bboxes.append(new_bbox)
ori_bboxes.append(ori_bbox)
return crop_res, new_bboxes, ori_bboxes
def parse_mot_res(input):
......@@ -305,3 +309,33 @@ def parse_mot_res(input):
res = [i, 0, score, xmin, ymin, xmin + w, ymin + h]
mot_res.append(res)
return {'boxes': np.array(mot_res)}
def refine_keypoint_coordinary(kpts, bbox, coord_size):
"""
This function is used to adjust coordinate values to a fixed scale.
"""
tl = bbox[:, 0:2]
wh = bbox[:, 2:] - tl
tl = np.expand_dims(np.transpose(tl, (1, 0)), (2, 3))
wh = np.expand_dims(np.transpose(wh, (1, 0)), (2, 3))
target_w, target_h = coord_size
res = (kpts - tl) / wh * np.expand_dims(
np.array([[target_w], [target_h]]), (2, 3))
return res
def parse_mot_keypoint(input, coord_size):
parsed_skeleton_with_mot = {}
ids = []
skeleton = []
for tracker_id, kpt_seq in input:
ids.append(tracker_id)
kpts = np.array(kpt_seq.kpts, dtype=np.float32)[:, :, :2]
kpts = np.expand_dims(np.transpose(kpts, [2, 0, 1]),
-1) #T, K, C -> C, T, K, 1
bbox = np.array(kpt_seq.bboxes, dtype=np.float32)
skeleton.append(refine_keypoint_coordinary(kpts, bbox, coord_size))
parsed_skeleton_with_mot["mot_id"] = ids
parsed_skeleton_with_mot["skeleton"] = skeleton
return parsed_skeleton_with_mot
......@@ -30,10 +30,15 @@ sys.path.insert(0, parent_path)
from python.infer import Detector, DetectorPicoDet
from python.mot_sde_infer import SDE_Detector
from python.attr_infer import AttrDetector
from python.keypoint_infer import KeyPointDetector
from python.keypoint_postprocess import translate_to_ori_images
from python.action_infer import ActionRecognizer
from python.action_utils import KeyPointCollector, ActionVisualCollector
from pipe_utils import argsparser, print_arguments, merge_cfg, PipeTimer
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint
from python.preprocess import decode_image
from python.visualize import visualize_box_mask, visualize_attr
from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action
from pptracking.python.visualize import plot_tracking
......@@ -299,9 +304,45 @@ class PipePredictor(object):
trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
enable_mkldnn)
if self.with_action:
self.kpt_predictor = KeyPointDetector()
self.kpt_collector = KeyPointCollector()
self.action_predictor = ActionDetector()
kpt_cfg = self.cfg['KPT']
kpt_model_dir = kpt_cfg['model_dir']
kpt_batch_size = kpt_cfg['batch_size']
action_cfg = self.cfg['ACTION']
action_model_dir = action_cfg['model_dir']
action_batch_size = action_cfg['batch_size']
action_frames = action_cfg['max_frames']
display_frames = action_cfg['display_frames']
self.coord_size = action_cfg['coord_size']
self.kpt_predictor = KeyPointDetector(
kpt_model_dir,
device,
run_mode,
kpt_batch_size,
trt_min_shape,
trt_max_shape,
trt_opt_shape,
trt_calib_mode,
cpu_threads,
enable_mkldnn,
use_dark=False)
self.kpt_collector = KeyPointCollector(action_frames)
self.action_predictor = ActionRecognizer(
action_model_dir,
device,
run_mode,
action_batch_size,
trt_min_shape,
trt_max_shape,
trt_opt_shape,
trt_calib_mode,
cpu_threads,
enable_mkldnn,
window_size=action_frames)
self.action_visual_collector = ActionVisualCollector(
display_frames)
def set_file_name(self, path):
self.file_name = os.path.split(path)[-1]
......@@ -412,7 +453,8 @@ class PipePredictor(object):
self.pipeline_res.update(mot_res, 'mot')
if self.with_attr or self.with_action:
crop_input = crop_image_with_mot(frame, mot_res)
crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
frame, mot_res)
if self.with_attr:
if frame_id > self.warmup_frame:
......@@ -424,17 +466,34 @@ class PipePredictor(object):
self.pipeline_res.update(attr_res, 'attr')
if self.with_action:
kpt_result = self.kpt_predictor.predict_image(crop_input)
self.pipeline_res.update(kpt_result, 'kpt')
self.kpt_collector.update(kpt_result) # collect kpt output
state = self.kpt_collector.state() # whether frame num is enough
kpt_pred = self.kpt_predictor.predict_image(
crop_input, visual=False)
keypoint_vector, score_vector = translate_to_ori_images(
kpt_pred, np.array(new_bboxes))
kpt_res = {}
kpt_res['keypoint'] = [
keypoint_vector.tolist(), score_vector.tolist()
] if len(keypoint_vector) > 0 else [[], []]
kpt_res['bbox'] = ori_bboxes
self.pipeline_res.update(kpt_res, 'kpt')
self.kpt_collector.update(kpt_res,
mot_res) # collect kpt output
state = self.kpt_collector.get_state(
) # whether frame num is enough or lost tracker
action_res = {}
if state:
action_input = self.kpt_collector.collate(
) # reorgnize kpt output in ID
action_res = self.action_predictor.predict_kpt(action_input)
self.pipeline_res.update(action, 'action')
collected_keypoint = self.kpt_collector.get_collected_keypoint(
) # reoragnize kpt output with ID
action_input = parse_mot_keypoint(collected_keypoint,
self.coord_size)
action_res = self.action_predictor.predict_skeleton_with_mot(
action_input)
self.pipeline_res.update(action_res, 'action')
if self.cfg['visual']:
self.action_visual_collector.update(action_res)
if frame_id > self.warmup_frame:
self.pipe_timer.img_num += 1
......@@ -474,6 +533,19 @@ class PipePredictor(object):
image = visualize_attr(image, attr_res, boxes)
image = np.array(image)
kpt_res = result.get('kpt')
if kpt_res is not None:
image = visualize_pose(
image,
kpt_res,
visual_thresh=self.cfg['kpt_thresh'],
returnimg=True)
action_res = result.get('action')
if action_res is not None:
image = visualize_action(image, mot_res['boxes'],
self.action_visual_collector, "Falling")
return image
def visualize_image(self, im_files, images, result):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import yaml
import glob
import cv2
import numpy as np
import math
import paddle
import sys
from collections import Sequence
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from paddle.inference import Config, create_predictor
from utils import argsparser, Timer, get_current_memory_mb
from benchmark_utils import PaddleInferBenchmark
from infer import Detector, print_arguments
class ActionRecognizer(Detector):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
threshold (float): The threshold of score for visualization
window_size(int): Temporal size of skeleton feature.
random_pad (bool): Whether do random padding when frame length < window_size.
"""
def __init__(self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
output_dir='output',
threshold=0.5,
window_size=100,
random_pad=False):
assert batch_size == 1, "ActionRecognizer only support batch_size=1 now."
super(ActionRecognizer, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold)
def predict(self, repeats=1):
'''
Args:
repeats (int): repeat number for prediction
Returns:
results (dict):
'''
# model prediction
output_names = self.predictor.get_output_names()
for i in range(repeats):
self.predictor.run()
output_tensor = self.predictor.get_output_handle(output_names[0])
np_output = output_tensor.copy_to_cpu()
result = dict(output=np_output)
return result
def predict_skeleton(self, skeleton_list, run_benchmark=False, repeats=1):
results = []
for i, skeleton in enumerate(skeleton_list):
if run_benchmark:
# preprocess
inputs = self.preprocess(skeleton) # warmup
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(skeleton)
self.det_times.preprocess_time_s.end()
# model prediction
result = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats)
# postprocess
result_warmup = self.postprocess(inputs, result) # warmup
self.det_times.postprocess_time_s.start()
result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(skeleton)
cm, gm, gu = get_current_memory_mb()
self.cpu_mem += cm
self.gpu_mem += gm
self.gpu_util += gu
else:
# preprocess
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(skeleton)
self.det_times.preprocess_time_s.end()
# model prediction
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
# postprocess
self.det_times.postprocess_time_s.start()
result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(skeleton)
results.append(result)
return results
def predict_skeleton_with_mot(self, skeleton_with_mot, run_benchmark=False):
"""
skeleton_with_mot (dict): includes individual skeleton sequences, which shape is [C, T, K, 1]
and its corresponding track id.
"""
skeleton_list = skeleton_with_mot["skeleton"]
mot_id = skeleton_with_mot["mot_id"]
act_res = self.predict_skeleton(skeleton_list, run_benchmark, repeats=1)
results = list(zip(mot_id, act_res))
return results
def preprocess(self, data):
preprocess_ops = []
for op_info in self.pred_config.preprocess_infos:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
input_lst = []
data = action_preprocess(data, preprocess_ops)
input_lst.append(data)
input_names = self.predictor.get_input_names()
inputs = {}
inputs['data_batch_0'] = np.stack(input_lst, axis=0).astype('float32')
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
return inputs
def postprocess(self, inputs, result):
# postprocess output of predictor
output_logit = result['output'][0]
classes = np.argpartition(output_logit, -1)[-1:]
classes = classes[np.argsort(-output_logit[classes])]
scores = output_logit[classes]
result = {'class': classes, 'score': scores}
return result
def action_preprocess(input, preprocess_ops):
"""
input (str | numpy.array): if input is str, it should be a legal file path with numpy array saved.
Otherwise it should be numpy.array as direct input.
return (numpy.array)
"""
if isinstance(input, str):
assert os.path.isfile(input) is not None, "{0} not exists".format(input)
data = np.load(input)
else:
data = input
for operator in preprocess_ops:
data = operator(data)
return data
class AutoPadding(object):
"""
Sample or Padding frame skeleton feature.
Args:
window_size (int): Temporal size of skeleton feature.
random_pad (bool): Whether do random padding when frame length < window size. Default: False.
"""
def __init__(self, window_size=100, random_pad=False):
self.window_size = window_size
self.random_pad = random_pad
def get_frame_num(self, data):
C, T, V, M = data.shape
for i in range(T - 1, -1, -1):
tmp = np.sum(data[:, i, :, :])
if tmp > 0:
T = i + 1
break
return T
def __call__(self, results):
data = results
C, T, V, M = data.shape
T = self.get_frame_num(data)
if T == self.window_size:
data_pad = data[:, :self.window_size, :, :]
elif T < self.window_size:
begin = random.randint(
0, self.window_size - T) if self.random_pad else 0
data_pad = np.zeros((C, self.window_size, V, M))
data_pad[:, begin:begin + T, :, :] = data[:, :T, :, :]
else:
if self.random_pad:
index = np.random.choice(
T, self.window_size, replace=False).astype('int64')
else:
index = np.linspace(0, T, self.window_size).astype("int64")
data_pad = data[:, index, :, :]
return data_pad
def get_test_skeletons(input_file):
assert input_file is not None, "--action_file can not be None"
input_data = np.load(input_file)
if input_data.ndim == 4:
return [input_data]
elif input_data.ndim == 5:
output = list(
map(lambda x: np.squeeze(x, 0),
np.split(input_data, input_data.shape[0], 0)))
return output
else:
raise ValueError(
"Now only support input with shape: (N, C, T, K, M) or (C, T, K, M)")
def main():
detector = ActionRecognizer(
FLAGS.model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
window_size=FLAGS.window_size,
random_pad=FLAGS.random_pad)
# predict from numpy array
input_list = get_test_skeletons(FLAGS.action_file)
detector.predict_skeleton(input_list, FLAGS.run_benchmark, repeats=10)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
mems = {
'cpu_rss_mb': detector.cpu_mem / len(input_list),
'gpu_rss_mb': detector.gpu_mem / len(input_list),
'gpu_util': detector.gpu_util * 100 / len(input_list)
}
perf_info = detector.det_times.report(average=True)
model_dir = FLAGS.model_dir
mode = FLAGS.run_mode
model_info = {
'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
data_info = {
'batch_size': FLAGS.batch_size,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
det_log = PaddleInferBenchmark(detector.config, model_info, data_info,
perf_info, mems)
det_log('Action')
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
FLAGS.device = FLAGS.device.upper()
assert FLAGS.device in ['CPU', 'GPU', 'XPU'
], "device should be CPU, GPU or XPU"
assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class KeyPointSequence(object):
def __init__(self, max_size=100):
self.frames = 0
self.kpts = []
self.bboxes = []
self.max_size = max_size
def save(self, kpt, bbox):
self.kpts.append(kpt)
self.bboxes.append(bbox)
self.frames += 1
if self.frames == self.max_size:
return True
return False
class KeyPointCollector(object):
def __init__(self, max_size=100):
self.flag_track_interrupt = False
self.keypoint_saver = dict()
self.max_size = max_size
self.id_to_pop = set()
self.flag_to_pop = False
def get_state(self):
return self.flag_to_pop
def update(self, kpt_res, mot_res):
kpts = kpt_res.get('keypoint')[0]
bboxes = kpt_res.get('bbox')
mot_bboxes = mot_res.get('boxes')
updated_id = set()
for idx in range(len(kpts)):
tracker_id = mot_bboxes[idx, 0]
updated_id.add(tracker_id)
kpt_seq = self.keypoint_saver.get(tracker_id,
KeyPointSequence(self.max_size))
is_full = kpt_seq.save(kpts[idx], bboxes[idx])
self.keypoint_saver[tracker_id] = kpt_seq
#Scene1: result should be popped when frames meet max size
if is_full:
self.id_to_pop.add(tracker_id)
self.flag_to_pop = True
#Scene2: result of a lost tracker should be popped
interrupted_id = set(self.keypoint_saver.keys()) - updated_id
if len(interrupted_id) > 0:
self.flag_to_pop = True
self.id_to_pop.update(interrupted_id)
def get_collected_keypoint(self):
"""
Output (List): List of keypoint results for Action Recognition task, where
the format of each element is [tracker_id, KeyPointSequence of tracker_id]
"""
output = []
for tracker_id in self.id_to_pop:
output.append([tracker_id, self.keypoint_saver[tracker_id]])
del (self.keypoint_saver[tracker_id])
self.flag_to_pop = False
self.id_to_pop.clear()
return output
class ActionVisualCollector(object):
def __init__(self, frame_life=20):
self.frame_life = frame_life
self.action_history = {}
def get_visualize_ids(self):
id_detected = self.check_detected()
return id_detected
def check_detected(self):
id_detected = set()
deperate_id = []
for mot_id in self.action_history:
self.action_history[mot_id]["life_remain"] -= 1
if int(self.action_history[mot_id]["class"]) == 0:
id_detected.add(mot_id)
if self.action_history[mot_id]["life_remain"] == 0:
deperate_id.append(mot_id)
for mot_id in deperate_id:
del (self.action_history[mot_id])
return id_detected
def update(self, action_res_list):
for mot_id, action_res in action_res_list:
action_info = self.action_history.get(mot_id, {})
action_info["class"] = action_res["class"]
action_info["life_remain"] = self.frame_life
self.action_history[mot_id] = action_info
......@@ -41,7 +41,6 @@ from PIL import Image, ImageDraw, ImageFont
class AttrDetector(Detector):
"""
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
......
......@@ -38,22 +38,9 @@ from utils import argsparser, Timer, get_current_memory_mb
# Global dictionary
SUPPORT_MODELS = {
'YOLO',
'RCNN',
'SSD',
'Face',
'FCOS',
'SOLOv2',
'TTFNet',
'S2ANet',
'JDE',
'FairMOT',
'DeepSORT',
'GFL',
'PicoDet',
'CenterNet',
'TOOD',
'StrongBaseline',
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD',
'StrongBaseline', 'STGCN'
}
......@@ -287,7 +274,7 @@ class Detector(object):
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
......
......@@ -141,6 +141,21 @@ def argsparser():
type=ast.literal_eval,
default=True,
help='whether to use darkpose to get better keypoint position predict ')
parser.add_argument(
"--action_file",
type=str,
default=None,
help="Path of input file for action recognition.")
parser.add_argument(
"--window_size",
type=int,
default=50,
help="Temporal size of skeleton feature for action recognition.")
parser.add_argument(
"--random_pad",
type=ast.literal_eval,
default=False,
help="Whether do random padding for action recognition.")
return parser
......@@ -237,7 +252,7 @@ class Timer(Times):
total_time = pre_time + infer_time + post_time
if self.with_tracker:
dic['tracking_time_s'] = round(track_time / max(1, self.img_num),
4) if average else track_time
4) if average else track_time
total_time = total_time + track_time
dic['total_time_s'] = round(total_time, 4)
return dic
......
......@@ -361,3 +361,17 @@ def visualize_attr(im, results, boxes=None):
text_scale, (0, 0, 255),
thickness=text_thickness)
return im
def visualize_action(im, mot_boxes, action_visual_collector, action_text=""):
im = cv2.imread(im) if isinstance(im, str) else im
id_detected = action_visual_collector.get_visualize_ids()
text_scale = max(1, im.shape[1] / 1600.)
for mot_box in mot_boxes:
# mot_box is a format with [mot_id, class, score, xmin, ymin, w, h]
if mot_box[0] in id_detected:
text_position = (int(mot_box[3] + mot_box[5] * 0.75),
int(mot_box[4] - 10))
cv2.putText(im, action_text, text_position, cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255), 2)
return im
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册