未验证 提交 f0e689e2 编写于 作者: G George Ni 提交者: GitHub

[MOT] add mot score (#3374)

* add mot score to txt results

* fix mot install typo

* add epoch for deepsort

* fix paddle no_grad

* deploy infer save image txt
上级 e8344214
...@@ -98,14 +98,17 @@ class MOT_Detector(object): ...@@ -98,14 +98,17 @@ class MOT_Detector(object):
def postprocess(self, pred_dets, pred_embs): def postprocess(self, pred_dets, pred_embs):
online_targets = self.tracker.update(pred_dets, pred_embs) online_targets = self.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], [] online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets: for t in online_targets:
tlwh = t.tlwh tlwh = t.tlwh
tid = t.track_id tid = t.track_id
tscore = t.score
vertical = tlwh[2] / tlwh[3] > 1.6 vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical: if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical:
online_tlwhs.append(tlwh) online_tlwhs.append(tlwh)
online_ids.append(tid) online_ids.append(tid)
return online_tlwhs, online_ids online_scores.append(tscore)
return online_tlwhs, online_scores, online_ids
def predict(self, image, threshold=0.5, repeats=1): def predict(self, image, threshold=0.5, repeats=1):
''' '''
...@@ -136,10 +139,11 @@ class MOT_Detector(object): ...@@ -136,10 +139,11 @@ class MOT_Detector(object):
self.det_times.inference_time_s.end(repeats=repeats) self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start() self.det_times.postprocess_time_s.start()
online_tlwhs, online_ids = self.postprocess(pred_dets, pred_embs) online_tlwhs, online_scores, online_ids = self.postprocess(pred_dets,
pred_embs)
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1 self.det_times.img_num += 1
return online_tlwhs, online_ids return online_tlwhs, online_scores, online_ids
def create_inputs(im, im_info): def create_inputs(im, im_info):
...@@ -290,6 +294,36 @@ def load_predictor(model_dir, ...@@ -290,6 +294,36 @@ def load_predictor(model_dir,
return predictor, config return predictor, config
def write_mot_results(filename, results, data_type='mot'):
if data_type in ['mot', 'mcmot', 'lab']:
save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
with open(filename, 'w') as f:
for frame_id, tlwhs, tscores, track_ids in results:
if data_type == 'kitti':
frame_id -= 1
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format(
frame=frame_id,
id=track_id,
x1=x1,
y1=y1,
x2=x2,
y2=y2,
w=w,
h=h,
score=score)
f.write(line)
def predict_video(detector, camera_id): def predict_video(detector, camera_id):
if camera_id != -1: if camera_id != -1:
capture = cv2.VideoCapture(camera_id) capture = cv2.VideoCapture(camera_id)
...@@ -311,20 +345,32 @@ def predict_video(detector, camera_id): ...@@ -311,20 +345,32 @@ def predict_video(detector, camera_id):
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
timer = MOTTimer() timer = MOTTimer()
results = []
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
break break
timer.tic() timer.tic()
online_tlwhs, online_ids = detector.predict(frame, FLAGS.threshold) online_tlwhs, online_scores, online_ids = detector.predict(
frame, FLAGS.threshold)
timer.toc() timer.toc()
results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
fps = 1. / timer.average_time
online_im = mot_vis.plot_tracking( online_im = mot_vis.plot_tracking(
frame, frame,
online_tlwhs, online_tlwhs,
online_ids, online_ids,
online_scores,
frame_id=frame_id, frame_id=frame_id,
fps=1. / timer.average_time) fps=fps)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
online_im)
frame_id += 1 frame_id += 1
print('detect frame:%d' % (frame_id)) print('detect frame:%d' % (frame_id))
im = np.array(online_im) im = np.array(online_im)
...@@ -333,6 +379,10 @@ def predict_video(detector, camera_id): ...@@ -333,6 +379,10 @@ def predict_video(detector, camera_id):
cv2.imshow('Tracking Detection', im) cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
if FLAGS.save_results:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
writer.release() writer.release()
......
...@@ -100,7 +100,14 @@ def argsparser(): ...@@ -100,7 +100,14 @@ def argsparser():
default=False, default=False,
help="If the model is produced by TRT offline quantitative " help="If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True.") "calibration, trt_calib_mode need to set True.")
parser.add_argument(
'--save_images',
action='store_true',
help='Save tracking results (image).')
parser.add_argument(
'--save_results',
action='store_true',
help='Save tracking results (txt).')
return parser return parser
......
...@@ -135,19 +135,24 @@ class Tracker(object): ...@@ -135,19 +135,24 @@ class Tracker(object):
online_targets = self.model.tracker.update(pred_dets, pred_embs) online_targets = self.model.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], [] online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets: for t in online_targets:
tlwh = t.tlwh tlwh = t.tlwh
tid = t.track_id tid = t.track_id
tscore = t.score
vertical = tlwh[2] / tlwh[3] > 1.6 vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical: if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
online_tlwhs.append(tlwh) online_tlwhs.append(tlwh)
online_ids.append(tid) online_ids.append(tid)
online_scores.append(tscore)
timer.toc() timer.toc()
# save results # save results
results.append((frame_id + 1, online_tlwhs, online_ids)) results.append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
self.save_results(data, frame_id, online_ids, online_tlwhs, self.save_results(data, frame_id, online_ids, online_tlwhs,
timer.average_time, show_image, save_dir) online_scores, timer.average_time, show_image,
save_dir)
frame_id += 1 frame_id += 1
return results, frame_id, timer.average_time, timer.calls return results, frame_id, timer.average_time, timer.calls
...@@ -206,20 +211,22 @@ class Tracker(object): ...@@ -206,20 +211,22 @@ class Tracker(object):
online_targets = self.model.tracker.update(detections) online_targets = self.model.tracker.update(detections)
online_tlwhs = [] online_tlwhs = []
online_scores = []
online_ids = [] online_ids = []
for track in online_targets: for track in online_targets:
if not track.is_confirmed() or track.time_since_update > 1: if not track.is_confirmed() or track.time_since_update > 1:
continue continue
tlwh = track.to_tlwh() online_tlwhs.append(track.to_tlwh())
track_id = track.track_id online_scores.append(1.0)
online_tlwhs.append(tlwh) online_ids.append(track.track_id)
online_ids.append(track_id)
timer.toc() timer.toc()
# save results # save results
results.append((frame_id + 1, online_tlwhs, online_ids)) results.append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
self.save_results(data, frame_id, online_ids, online_tlwhs, self.save_results(data, frame_id, online_ids, online_tlwhs,
timer.average_time, show_image, save_dir) online_scores, timer.average_time, show_image,
save_dir)
frame_id += 1 frame_id += 1
return results, frame_id, timer.average_time, timer.calls return results, frame_id, timer.average_time, timer.calls
...@@ -261,23 +268,23 @@ class Tracker(object): ...@@ -261,23 +268,23 @@ class Tracker(object):
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read() meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
frame_rate = int(meta_info[meta_info.find('frameRate') + 10: frame_rate = int(meta_info[meta_info.find('frameRate') + 10:
meta_info.find('\nseqLength')]) meta_info.find('\nseqLength')])
with paddle.no_grad():
if model_type in ['JDE', 'FairMOT']: if model_type in ['JDE', 'FairMOT']:
results, nf, ta, tc = self._eval_seq_jde( results, nf, ta, tc = self._eval_seq_jde(
dataloader, dataloader,
save_dir=save_dir, save_dir=save_dir,
show_image=show_image, show_image=show_image,
frame_rate=frame_rate) frame_rate=frame_rate)
elif model_type in ['DeepSORT']: elif model_type in ['DeepSORT']:
results, nf, ta, tc = self._eval_seq_sde( results, nf, ta, tc = self._eval_seq_sde(
dataloader, dataloader,
save_dir=save_dir, save_dir=save_dir,
show_image=show_image, show_image=show_image,
frame_rate=frame_rate, frame_rate=frame_rate,
det_file=os.path.join(det_results_dir, det_file=os.path.join(det_results_dir,
'{}.txt'.format(seq))) '{}.txt'.format(seq)))
else: else:
raise ValueError(model_type) raise ValueError(model_type)
self.write_mot_results(result_filename, results, data_type) self.write_mot_results(result_filename, results, data_type)
n_frame += nf n_frame += nf
...@@ -356,21 +363,23 @@ class Tracker(object): ...@@ -356,21 +363,23 @@ class Tracker(object):
result_filename = os.path.join(result_root, '{}.txt'.format(seq)) result_filename = os.path.join(result_root, '{}.txt'.format(seq))
frame_rate = self.dataset.frame_rate frame_rate = self.dataset.frame_rate
if model_type in ['JDE', 'FairMOT']: with paddle.no_grad():
results, nf, ta, tc = self._eval_seq_jde( if model_type in ['JDE', 'FairMOT']:
dataloader, results, nf, ta, tc = self._eval_seq_jde(
save_dir=save_dir, dataloader,
show_image=show_image, save_dir=save_dir,
frame_rate=frame_rate) show_image=show_image,
elif model_type in ['DeepSORT']: frame_rate=frame_rate)
results, nf, ta, tc = self._eval_seq_sde( elif model_type in ['DeepSORT']:
dataloader, results, nf, ta, tc = self._eval_seq_sde(
save_dir=save_dir, dataloader,
show_image=show_image, save_dir=save_dir,
frame_rate=frame_rate, show_image=show_image,
det_file=os.path.join(det_results_dir, '{}.txt'.format(seq))) frame_rate=frame_rate,
else: det_file=os.path.join(det_results_dir,
raise ValueError(model_type) '{}.txt'.format(seq)))
else:
raise ValueError(model_type)
self.write_mot_results(result_filename, results, data_type) self.write_mot_results(result_filename, results, data_type)
...@@ -384,17 +393,17 @@ class Tracker(object): ...@@ -384,17 +393,17 @@ class Tracker(object):
def write_mot_results(self, filename, results, data_type='mot'): def write_mot_results(self, filename, results, data_type='mot'):
if data_type in ['mot', 'mcmot', 'lab']: if data_type in ['mot', 'mcmot', 'lab']:
save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n' save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
elif data_type == 'kitti': elif data_type == 'kitti':
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n' save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else: else:
raise ValueError(data_type) raise ValueError(data_type)
with open(filename, 'w') as f: with open(filename, 'w') as f:
for frame_id, tlwhs, track_ids in results: for frame_id, tlwhs, tscores, track_ids in results:
if data_type == 'kitti': if data_type == 'kitti':
frame_id -= 1 frame_id -= 1
for tlwh, track_id in zip(tlwhs, track_ids): for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0: if track_id < 0:
continue continue
x1, y1, w, h = tlwh x1, y1, w, h = tlwh
...@@ -407,12 +416,13 @@ class Tracker(object): ...@@ -407,12 +416,13 @@ class Tracker(object):
x2=x2, x2=x2,
y2=y2, y2=y2,
w=w, w=w,
h=h) h=h,
score=score)
f.write(line) f.write(line)
logger.info('MOT results save in {}'.format(filename)) logger.info('MOT results save in {}'.format(filename))
def save_results(self, data, frame_id, online_ids, online_tlwhs, def save_results(self, data, frame_id, online_ids, online_tlwhs,
average_time, show_image, save_dir): online_scores, average_time, show_image, save_dir):
if show_image or save_dir is not None: if show_image or save_dir is not None:
assert 'ori_image' in data assert 'ori_image' in data
img0 = data['ori_image'].numpy()[0] img0 = data['ori_image'].numpy()[0]
...@@ -420,6 +430,7 @@ class Tracker(object): ...@@ -420,6 +430,7 @@ class Tracker(object):
img0, img0,
online_tlwhs, online_tlwhs,
online_ids, online_ids,
online_scores,
frame_id=frame_id, frame_id=frame_id,
fps=1. / average_time) fps=1. / average_time)
if show_image: if show_image:
......
...@@ -115,7 +115,7 @@ class Trainer(object): ...@@ -115,7 +115,7 @@ class Trainer(object):
self.status = {} self.status = {}
self.start_epoch = 0 self.start_epoch = 0
self.end_epoch = cfg.epoch self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
# initial default callbacks # initial default callbacks
self._init_callbacks() self._init_callbacks()
......
...@@ -107,7 +107,6 @@ class JDE(BaseArch): ...@@ -107,7 +107,6 @@ class JDE(BaseArch):
pred_dets = paddle.concat((bbox[:, 2:], bbox[:, 1:2]), axis=1) pred_dets = paddle.concat((bbox[:, 2:], bbox[:, 1:2]), axis=1)
boxes_idx = paddle.cast(boxes_idx, 'int64')
emb_valid = paddle.gather_nd(emb_outs, boxes_idx) emb_valid = paddle.gather_nd(emb_outs, boxes_idx)
pred_embs = paddle.gather_nd(emb_valid, nms_keep_idx) pred_embs = paddle.gather_nd(emb_valid, nms_keep_idx)
......
...@@ -76,10 +76,19 @@ def plot_tracking(image, ...@@ -76,10 +76,19 @@ def plot_tracking(image,
im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness) im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
cv2.putText( cv2.putText(
im, im,
id_text, (intbox[0], intbox[1] + 30), id_text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN, cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255), text_scale, (0, 0, 255),
thickness=text_thickness) thickness=text_thickness)
if scores is not None:
text = '{:.2f}'.format(float(scores[i]))
cv2.putText(
im,
text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
return im return im
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册