未验证 提交 2f3950ee 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] refine mtmct deploy (#4633)

* fix mtmct deploy

* add enable_static
上级 7fc9614a
...@@ -341,6 +341,7 @@ def visual_rerank(prb_feats, ...@@ -341,6 +341,7 @@ def visual_rerank(prb_feats,
prb_feats, gal_feats = run_fac(prb_feats, gal_feats, prb_labels, prb_feats, gal_feats = run_fac(prb_feats, gal_feats, prb_labels,
gal_labels, 0.08, 20, 0.5, 1, 1) gal_labels, 0.08, 20, 0.5, 1, 1)
if use_rerank: if use_rerank:
paddle.enable_static()
print('current use rerank finetuned parameters....') print('current use rerank finetuned parameters....')
# Step2: k-reciprocal. finetuned parameters: [k1,k2,lambda_value] # Step2: k-reciprocal. finetuned parameters: [k1,k2,lambda_value]
sims = ReRank2( sims = ReRank2(
......
...@@ -32,7 +32,7 @@ from benchmark_utils import PaddleInferBenchmark ...@@ -32,7 +32,7 @@ from benchmark_utils import PaddleInferBenchmark
from visualize import plot_tracking from visualize import plot_tracking
from mot.tracker import DeepSORTTracker from mot.tracker import DeepSORTTracker
from mot.utils import MOTTimer, write_mot_results, flow_statistic from mot.utils import MOTTimer, write_mot_results, flow_statistic, scale_coords, clip_box, preprocess_reid
from mot.mtmct.utils import parse_bias from mot.mtmct.utils import parse_bias
from mot.mtmct.postprocess import trajectory_fusion, sub_cluster, gen_res, print_mtmct_result from mot.mtmct.postprocess import trajectory_fusion, sub_cluster, gen_res, print_mtmct_result
...@@ -59,50 +59,6 @@ def bench_log(detector, img_list, model_info, batch_size=1, name=None): ...@@ -59,50 +59,6 @@ def bench_log(detector, img_list, model_info, batch_size=1, name=None):
log(name) log(name)
def scale_coords(coords, input_shape, im_shape, scale_factor):
im_shape = im_shape[0]
ratio = scale_factor[0][0]
pad_w = (input_shape[1] - int(im_shape[1])) / 2
pad_h = (input_shape[0] - int(im_shape[0])) / 2
coords[:, 0::2] -= pad_w
coords[:, 1::2] -= pad_h
coords[:, 0:4] /= ratio
coords[:, :4] = np.clip(coords[:, :4], a_min=0, a_max=coords[:, :4].max())
return coords.round()
def clip_box(xyxy, input_shape, im_shape, scale_factor):
im_shape = im_shape[0]
ratio = scale_factor[0][0]
img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)]
xyxy[:, 0::2] = np.clip(xyxy[:, 0::2], a_min=0, a_max=img0_shape[1])
xyxy[:, 1::2] = np.clip(xyxy[:, 1::2], a_min=0, a_max=img0_shape[0])
w = xyxy[:, 2:3] - xyxy[:, 0:1]
h = xyxy[:, 3:4] - xyxy[:, 1:2]
mask = np.logical_and(h > 0, w > 0)
keep_idx = np.nonzero(mask)
return xyxy[keep_idx[0]], keep_idx
def preprocess_reid(imgs,
w=64,
h=192,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
im_batch = []
for img in imgs:
img = cv2.resize(img, (w, h))
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
img = np.expand_dims(img, axis=0)
im_batch.append(img)
im_batch = np.concatenate(im_batch, 0)
return im_batch
class SDE_Detector(Detector): class SDE_Detector(Detector):
""" """
Args: Args:
...@@ -146,8 +102,7 @@ class SDE_Detector(Detector): ...@@ -146,8 +102,7 @@ class SDE_Detector(Detector):
assert batch_size == 1, "The JDE Detector only supports batch size=1 now" assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
self.pred_config = pred_config self.pred_config = pred_config
def postprocess(self, boxes, input_shape, im_shape, scale_factor, threshold, def postprocess(self, boxes, ori_image_shape, threshold, scaled):
scaled):
over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0] over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0]
if len(over_thres_idx) == 0: if len(over_thres_idx) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32) pred_dets = np.zeros((1, 6), dtype=np.float32)
...@@ -165,8 +120,8 @@ class SDE_Detector(Detector): ...@@ -165,8 +120,8 @@ class SDE_Detector(Detector):
else: else:
pred_bboxes = boxes[:, 2:] pred_bboxes = boxes[:, 2:]
pred_xyxys, keep_idx = clip_box(pred_bboxes, input_shape, im_shape, pred_xyxys, keep_idx = clip_box(pred_bboxes, ori_image_shape)
scale_factor)
if len(keep_idx[0]) == 0: if len(keep_idx[0]) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32) pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32) pred_xyxys = np.zeros((1, 4), dtype=np.float32)
...@@ -183,10 +138,12 @@ class SDE_Detector(Detector): ...@@ -183,10 +138,12 @@ class SDE_Detector(Detector):
return pred_dets, pred_xyxys return pred_dets, pred_xyxys
def predict(self, image, scaled, threshold=0.5, warmup=0, repeats=1): def predict(self, image_path, ori_image_shape, scaled, threshold=0.5, warmup=0, repeats=1):
''' '''
Args: Args:
image (np.ndarray): image numpy data image_path (list[str]): path of images, only support one image path
(batch_size=1) in tracking model
ori_image_shape (list[int]: original image shape
threshold (float): threshold of predicted box' score threshold (float): threshold of predicted box' score
scaled (bool): whether the coords after detector outputs are scaled, scaled (bool): whether the coords after detector outputs are scaled,
default False in jde yolov3, set True in general detector. default False in jde yolov3, set True in general detector.
...@@ -194,7 +151,7 @@ class SDE_Detector(Detector): ...@@ -194,7 +151,7 @@ class SDE_Detector(Detector):
pred_dets (np.ndarray, [N, 6]) pred_dets (np.ndarray, [N, 6])
''' '''
self.det_times.preprocess_time_s.start() self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image) inputs = self.preprocess(image_path)
self.det_times.preprocess_time_s.end() self.det_times.preprocess_time_s.end()
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
...@@ -221,12 +178,8 @@ class SDE_Detector(Detector): ...@@ -221,12 +178,8 @@ class SDE_Detector(Detector):
pred_dets = np.zeros((1, 6), dtype=np.float32) pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32) pred_xyxys = np.zeros((1, 4), dtype=np.float32)
else: else:
input_shape = inputs['image'].shape[2:]
im_shape = inputs['im_shape']
scale_factor = inputs['scale_factor']
pred_dets, pred_xyxys = self.postprocess( pred_dets, pred_xyxys = self.postprocess(
boxes, input_shape, im_shape, scale_factor, threshold, scaled) boxes, ori_image_shape, threshold, scaled)
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1 self.det_times.img_num += 1
...@@ -727,7 +680,9 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir): ...@@ -727,7 +680,9 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
if frame_id % 40 == 0: if frame_id % 40 == 0:
print('Processing frame {} of seq {}.'.format(frame_id, seq_name)) print('Processing frame {} of seq {}.'.format(frame_id, seq_name))
frame = cv2.imread(os.path.join(fpath, img_file)) frame = cv2.imread(os.path.join(fpath, img_file))
pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled, ori_image_shape = list(frame.shape[:2])
frame_path = os.path.join(fpath, img_file)
pred_dets, pred_xyxys = detector.predict([frame_path], ori_image_shape, FLAGS.scaled,
FLAGS.threshold) FLAGS.threshold)
if len(pred_dets) == 1 and np.sum(pred_dets) == 0: if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
...@@ -855,8 +810,6 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg): ...@@ -855,8 +810,6 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
use_roi=use_roi, use_roi=use_roi,
roi_dir=roi_dir) roi_dir=roi_dir)
pred_mtmct_file = os.path.join(output_dir, 'mtmct_result.txt')
if FLAGS.save_images: if FLAGS.save_images:
carame_results, cid_tid_fid_res = get_mtmct_matching_results( carame_results, cid_tid_fid_res = get_mtmct_matching_results(
pred_mtmct_file) pred_mtmct_file)
...@@ -872,6 +825,11 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg): ...@@ -872,6 +825,11 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
save_dir=save_dir, save_dir=save_dir,
save_videos=FLAGS.save_images) save_videos=FLAGS.save_images)
# evalution metrics
data_root_gt = os.path.join(mtmct_dir, '..', 'gt', 'gt.txt')
if os.path.exists(data_root_gt):
print_mtmct_result(data_root_gt, pred_mtmct_file)
def main(): def main():
pred_config = PredictConfig(FLAGS.model_dir) pred_config = PredictConfig(FLAGS.model_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册