未验证 提交 e066d8d1 编写于 作者: XYZ_916's avatar XYZ_916 提交者: GitHub

fix bug for mtmct attr vis (#7165)

上级 7fb2318a
......@@ -118,6 +118,7 @@ def save_mtmct_vis_results(camera_results, captures, output_dir,
for idx, video_file in enumerate(captures):
capture = cv2.VideoCapture(video_file)
cid = camera_ids[idx]
basename = os.path.basename(video_file)
video_out_name = "vis_" + basename
out_path = os.path.join(save_dir, video_out_name)
......@@ -151,16 +152,22 @@ def save_mtmct_vis_results(camera_results, captures, output_dir,
tid_list = multi_res.keys() # c0_t1, c0_t2...
all_attr_result = [multi_res[i]["attrs"]
for i in tid_list] # all cid_tid result
if any(
all_attr_result
): # at least one cid_tid[attrs] is not None will goes to attrs_vis
attr_res = []
cid_str = 'c' + str(cid - 1) + "_"
for k in tid_list:
if not k.startswith(cid_str):
continue
if (frame_id - 1) >= len(multi_res[k]['attrs']):
t_attr = None
else:
t_attr = multi_res[k]['attrs'][frame_id - 1]
attr_res.append(t_attr)
assert len(attr_res) == len(boxes)
image = visualize_attr(
image, attr_res, boxes, is_mtmct=True)
......@@ -347,7 +354,7 @@ def res2dict(multi_res):
for tid, res in c_res.items():
key = "c" + str(cid) + "_t" + str(tid)
if key not in cid_tid_dict:
if len(res["features"])==0:
if len(res["features"]) == 0:
continue
cid_tid_dict[key] = res
cid_tid_dict[key]['mean_feat'] = distill_idfeat(res)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册