From e00cfc0a3ad225e12c9cfecfb1a44a791c9ebd85 Mon Sep 17 00:00:00 2001 From: David Nicolas <37790151+liyongchao911@users.noreply.github.com> Date: Fri, 5 Aug 2022 14:48:05 +0800 Subject: [PATCH] update mtmct.py and visualize.py for multi-camera attrs visualize (#6580) * update mtmct.py and visualize.py for multi-camera attrs visualize * update mtmct.py * update mcmct_attrvisualzie logic, without enable_attr will not goes to attrs visualize --- deploy/pipeline/pphuman/mtmct.py | 34 +++++++++++++++++++++++++++++--- deploy/python/visualize.py | 8 ++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/deploy/pipeline/pphuman/mtmct.py b/deploy/pipeline/pphuman/mtmct.py index 50c789c75..8ab72f4a4 100644 --- a/deploy/pipeline/pphuman/mtmct.py +++ b/deploy/pipeline/pphuman/mtmct.py @@ -13,6 +13,7 @@ # limitations under the License. from pptracking.python.mot.visualize import plot_tracking +from python.visualize import visualize_attr import os import re import cv2 @@ -103,7 +104,8 @@ def get_mtmct_matching_results(pred_mtmct_file, secs_interval=0.5, return camera_results, cid_tid_fid_results -def save_mtmct_vis_results(camera_results, captures, output_dir): +def save_mtmct_vis_results(camera_results, captures, output_dir, + multi_res=None): # camera_results: 'cid, tid, fid, x1, y1, w, h' camera_ids = list(camera_results.keys()) @@ -126,7 +128,7 @@ def save_mtmct_vis_results(camera_results, captures, output_dir): height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(capture.get(cv2.CAP_PROP_FPS)) frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) - fourcc = cv2.VideoWriter_fourcc(* 'mp4v') + fourcc = cv2.VideoWriter_fourcc(*'mp4v') writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) frame_id = 0 while (1): @@ -143,6 +145,28 @@ def save_mtmct_vis_results(camera_results, captures, output_dir): boxes = frame_results[:, -4:] ids = frame_results[:, 1] image = plot_tracking(frame, boxes, ids, frame_id=frame_id, fps=fps) + + # add attr vis + if multi_res: + tid_list = [ + 'c' + str(idx) + '_' + 't' + str(int(j)) + for j in range(1, len(ids) + 1) + ] # c0_t1, c0_t2... + all_attr_result = [multi_res[i]["attrs"] + for i in tid_list] # all cid_tid result + if any( + all_attr_result + ): # at least one cid_tid[attrs] is not None will goes to attrs_vis + attr_res = [] + for k in tid_list: + if (frame_id - 1) >= len(multi_res[k]['attrs']): + t_attr = None + else: + t_attr = multi_res[k]['attrs'][frame_id - 1] + attr_res.append(t_attr) + image = visualize_attr( + image, attr_res, boxes, is_mtmct=True) + writer.write(image) writer.release() @@ -349,4 +373,8 @@ def mtmct_process(multi_res, captures, mtmct_vis=True, output_dir="output"): camera_results, cid_tid_fid_res = get_mtmct_matching_results( pred_mtmct_file) - save_mtmct_vis_results(camera_results, captures, output_dir=output_dir) + save_mtmct_vis_results( + camera_results, + captures, + output_dir=output_dir, + multi_res=cid_tid_dict) diff --git a/deploy/python/visualize.py b/deploy/python/visualize.py index 8f0da233d..24aa40796 100644 --- a/deploy/python/visualize.py +++ b/deploy/python/visualize.py @@ -331,7 +331,7 @@ def visualize_pose(imgfile, plt.close() -def visualize_attr(im, results, boxes=None): +def visualize_attr(im, results, boxes=None, is_mtmct=False): if isinstance(im, str): im = Image.open(im) im = np.ascontiguousarray(np.copy(im)) @@ -348,8 +348,12 @@ def visualize_attr(im, results, boxes=None): if boxes is None: text_w = 3 text_h = 1 + elif is_mtmct: + box = boxes[i] # multi camera, bbox shape is x,y, w,h + text_w = int(box[0]) + 3 + text_h = int(box[1]) else: - box = boxes[i] + box = boxes[i] # single camera, bbox shape is 0, 0, x,y, w,h text_w = int(box[2]) + 3 text_h = int(box[3]) for text in res: -- GitLab