未验证 提交 f0e689e2 编写于 作者: G George Ni 提交者: GitHub

[MOT] add mot score (#3374)

* add mot score to txt results

* fix mot install typo

* add epoch for deepsort

* fix paddle no_grad

* deploy infer save image txt
上级 e8344214
......@@ -98,14 +98,17 @@ class MOT_Detector(object):
def postprocess(self, pred_dets, pred_embs):
online_targets = self.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
return online_tlwhs, online_ids
online_scores.append(tscore)
return online_tlwhs, online_scores, online_ids
def predict(self, image, threshold=0.5, repeats=1):
'''
......@@ -136,10 +139,11 @@ class MOT_Detector(object):
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
online_tlwhs, online_ids = self.postprocess(pred_dets, pred_embs)
online_tlwhs, online_scores, online_ids = self.postprocess(pred_dets,
pred_embs)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return online_tlwhs, online_ids
return online_tlwhs, online_scores, online_ids
def create_inputs(im, im_info):
......@@ -290,6 +294,36 @@ def load_predictor(model_dir,
return predictor, config
def write_mot_results(filename, results, data_type='mot'):
if data_type in ['mot', 'mcmot', 'lab']:
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
with open(filename, 'w') as f:
for frame_id, tlwhs, tscores, track_ids in results:
if data_type == 'kitti':
frame_id -= 1
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(
frame=frame_id,
id=track_id,
x1=x1,
y1=y1,
x2=x2,
y2=y2,
w=w,
h=h,
score=score)
f.write(line)
def predict_video(detector, camera_id):
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
......@@ -311,20 +345,32 @@ def predict_video(detector, camera_id):
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer = MOTTimer()
results = []
while (1):
ret, frame = capture.read()
if not ret:
break
timer.tic()
online_tlwhs, online_ids = detector.predict(frame, FLAGS.threshold)
online_tlwhs, online_scores, online_ids = detector.predict(
frame, FLAGS.threshold)
timer.toc()
results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
fps = 1. / timer.average_time
online_im = mot_vis.plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=1. / timer.average_time)
fps=fps)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
online_im)
frame_id += 1
print('detect frame:%d' % (frame_id))
im = np.array(online_im)
......@@ -333,6 +379,10 @@ def predict_video(detector, camera_id):
cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if FLAGS.save_results:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
writer.release()
......
......@@ -100,7 +100,14 @@ def argsparser():
default=False,
help="If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True.")
parser.add_argument(
'--save_images',
action='store_true',
help='Save tracking results (image).')
parser.add_argument(
'--save_results',
action='store_true',
help='Save tracking results (txt).')
return parser
......
......@@ -135,19 +135,24 @@ class Tracker(object):
online_targets = self.model.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(tscore)
timer.toc()
# save results
results.append((frame_id + 1, online_tlwhs, online_ids))
results.append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
self.save_results(data, frame_id, online_ids, online_tlwhs,
timer.average_time, show_image, save_dir)
online_scores, timer.average_time, show_image,
save_dir)
frame_id += 1
return results, frame_id, timer.average_time, timer.calls
......@@ -206,20 +211,22 @@ class Tracker(object):
online_targets = self.model.tracker.update(detections)
online_tlwhs = []
online_scores = []
online_ids = []
for track in online_targets:
if not track.is_confirmed() or track.time_since_update > 1:
continue
tlwh = track.to_tlwh()
track_id = track.track_id
online_tlwhs.append(tlwh)
online_ids.append(track_id)
online_tlwhs.append(track.to_tlwh())
online_scores.append(1.0)
online_ids.append(track.track_id)
timer.toc()
# save results
results.append((frame_id + 1, online_tlwhs, online_ids))
results.append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
self.save_results(data, frame_id, online_ids, online_tlwhs,
timer.average_time, show_image, save_dir)
online_scores, timer.average_time, show_image,
save_dir)
frame_id += 1
return results, frame_id, timer.average_time, timer.calls
......@@ -261,7 +268,7 @@ class Tracker(object):
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
frame_rate = int(meta_info[meta_info.find('frameRate') + 10:
meta_info.find('\nseqLength')])
with paddle.no_grad():
if model_type in ['JDE', 'FairMOT']:
results, nf, ta, tc = self._eval_seq_jde(
dataloader,
......@@ -356,6 +363,7 @@ class Tracker(object):
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
frame_rate = self.dataset.frame_rate
with paddle.no_grad():
if model_type in ['JDE', 'FairMOT']:
results, nf, ta, tc = self._eval_seq_jde(
dataloader,
......@@ -368,7 +376,8 @@ class Tracker(object):
save_dir=save_dir,
show_image=show_image,
frame_rate=frame_rate,
det_file=os.path.join(det_results_dir, '{}.txt'.format(seq)))
det_file=os.path.join(det_results_dir,
'{}.txt'.format(seq)))
else:
raise ValueError(model_type)
......@@ -384,17 +393,17 @@ class Tracker(object):
def write_mot_results(self, filename, results, data_type='mot'):
if data_type in ['mot', 'mcmot', 'lab']:
save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
with open(filename, 'w') as f:
for frame_id, tlwhs, track_ids in results:
for frame_id, tlwhs, tscores, track_ids in results:
if data_type == 'kitti':
frame_id -= 1
for tlwh, track_id in zip(tlwhs, track_ids):
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
......@@ -407,12 +416,13 @@ class Tracker(object):
x2=x2,
y2=y2,
w=w,
h=h)
h=h,
score=score)
f.write(line)
logger.info('MOT results save in {}'.format(filename))
def save_results(self, data, frame_id, online_ids, online_tlwhs,
average_time, show_image, save_dir):
online_scores, average_time, show_image, save_dir):
if show_image or save_dir is not None:
assert 'ori_image' in data
img0 = data['ori_image'].numpy()[0]
......@@ -420,6 +430,7 @@ class Tracker(object):
img0,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=1. / average_time)
if show_image:
......
......@@ -115,7 +115,7 @@ class Trainer(object):
self.status = {}
self.start_epoch = 0
self.end_epoch = cfg.epoch
self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
# initial default callbacks
self._init_callbacks()
......
......@@ -107,7 +107,6 @@ class JDE(BaseArch):
pred_dets = paddle.concat((bbox[:, 2:], bbox[:, 1:2]), axis=1)
boxes_idx = paddle.cast(boxes_idx, 'int64')
emb_valid = paddle.gather_nd(emb_outs, boxes_idx)
pred_embs = paddle.gather_nd(emb_valid, nms_keep_idx)
......
......@@ -76,10 +76,19 @@ def plot_tracking(image,
im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
cv2.putText(
im,
id_text, (intbox[0], intbox[1] + 30),
id_text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=text_thickness)
if scores is not None:
text = '{:.2f}'.format(float(scores[i]))
cv2.putText(
im,
text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
return im
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册