diff --git a/deploy/pphuman/config/infer_cfg.yml b/deploy/pphuman/config/infer_cfg.yml index 039c6ec817be7a051992a2cf010b24d474218d00..9e53523aed7cedaa7f208d960bcbf28d09ae1e92 100644 --- a/deploy/pphuman/config/infer_cfg.yml +++ b/deploy/pphuman/config/infer_cfg.yml @@ -27,3 +27,7 @@ ACTION: max_frames: 50 display_frames: 80 coord_size: [384, 512] + +REID: + model_dir: output_inference/reid_model/ + batch_size: 16 diff --git a/deploy/pphuman/datacollector.py b/deploy/pphuman/datacollector.py new file mode 100644 index 0000000000000000000000000000000000000000..62bd68bfc575481bea17da321b4030644379858f --- /dev/null +++ b/deploy/pphuman/datacollector.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy + + +class Result(object): + def __init__(self): + self.res_dict = { + 'det': dict(), + 'mot': dict(), + 'attr': dict(), + 'kpt': dict(), + 'action': dict(), + 'reid': dict() + } + + def update(self, res, name): + self.res_dict[name].update(res) + + def get(self, name): + if name in self.res_dict and len(self.res_dict[name]) > 0: + return self.res_dict[name] + return None + + +class DataCollector(object): + """ + DataCollector of pphuman Pipeline, collect results in every frames and assign it to each track ids. + mainly used in mtmct. + + data struct: + collector: + - [id1]: (all results of N frames) + - frames(list of int): Nx[int] + - rects(list of rect): Nx[rect(conf, xmin, ymin, xmax, ymax)] + - features(list of array(256,)): Nx[array(256,)] + - qualities(list of float): Nx[float] + - attrs(list of attr): refer to attrs for details + - kpts(list of kpts): refer to kpts for details + - actions(list of actions): refer to actions for details + ... + - [idN] + """ + + def __init__(self): + #id, frame, rect, score, label, attrs, kpts, actions + self.mots = { + "frames": [], + "rects": [], + "attrs": [], + "kpts": [], + "features": [], + "qualities": [], + "actions": [] + } + self.collector = {} + + def append(self, frameid, Result): + mot_res = Result.get('mot') + attr_res = Result.get('attr') + kpt_res = Result.get('kpt') + action_res = Result.get('action') + reid_res = Result.get('reid') + + for idx, mot_item in enumerate(reid_res['rects']): + ids = int(mot_item[0]) + if ids not in self.collector: + self.collector[ids] = copy.deepcopy(self.mots) + + self.collector[ids]["frames"].append(frameid) + self.collector[ids]["rects"].append([mot_item[2:]]) + if attr_res: + self.collector[ids]["attrs"].append(attr_res['output'][idx]) + if kpt_res: + self.collector[ids]["kpts"].append(kpt_res['output'][idx]) + if action_res: + self.collector[ids]["actions"].append(action_res['output'][idx]) + else: + # action model generate result per X frames, Not available every frames + self.collector[ids]["actions"].append(None) + if reid_res: + self.collector[ids]["features"].append(reid_res['features'][ + idx]) + self.collector[ids]["qualities"].append(reid_res['qualities'][ + idx]) + + def get_res(self): + return self.collector diff --git a/deploy/pphuman/mtmct.py b/deploy/pphuman/mtmct.py new file mode 100644 index 0000000000000000000000000000000000000000..f67da49a08e88e2d524cc7c86cbda692795fb159 --- /dev/null +++ b/deploy/pphuman/mtmct.py @@ -0,0 +1,342 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import motmetrics as mm +from pptracking.python.mot.visualize import plot_tracking +import os +import re +import cv2 +import gc +import numpy as np +from sklearn import preprocessing +from sklearn.cluster import AgglomerativeClustering +import pandas as pd +from tqdm import tqdm +from functools import reduce +import warnings +warnings.filterwarnings("ignore") + + +def gen_restxt(output_dir_filename, map_tid, cid_tid_dict): + pattern = re.compile(r'c(\d)_t(\d)') + f_w = open(output_dir_filename, 'w') + for key, res in cid_tid_dict.items(): + cid, tid = pattern.search(key).groups() + cid = int(cid) + 1 + rects = res["rects"] + frames = res["frames"] + for idx, bbox in enumerate(rects): + bbox[0][3:] -= bbox[0][1:3] + fid = frames[idx] + 1 + rect = [max(int(x), 0) for x in bbox[0][1:]] + if key in map_tid: + new_tid = map_tid[key] + f_w.write( + str(cid) + ' ' + str(new_tid) + ' ' + str(fid) + ' ' + + ' '.join(map(str, rect)) + '\n') + print('gen_res: write file in {}'.format(output_dir_filename)) + f_w.close() + + +def get_mtmct_matching_results(pred_mtmct_file, secs_interval=0.5, + video_fps=20): + res = np.loadtxt(pred_mtmct_file) # 'cid, tid, fid, x1, y1, w, h, -1, -1' + camera_ids = list(map(int, np.unique(res[:, 0]))) + + res = res[:, :7] + # each line in res: 'cid, tid, fid, x1, y1, w, h' + + camera_tids = [] + camera_results = dict() + for c_id in camera_ids: + camera_results[c_id] = res[res[:, 0] == c_id] + tids = np.unique(camera_results[c_id][:, 1]) + tids = list(map(int, tids)) + camera_tids.append(tids) + + # select common tids throughout each video + common_tids = reduce(np.intersect1d, camera_tids) + + # get mtmct matching results by cid_tid_fid_results[c_id][t_id][f_id] + cid_tid_fid_results = dict() + cid_tid_to_fids = dict() + interval = int(secs_interval * video_fps) # preferably less than 10 + for c_id in camera_ids: + cid_tid_fid_results[c_id] = dict() + cid_tid_to_fids[c_id] = dict() + for t_id in common_tids: + tid_mask = camera_results[c_id][:, 1] == t_id + cid_tid_fid_results[c_id][t_id] = dict() + + camera_trackid_results = camera_results[c_id][tid_mask] + fids = np.unique(camera_trackid_results[:, 2]) + fids = fids[fids % interval == 0] + fids = list(map(int, fids)) + cid_tid_to_fids[c_id][t_id] = fids + + for f_id in fids: + st_frame = f_id + ed_frame = f_id + interval + + st_mask = camera_trackid_results[:, 2] >= st_frame + ed_mask = camera_trackid_results[:, 2] < ed_frame + frame_mask = np.logical_and(st_mask, ed_mask) + cid_tid_fid_results[c_id][t_id][f_id] = camera_trackid_results[ + frame_mask] + + return camera_results, cid_tid_fid_results + + +def save_mtmct_vis_results(camera_results, captures, output_dir): + # camera_results: 'cid, tid, fid, x1, y1, w, h' + camera_ids = list(camera_results.keys()) + + import shutil + save_dir = os.path.join(output_dir, 'mtmct_vis') + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + os.makedirs(save_dir) + + for idx, video_file in enumerate(captures): + capture = cv2.VideoCapture(video_file) + cid = camera_ids[idx] + video_out_name = "mtmct_vis_c" + str(cid) + ".mp4" + print("Start visualizing output video: {}".format(video_out_name)) + out_path = os.path.join(save_dir, video_out_name) + + # Get Video info : resolution, fps, frame count + width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = int(capture.get(cv2.CAP_PROP_FPS)) + frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) + fourcc = cv2.VideoWriter_fourcc(* 'mp4v') + writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) + frame_id = 0 + while (1): + if frame_id % 50 == 0: + print('frame id: ', frame_id) + ret, frame = capture.read() + frame_id += 1 + if not ret: + if frame_id == 1: + print("video read failed!") + break + frame_results = camera_results[cid][camera_results[cid][:, 2] == + frame_id] + boxes = frame_results[:, -4:] + ids = frame_results[:, 1] + image = plot_tracking(frame, boxes, ids, frame_id=frame_id, fps=fps) + writer.write(image) + writer.release() + + +def get_euclidean(x, y, **kwargs): + m = x.shape[0] + n = y.shape[0] + distmat = (np.power(x, 2).sum(axis=1, keepdims=True).repeat( + n, axis=1) + np.power(y, 2).sum(axis=1, keepdims=True).repeat( + m, axis=1).T) + distmat -= np.dot(2 * x, y.T) + return distmat + + +def cosine_similarity(x, y, eps=1e-12): + """ + Computes cosine similarity between two tensors. + Value == 1 means the same vector + Value == 0 means perpendicular vectors + """ + x_n, y_n = np.linalg.norm( + x, axis=1, keepdims=True), np.linalg.norm( + y, axis=1, keepdims=True) + x_norm = x / np.maximum(x_n, eps * np.ones_like(x_n)) + y_norm = y / np.maximum(y_n, eps * np.ones_like(y_n)) + sim_mt = np.dot(x_norm, y_norm.T) + return sim_mt + + +def get_cosine(x, y, eps=1e-12): + """ + Computes cosine distance between two tensors. + The cosine distance is the inverse cosine similarity + -> cosine_distance = abs(-cosine_distance) to make it + similar in behaviour to euclidean distance + """ + sim_mt = cosine_similarity(x, y, eps) + return sim_mt + + +def get_dist_mat(x, y, func_name="euclidean"): + if func_name == "cosine": + dist_mat = get_cosine(x, y) + elif func_name == "euclidean": + dist_mat = get_euclidean(x, y) + print("Using {func_name} as distance function during evaluation") + return dist_mat + + +def intracam_ignore(st_mask, cid_tids): + count = len(cid_tids) + for i in range(count): + for j in range(count): + if cid_tids[i][1] == cid_tids[j][1]: + st_mask[i, j] = 0. + return st_mask + + +def get_sim_matrix_new(cid_tid_dict, cid_tids): + # Note: camera independent get_sim_matrix function, + # which is different from the one in camera_utils.py. + count = len(cid_tids) + + q_arr = np.array( + [cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)]) + g_arr = np.array( + [cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)]) + #compute distmat + distmat = get_dist_mat(q_arr, g_arr, func_name="cosine") + + #mask the element which belongs to same video + st_mask = np.ones((count, count), dtype=np.float32) + st_mask = intracam_ignore(st_mask, cid_tids) + + sim_matrix = distmat * st_mask + np.fill_diagonal(sim_matrix, 0.) + return 1. - sim_matrix + + +def get_match(cluster_labels): + cluster_dict = dict() + cluster = list() + for i, l in enumerate(cluster_labels): + if l in list(cluster_dict.keys()): + cluster_dict[l].append(i) + else: + cluster_dict[l] = [i] + for idx in cluster_dict: + cluster.append(cluster_dict[idx]) + return cluster + + +def get_cid_tid(cluster_labels, cid_tids): + cluster = list() + for labels in cluster_labels: + cid_tid_list = list() + for label in labels: + cid_tid_list.append(cid_tids[label]) + cluster.append(cid_tid_list) + return cluster + + +def get_labels(cid_tid_dict, cid_tids): + #compute cost matrix between features + cost_matrix = get_sim_matrix_new(cid_tid_dict, cid_tids) + + #cluster all the features + cluster1 = AgglomerativeClustering( + n_clusters=None, + distance_threshold=0.5, + affinity='precomputed', + linkage='complete') + cluster_labels1 = cluster1.fit_predict(cost_matrix) + labels = get_match(cluster_labels1) + + sub_cluster = get_cid_tid(labels, cid_tids) + return labels + + +def sub_cluster(cid_tid_dict): + ''' + cid_tid_dict: all camera_id and track_id + ''' + #get all keys + cid_tids = sorted([key for key in cid_tid_dict.keys()]) + + #cluster all trackid + clu = get_labels(cid_tid_dict, cid_tids) + + #relabel every cluster groups + new_clu = list() + for c_list in clu: + new_clu.append([cid_tids[c] for c in c_list]) + cid_tid_label = dict() + for i, c_list in enumerate(new_clu): + for c in c_list: + cid_tid_label[c] = i + 1 + return cid_tid_label + + +def distill_idfeat(mot_res): + qualities_list = mot_res["qualities"] + feature_list = mot_res["features"] + rects = mot_res["rects"] + + qualities_new = [] + feature_new = [] + #filter rect less than 100*20 + for idx, rect in enumerate(rects): + conf, xmin, ymin, xmax, ymax = rect[0] + if (xmax - xmin) * (ymax - ymin) and (xmax > xmin) > 2000: + qualities_new.append(qualities_list[idx]) + feature_new.append(feature_list[idx]) + #take all features if available rect is less than 2 + if len(qualities_new) < 2: + qualities_new = qualities_list + feature_new = feature_list + + #if available frames number is more than 200, take one frame data per 20 frames + if len(qualities_new) > 200: + skipf = 20 + else: + skipf = max(10, len(qualities_new) // 10) + quality_skip = np.array(qualities_new[::skipf]) + feature_skip = np.array(feature_new[::skipf]) + + #sort features with image qualities, take the most trustworth features + topk_argq = np.argsort(quality_skip)[::-1] + if (quality_skip > 0.6).sum() > 1: + topk_feat = feature_skip[topk_argq[quality_skip > 0.6]] + else: + topk_feat = feature_skip[topk_argq] + + #get final features by mean or cluster, at most take five + mean_feat = np.mean(topk_feat[:5], axis=0) + return mean_feat + + +def res2dict(multi_res): + cid_tid_dict = {} + for cid, c_res in enumerate(multi_res): + for tid, res in c_res.items(): + key = "c" + str(cid) + "_t" + str(tid) + if key not in cid_tid_dict: + cid_tid_dict[key] = res + cid_tid_dict[key]['mean_feat'] = distill_idfeat(res) + return cid_tid_dict + + +def mtmct_process(multi_res, captures, mtmct_vis=True, output_dir="output"): + cid_tid_dict = res2dict(multi_res) + map_tid = sub_cluster(cid_tid_dict) + + if not os.path.exists(output_dir): + os.mkdir(output_dir) + pred_mtmct_file = os.path.join(output_dir, 'mtmct_result.txt') + gen_restxt(pred_mtmct_file, map_tid, cid_tid_dict) + + if mtmct_vis: + camera_results, cid_tid_fid_res = get_mtmct_matching_results( + pred_mtmct_file) + + save_mtmct_vis_results(camera_results, captures, output_dir=output_dir) diff --git a/deploy/pphuman/pipe_utils.py b/deploy/pphuman/pipe_utils.py index 25d3ad0733b043b8fd11fffd82669ea607bf7bc9..094cb6a72fe3f04382ec5228760d81ca22e0847f 100644 --- a/deploy/pphuman/pipe_utils.py +++ b/deploy/pphuman/pipe_utils.py @@ -45,6 +45,11 @@ def argsparser(): default=None, help="Path of video file, `video_file` or `camera_id` has a highest priority." ) + parser.add_argument( + "--video_dir", + type=str, + default=None, + help="Dir of video file, `video_file` has a higher priority.") parser.add_argument( "--model_dir", nargs='*', help="set model dir in pipeline") parser.add_argument( @@ -143,6 +148,7 @@ class PipeTimer(Times): 'attr': Times(), 'kpt': Times(), 'action': Times(), + 'reid': Times() } self.img_num = 0 @@ -268,7 +274,7 @@ def get_test_images(infer_dir, infer_img): return images -def crop_image_with_det(batch_input, det_res): +def crop_image_with_det(batch_input, det_res, thresh=0.3): boxes = det_res['boxes'] score = det_res['boxes'][:, 1] boxes_num = det_res['boxes_num'] @@ -279,21 +285,38 @@ def crop_image_with_det(batch_input, det_res): boxes_i = boxes[start_idx:start_idx + boxes_num_i, :] score_i = score[start_idx:start_idx + boxes_num_i] res = [] - for box in boxes_i: - crop_image, new_box, ori_box = expand_crop(input, box) - if crop_image is not None: - res.append(crop_image) + for box, s in zip(boxes_i, score_i): + if s > thresh: + crop_image, new_box, ori_box = expand_crop(input, box) + if crop_image is not None: + res.append(crop_image) crop_res.append(res) return crop_res -def crop_image_with_mot(input, mot_res): +def normal_crop(image, rect): + imgh, imgw, c = image.shape + label, conf, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()] + org_rect = [xmin, ymin, xmax, ymax] + if label != 0: + return None, None, None + xmin = max(0, xmin) + ymin = max(0, ymin) + xmax = min(imgw, xmax) + ymax = min(imgh, ymax) + return image[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax], org_rect + + +def crop_image_with_mot(input, mot_res, expand=True): res = mot_res['boxes'] crop_res = [] new_bboxes = [] ori_bboxes = [] for box in res: - crop_image, new_bbox, ori_bbox = expand_crop(input, box[1:]) + if expand: + crop_image, new_bbox, ori_bbox = expand_crop(input, box[1:]) + else: + crop_image, new_bbox, ori_bbox = normal_crop(input, box[1:]) if crop_image is not None: crop_res.append(crop_image) new_bboxes.append(new_bbox) diff --git a/deploy/pphuman/pipeline.py b/deploy/pphuman/pipeline.py index 22322ce7c6123b229a1b703eb01200ba48d61ca5..bbbf4da956b9ad3790c6f9787c493fd435d65a9c 100644 --- a/deploy/pphuman/pipeline.py +++ b/deploy/pphuman/pipeline.py @@ -21,7 +21,11 @@ import numpy as np import math import paddle import sys +import copy from collections import Sequence +from reid import ReID +from datacollector import DataCollector, Result +from mtmct import mtmct_process # add deploy path of PadleDetection to sys.path parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) @@ -32,7 +36,7 @@ from python.attr_infer import AttrDetector from python.keypoint_infer import KeyPointDetector from python.keypoint_postprocess import translate_to_ori_images from python.action_infer import ActionRecognizer -from python.action_utils import KeyPointCollector, ActionVisualCollector +from python.action_utils import KeyPointBuff, ActionVisualHelper from pipe_utils import argsparser, print_arguments, merge_cfg, PipeTimer from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint @@ -75,6 +79,7 @@ class Pipeline(object): image_file=None, image_dir=None, video_file=None, + video_dir=None, camera_id=-1, enable_attr=False, enable_action=True, @@ -89,8 +94,10 @@ class Pipeline(object): output_dir='output'): self.multi_camera = False self.is_video = False + self.output_dir = output_dir + self.vis_result = cfg['visual'] self.input = self._parse_input(image_file, image_dir, video_file, - camera_id) + video_dir, camera_id) if self.multi_camera: self.predictor = [ PipePredictor( @@ -126,7 +133,8 @@ class Pipeline(object): if self.is_video: self.predictor.set_file_name(video_file) - def _parse_input(self, image_file, image_dir, video_file, camera_id): + def _parse_input(self, image_file, image_dir, video_file, video_dir, + camera_id): # parse input as is_video and multi_camera @@ -136,19 +144,23 @@ class Pipeline(object): self.multi_camera = False elif video_file is not None: - if isinstance(video_file, list): + self.multi_camera = False + input = video_file + self.is_video = True + + elif video_dir is not None: + videof = [os.path.join(video_dir, x) for x in os.listdir(video_dir)] + if len(videof) > 1: self.multi_camera = True - input = [cv2.VideoCapture(v) for v in video_file] + videof.sort() + input = videof else: - input = cv2.VideoCapture(video_file) + input = videof[0] self.is_video = True elif camera_id != -1: - if isinstance(camera_id, Sequence): - self.multi_camera = True - input = [cv2.VideoCapture(i) for i in camera_id] - else: - input = cv2.VideoCapture(camera_id) + self.multi_camera = False + input = camera_id self.is_video = True else: @@ -163,34 +175,18 @@ class Pipeline(object): multi_res = [] for predictor, input in zip(self.predictor, self.input): predictor.run(input) - res = predictor.get_result() - multi_res.append(res) - - mtmct_process(multi_res) + collector_data = predictor.get_result() + multi_res.append(collector_data) + mtmct_process( + multi_res, + self.input, + mtmct_vis=self.vis_result, + output_dir=self.output_dir) else: self.predictor.run(self.input) -class Result(object): - def __init__(self): - self.res_dict = { - 'det': dict(), - 'mot': dict(), - 'attr': dict(), - 'kpt': dict(), - 'action': dict() - } - - def update(self, res, name): - self.res_dict[name].update(res) - - def get(self, name): - if name in self.res_dict and len(self.res_dict[name]) > 0: - return self.res_dict[name] - return None - - class PipePredictor(object): """ Predictor in single camera @@ -255,10 +251,18 @@ class PipePredictor(object): self.with_attr = cfg.get('ATTR', False) and enable_attr self.with_action = cfg.get('ACTION', False) and enable_action + self.with_mtmct = cfg.get('REID', False) and multi_camera if self.with_attr: print('Attribute Recognition enabled') if self.with_action: print('Action Recognition enabled') + if multi_camera: + if not self.with_mtmct: + print( + 'Warning!!! MTMCT enabled, but cannot find REID config in [infer_cfg.yml], please check!' + ) + else: + print("MTMCT enabled") self.is_video = is_video self.multi_camera = multi_camera @@ -269,6 +273,7 @@ class PipePredictor(object): self.pipeline_res = Result() self.pipe_timer = PipeTimer() self.file_name = None + self.collector = DataCollector() if not is_video: det_cfg = self.cfg['DET'] @@ -327,7 +332,7 @@ class PipePredictor(object): cpu_threads, enable_mkldnn, use_dark=False) - self.kpt_collector = KeyPointCollector(action_frames) + self.kpt_buff = KeyPointBuff(action_frames) self.action_predictor = ActionRecognizer( action_model_dir, @@ -342,14 +347,22 @@ class PipePredictor(object): enable_mkldnn, window_size=action_frames) - self.action_visual_collector = ActionVisualCollector( - display_frames) + self.action_visual_helper = ActionVisualHelper(display_frames) + + if self.with_mtmct: + reid_cfg = self.cfg['REID'] + model_dir = reid_cfg['model_dir'] + batch_size = reid_cfg['batch_size'] + self.reid_predictor = ReID(model_dir, device, run_mode, batch_size, + trt_min_shape, trt_max_shape, + trt_opt_shape, trt_calib_mode, + cpu_threads, enable_mkldnn) def set_file_name(self, path): self.file_name = os.path.split(path)[-1] def get_result(self): - return self.pipeline_res + return self.collector.get_res() def run(self, input): if self.is_video: @@ -406,10 +419,11 @@ class PipePredictor(object): if self.cfg['visual']: self.visualize_image(batch_file, batch_input, self.pipeline_res) - def predict_video(self, capture): + def predict_video(self, video_file): # mot # mot -> attr # mot -> pose -> action + capture = cv2.VideoCapture(video_file) video_out_name = 'output.mp4' if self.file_name is None else self.file_name # Get Video info : resolution, fps, frame count @@ -434,7 +448,8 @@ class PipePredictor(object): if frame_id > self.warmup_frame: self.pipe_timer.total_time.start() self.pipe_timer.module_time['mot'].start() - res = self.mot_predictor.predict_image([frame], visual=False) + res = self.mot_predictor.predict_image( + [copy.deepcopy(frame)], visual=False) if frame_id > self.warmup_frame: self.pipe_timer.module_time['mot'].end() @@ -485,16 +500,15 @@ class PipePredictor(object): self.pipeline_res.update(kpt_res, 'kpt') - self.kpt_collector.update(kpt_res, - mot_res) # collect kpt output - state = self.kpt_collector.get_state( + self.kpt_buff.update(kpt_res, mot_res) # collect kpt output + state = self.kpt_buff.get_state( ) # whether frame num is enough or lost tracker action_res = {} if state: if frame_id > self.warmup_frame: self.pipe_timer.module_time['action'].start() - collected_keypoint = self.kpt_collector.get_collected_keypoint( + collected_keypoint = self.kpt_buff.get_collected_keypoint( ) # reoragnize kpt output with ID action_input = parse_mot_keypoint(collected_keypoint, self.coord_size) @@ -505,18 +519,32 @@ class PipePredictor(object): self.pipeline_res.update(action_res, 'action') if self.cfg['visual']: - self.action_visual_collector.update(action_res) + self.action_visual_helper.update(action_res) + + if self.with_mtmct: + crop_input, img_qualities, rects = self.reid_predictor.crop_image_with_mot( + frame, mot_res) + if frame_id > self.warmup_frame: + self.pipe_timer.module_time['reid'].start() + reid_res = self.reid_predictor.predict_batch(crop_input) + + if frame_id > self.warmup_frame: + self.pipe_timer.module_time['reid'].end() + + reid_res_dict = { + 'features': reid_res, + "qualities": img_qualities, + "rects": rects + } + self.pipeline_res.update(reid_res_dict, 'reid') + + self.collector.append(frame_id, self.pipeline_res) if frame_id > self.warmup_frame: self.pipe_timer.img_num += 1 self.pipe_timer.total_time.end() frame_id += 1 - if self.multi_camera: - self.get_valid_instance( - frame, - self.pipeline_res) # parse output result for multi-camera - if self.cfg['visual']: _, _, fps = self.pipe_timer.get_total_time() im = self.visualize_video(frame, self.pipeline_res, frame_id, @@ -527,7 +555,7 @@ class PipePredictor(object): print('save result to {}'.format(out_path)) def visualize_video(self, image, result, frame_id, fps): - mot_res = result.get('mot') + mot_res = copy.deepcopy(result.get('mot')) if mot_res is not None: ids = mot_res['boxes'][:, 0] scores = mot_res['boxes'][:, 2] @@ -559,7 +587,7 @@ class PipePredictor(object): action_res = result.get('action') if action_res is not None: image = visualize_action(image, mot_res['boxes'], - self.action_visual_collector, "Falling") + self.action_visual_helper, "Falling") return image @@ -598,10 +626,10 @@ def main(): print_arguments(cfg) pipeline = Pipeline( cfg, FLAGS.image_file, FLAGS.image_dir, FLAGS.video_file, - FLAGS.camera_id, FLAGS.enable_attr, FLAGS.enable_action, FLAGS.device, - FLAGS.run_mode, FLAGS.trt_min_shape, FLAGS.trt_max_shape, - FLAGS.trt_opt_shape, FLAGS.trt_calib_mode, FLAGS.cpu_threads, - FLAGS.enable_mkldnn, FLAGS.output_dir) + FLAGS.video_dir, FLAGS.camera_id, FLAGS.enable_attr, + FLAGS.enable_action, FLAGS.device, FLAGS.run_mode, FLAGS.trt_min_shape, + FLAGS.trt_max_shape, FLAGS.trt_opt_shape, FLAGS.trt_calib_mode, + FLAGS.cpu_threads, FLAGS.enable_mkldnn, FLAGS.output_dir) pipeline.run() diff --git a/deploy/pphuman/reid.py b/deploy/pphuman/reid.py new file mode 100644 index 0000000000000000000000000000000000000000..cef4029239f7e0f635547506282c2527bf687353 --- /dev/null +++ b/deploy/pphuman/reid.py @@ -0,0 +1,191 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import cv2 +import numpy as np +# add deploy path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) +sys.path.insert(0, parent_path) + +from python.infer import PredictConfig +from pptracking.python.det_infer import load_predictor +from python.utils import Timer + + +class ReID(object): + """ + ReID of SDE methods + + Args: + pred_config (object): config of model, defined by `Config(model_dir)` + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) + batch_size (int): size of per batch in inference, default 50 means at most + 50 sub images can be made a batch and send into ReID model + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, + model_dir, + device='CPU', + run_mode='paddle', + batch_size=50, + trt_min_shape=1, + trt_max_shape=1088, + trt_opt_shape=608, + trt_calib_mode=False, + cpu_threads=4, + enable_mkldnn=False): + self.pred_config = self.set_config(model_dir) + self.predictor, self.config = load_predictor( + model_dir, + run_mode=run_mode, + batch_size=batch_size, + min_subgraph_size=self.pred_config.min_subgraph_size, + device=device, + use_dynamic_shape=self.pred_config.use_dynamic_shape, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn) + self.det_times = Timer() + self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 + self.batch_size = batch_size + self.input_wh = [128, 256] + + def set_config(self, model_dir): + return PredictConfig(model_dir) + + def check_img_quality(self, crop, bbox, xyxy): + if crop is None: + return None + #eclipse + eclipse_quality = 1.0 + inner_rect = np.zeros(xyxy.shape) + inner_rect[:, :2] = np.maximum(xyxy[:, :2], bbox[None, :2]) + inner_rect[:, 2:] = np.minimum(xyxy[:, 2:], bbox[None, 2:]) + wh_array = inner_rect[:, 2:] - inner_rect[:, :2] + filt = np.logical_and(wh_array[:, 0] > 0, wh_array[:, 1] > 0) + wh_array = wh_array[filt] + if wh_array.shape[0] > 1: + eclipse_ratio = wh_array / (bbox[2:] - bbox[:2]) + eclipse_area_ratio = eclipse_ratio[:, 0] * eclipse_ratio[:, 1] + ear_lst = eclipse_area_ratio.tolist() + ear_lst.sort(reverse=True) + eclipse_quality = 1.0 - ear_lst[1] + bbox_wh = (bbox[2:] - bbox[:2]) + height_quality = bbox_wh[1] / (bbox_wh[0] * 2) + eclipse_quality = min(eclipse_quality, height_quality) + + #definition + cropgray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY) + definition = int(cv2.Laplacian(cropgray, cv2.CV_64F, ksize=3).var()) + brightness = int(cropgray.mean()) + bd_quality = min(1., brightness / 50.) + + eclipse_weight = 0.7 + return eclipse_quality * eclipse_weight + bd_quality * (1 - + eclipse_weight) + + def normal_crop(self, image, rect): + imgh, imgw, c = image.shape + label, conf, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()] + xmin = max(0, xmin) + ymin = max(0, ymin) + xmax = min(imgw, xmax) + ymax = min(imgh, ymax) + if label != 0 or xmax <= xmin or ymax <= ymin: + print("Warning! label missed!!") + return None, None, None + return image[ymin:ymax, xmin:xmax, :] + + def crop_image_with_mot(self, image, mot_res): + res = mot_res['boxes'] + crop_res = [] + img_quality = [] + rects = [] + for box in res: + crop_image = self.normal_crop(image, box[1:]) + quality_item = self.check_img_quality(crop_image, box[3:], + res[:, 3:]) + if crop_image is not None: + crop_res.append(crop_image) + img_quality.append(quality_item) + rects.append(box) + return crop_res, img_quality, rects + + def preprocess(self, + imgs, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]): + im_batch = [] + for img in imgs: + img = cv2.resize(img, self.input_wh) + img = img.astype('float32') / 255. + img -= np.array(mean) + img /= np.array(std) + im_batch.append(img.transpose((2, 0, 1))) + inputs = {} + inputs['x'] = np.array(im_batch).astype('float32') + return inputs + + def predict(self, crops, repeats=1, add_timer=True, seq_name=''): + # preprocess + if add_timer: + self.det_times.preprocess_time_s.start() + inputs = self.preprocess(crops) + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + if add_timer: + self.det_times.preprocess_time_s.end() + self.det_times.inference_time_s.start() + + # model prediction + for i in range(repeats): + self.predictor.run() + output_names = self.predictor.get_output_names() + feature_tensor = self.predictor.get_output_handle(output_names[0]) + pred_embs = feature_tensor.copy_to_cpu() + if add_timer: + self.det_times.inference_time_s.end(repeats=repeats) + self.det_times.postprocess_time_s.start() + + if add_timer: + self.det_times.postprocess_time_s.end() + self.det_times.img_num += 1 + return pred_embs + + def predict_batch(self, imgs, batch_size=4): + batch_feat = [] + for b in range(0, len(imgs), batch_size): + b_end = min(len(imgs), b + batch_size) + batch_imgs = imgs[b:b_end] + feat = self.predict(batch_imgs) + batch_feat.extend(feat.tolist()) + + return batch_feat diff --git a/deploy/python/action_utils.py b/deploy/python/action_utils.py index d9da8b6e7176c510396e0d5b7c587622effe6916..0fbc92a8aa842dbe92ee61b119be9e8be2ebfac1 100644 --- a/deploy/python/action_utils.py +++ b/deploy/python/action_utils.py @@ -29,7 +29,7 @@ class KeyPointSequence(object): return False -class KeyPointCollector(object): +class KeyPointBuff(object): def __init__(self, max_size=100): self.flag_track_interrupt = False self.keypoint_saver = dict() @@ -80,7 +80,7 @@ class KeyPointCollector(object): return output -class ActionVisualCollector(object): +class ActionVisualHelper(object): def __init__(self, frame_life=20): self.frame_life = frame_life self.action_history = {}