未验证 提交 e066d8d1 编写于 作者: XYZ_916's avatar XYZ_916 提交者: GitHub

fix bug for mtmct attr vis (#7165)

上级 7fb2318a
...@@ -118,6 +118,7 @@ def save_mtmct_vis_results(camera_results, captures, output_dir, ...@@ -118,6 +118,7 @@ def save_mtmct_vis_results(camera_results, captures, output_dir,
for idx, video_file in enumerate(captures): for idx, video_file in enumerate(captures):
capture = cv2.VideoCapture(video_file) capture = cv2.VideoCapture(video_file)
cid = camera_ids[idx] cid = camera_ids[idx]
basename = os.path.basename(video_file) basename = os.path.basename(video_file)
video_out_name = "vis_" + basename video_out_name = "vis_" + basename
out_path = os.path.join(save_dir, video_out_name) out_path = os.path.join(save_dir, video_out_name)
...@@ -151,16 +152,22 @@ def save_mtmct_vis_results(camera_results, captures, output_dir, ...@@ -151,16 +152,22 @@ def save_mtmct_vis_results(camera_results, captures, output_dir,
tid_list = multi_res.keys() # c0_t1, c0_t2... tid_list = multi_res.keys() # c0_t1, c0_t2...
all_attr_result = [multi_res[i]["attrs"] all_attr_result = [multi_res[i]["attrs"]
for i in tid_list] # all cid_tid result for i in tid_list] # all cid_tid result
if any( if any(
all_attr_result all_attr_result
): # at least one cid_tid[attrs] is not None will goes to attrs_vis ): # at least one cid_tid[attrs] is not None will goes to attrs_vis
attr_res = [] attr_res = []
cid_str = 'c' + str(cid - 1) + "_"
for k in tid_list: for k in tid_list:
if not k.startswith(cid_str):
continue
if (frame_id - 1) >= len(multi_res[k]['attrs']): if (frame_id - 1) >= len(multi_res[k]['attrs']):
t_attr = None t_attr = None
else: else:
t_attr = multi_res[k]['attrs'][frame_id - 1] t_attr = multi_res[k]['attrs'][frame_id - 1]
attr_res.append(t_attr) attr_res.append(t_attr)
assert len(attr_res) == len(boxes)
image = visualize_attr( image = visualize_attr(
image, attr_res, boxes, is_mtmct=True) image, attr_res, boxes, is_mtmct=True)
...@@ -347,7 +354,7 @@ def res2dict(multi_res): ...@@ -347,7 +354,7 @@ def res2dict(multi_res):
for tid, res in c_res.items(): for tid, res in c_res.items():
key = "c" + str(cid) + "_t" + str(tid) key = "c" + str(cid) + "_t" + str(tid)
if key not in cid_tid_dict: if key not in cid_tid_dict:
if len(res["features"])==0: if len(res["features"]) == 0:
continue continue
cid_tid_dict[key] = res cid_tid_dict[key] = res
cid_tid_dict[key]['mean_feat'] = distill_idfeat(res) cid_tid_dict[key]['mean_feat'] = distill_idfeat(res)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册