You need to sign in or sign up before continuing.
未验证 提交 d015403e 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] fix deploy infer of pptracking (#4659)

上级 c4db4e7f
...@@ -45,14 +45,14 @@ class JDE_Detector(Detector): ...@@ -45,14 +45,14 @@ class JDE_Detector(Detector):
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference batch_size (int): size of per batch in inference, default is 1 in tracking models
trt_min_shape (int): min shape for dynamic shape in trt trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
""" """
def __init__(self, def __init__(self,
...@@ -111,7 +111,8 @@ class JDE_Detector(Detector): ...@@ -111,7 +111,8 @@ class JDE_Detector(Detector):
tid = t.track_id tid = t.track_id
tscore = t.score tscore = t.score
if tscore < threshold: continue if tscore < threshold: continue
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio: 3] > self.tracker.vertical_ratio:
continue continue
...@@ -123,7 +124,8 @@ class JDE_Detector(Detector): ...@@ -123,7 +124,8 @@ class JDE_Detector(Detector):
def predict(self, image_list, threshold=0.5, warmup=0, repeats=1): def predict(self, image_list, threshold=0.5, warmup=0, repeats=1):
''' '''
Args: Args:
image_list (list): list of image image_list (list[str]): path of images, only support one image path
(batch_size=1) in tracking model
threshold (float): threshold of predicted box' score threshold (float): threshold of predicted box' score
Returns: Returns:
online_tlwhs, online_scores, online_ids (dict[np.array]) online_tlwhs, online_scores, online_ids (dict[np.array])
...@@ -159,6 +161,7 @@ class JDE_Detector(Detector): ...@@ -159,6 +161,7 @@ class JDE_Detector(Detector):
pred_dets, pred_embs, threshold) pred_dets, pred_embs, threshold)
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1 self.det_times.img_num += 1
return online_tlwhs, online_scores, online_ids return online_tlwhs, online_scores, online_ids
...@@ -172,7 +175,7 @@ def predict_image(detector, image_list): ...@@ -172,7 +175,7 @@ def predict_image(detector, image_list):
for frame_id, img_file in enumerate(image_list): for frame_id, img_file in enumerate(image_list):
frame = cv2.imread(img_file) frame = cv2.imread(img_file)
if FLAGS.run_benchmark: if FLAGS.run_benchmark:
detector.predict([frame], FLAGS.threshold, warmup=10, repeats=10) detector.predict([img_file], FLAGS.threshold, warmup=10, repeats=10)
cm, gm, gu = get_current_memory_mb() cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm detector.cpu_mem += cm
detector.gpu_mem += gm detector.gpu_mem += gm
...@@ -180,10 +183,15 @@ def predict_image(detector, image_list): ...@@ -180,10 +183,15 @@ def predict_image(detector, image_list):
print('Test iter {}, file name:{}'.format(frame_id, img_file)) print('Test iter {}, file name:{}'.format(frame_id, img_file))
else: else:
online_tlwhs, online_scores, online_ids = detector.predict( online_tlwhs, online_scores, online_ids = detector.predict(
[frame], FLAGS.threshold) [img_file], FLAGS.threshold)
online_im = plot_tracking_dict(frame, num_classes, online_tlwhs, online_im = plot_tracking_dict(
online_ids, online_scores, frame_id, frame,
ids2names) num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
ids2names=ids2names)
if FLAGS.save_images: if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
......
...@@ -61,11 +61,14 @@ def bench_log(detector, img_list, model_info, batch_size=1, name=None): ...@@ -61,11 +61,14 @@ def bench_log(detector, img_list, model_info, batch_size=1, name=None):
class SDE_Detector(Detector): class SDE_Detector(Detector):
""" """
Detector of SDE methods
Args: Args:
pred_config (object): config of model, defined by `Config(model_dir)` pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
batch_size (int): size of per batch in inference, default is 1 in tracking models
trt_min_shape (int): min shape for dynamic shape in trt trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt trt_opt_shape (int): opt shape for dynamic shape in trt
...@@ -99,10 +102,15 @@ class SDE_Detector(Detector): ...@@ -99,10 +102,15 @@ class SDE_Detector(Detector):
trt_calib_mode=trt_calib_mode, trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads, cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now" assert batch_size == 1, "The detector of tracking models only supports batch_size=1 now"
self.pred_config = pred_config self.pred_config = pred_config
def postprocess(self, boxes, ori_image_shape, threshold, scaled): def postprocess(self,
boxes,
ori_image_shape,
threshold,
inputs,
scaled=False):
over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0] over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0]
if len(over_thres_idx) == 0: if len(over_thres_idx) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32) pred_dets = np.zeros((1, 6), dtype=np.float32)
...@@ -115,6 +123,9 @@ class SDE_Detector(Detector): ...@@ -115,6 +123,9 @@ class SDE_Detector(Detector):
# scaled means whether the coords after detector outputs # scaled means whether the coords after detector outputs
# have been scaled back to the original image, set True # have been scaled back to the original image, set True
# in general detector, set False in JDE YOLOv3. # in general detector, set False in JDE YOLOv3.
input_shape = inputs['image'].shape[2:]
im_shape = inputs['im_shape'][0]
scale_factor = inputs['scale_factor'][0]
pred_bboxes = scale_coords(boxes[:, 2:], input_shape, im_shape, pred_bboxes = scale_coords(boxes[:, 2:], input_shape, im_shape,
scale_factor) scale_factor)
else: else:
...@@ -138,7 +149,13 @@ class SDE_Detector(Detector): ...@@ -138,7 +149,13 @@ class SDE_Detector(Detector):
return pred_dets, pred_xyxys return pred_dets, pred_xyxys
def predict(self, image_path, ori_image_shape, scaled, threshold=0.5, warmup=0, repeats=1): def predict(self,
image_path,
ori_image_shape,
threshold=0.5,
scaled=False,
warmup=0,
repeats=1):
''' '''
Args: Args:
image_path (list[str]): path of images, only support one image path image_path (list[str]): path of images, only support one image path
...@@ -148,7 +165,8 @@ class SDE_Detector(Detector): ...@@ -148,7 +165,8 @@ class SDE_Detector(Detector):
scaled (bool): whether the coords after detector outputs are scaled, scaled (bool): whether the coords after detector outputs are scaled,
default False in jde yolov3, set True in general detector. default False in jde yolov3, set True in general detector.
Returns: Returns:
pred_dets (np.ndarray, [N, 6]) pred_dets (np.ndarray, [N, 6]): 'x,y,w,h,score,cls_id'
pred_xyxys (np.ndarray, [N, 4]): 'x1,y1,x2,y2'
''' '''
self.det_times.preprocess_time_s.start() self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_path) inputs = self.preprocess(image_path)
...@@ -179,20 +197,24 @@ class SDE_Detector(Detector): ...@@ -179,20 +197,24 @@ class SDE_Detector(Detector):
pred_xyxys = np.zeros((1, 4), dtype=np.float32) pred_xyxys = np.zeros((1, 4), dtype=np.float32)
else: else:
pred_dets, pred_xyxys = self.postprocess( pred_dets, pred_xyxys = self.postprocess(
boxes, ori_image_shape, threshold, scaled) boxes, ori_image_shape, threshold, inputs, scaled=scaled)
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1 self.det_times.img_num += 1
return pred_dets, pred_xyxys return pred_dets, pred_xyxys
class SDE_DetectorPicoDet(DetectorPicoDet): class SDE_DetectorPicoDet(DetectorPicoDet):
""" """
PicoDet of SDE methods, the postprocess of PicoDet has not been exported as
other detectors, so do postprocess here.
Args: Args:
pred_config (object): config of model, defined by `Config(model_dir)` pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
batch_size (int): size of per batch in inference, default is 1 in tracking models
trt_min_shape (int): min shape for dynamic shape in trt trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt trt_opt_shape (int): opt shape for dynamic shape in trt
...@@ -226,11 +248,10 @@ class SDE_DetectorPicoDet(DetectorPicoDet): ...@@ -226,11 +248,10 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
trt_calib_mode=trt_calib_mode, trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads, cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now" assert batch_size == 1, "The detector of tracking models only supports batch_size=1 now"
self.pred_config = pred_config self.pred_config = pred_config
def postprocess_bboxes(self, boxes, input_shape, im_shape, scale_factor, def postprocess(self, boxes, ori_image_shape, threshold):
threshold):
over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0] over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0]
if len(over_thres_idx) == 0: if len(over_thres_idx) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32) pred_dets = np.zeros((1, 6), dtype=np.float32)
...@@ -241,8 +262,7 @@ class SDE_DetectorPicoDet(DetectorPicoDet): ...@@ -241,8 +262,7 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
pred_bboxes = boxes[:, 2:] pred_bboxes = boxes[:, 2:]
pred_xyxys, keep_idx = clip_box(pred_bboxes, input_shape, im_shape, pred_xyxys, keep_idx = clip_box(pred_bboxes, ori_image_shape)
scale_factor)
if len(keep_idx[0]) == 0: if len(keep_idx[0]) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32) pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32) pred_xyxys = np.zeros((1, 4), dtype=np.float32)
...@@ -256,20 +276,30 @@ class SDE_DetectorPicoDet(DetectorPicoDet): ...@@ -256,20 +276,30 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
pred_dets = np.concatenate( pred_dets = np.concatenate(
(pred_tlwhs, pred_scores, pred_cls_ids), axis=1) (pred_tlwhs, pred_scores, pred_cls_ids), axis=1)
return pred_dets, pred_xyxys return pred_dets, pred_xyxys
def predict(self, image, scaled, threshold=0.5, warmup=0, repeats=1): def predict(self,
image_path,
ori_image_shape,
threshold=0.5,
scaled=False,
warmup=0,
repeats=1):
''' '''
Args: Args:
image (np.ndarray): image numpy data image_path (list[str]): path of images, only support one image path
(batch_size=1) in tracking model
ori_image_shape (list[int]: original image shape
threshold (float): threshold of predicted box' score threshold (float): threshold of predicted box' score
scaled (bool): whether the coords after detector outputs are scaled, scaled (bool): whether the coords after detector outputs are scaled,
default False in jde yolov3, set True in general detector. default False in jde yolov3, set True in general detector.
Returns: Returns:
pred_dets (np.ndarray, [N, 6]) pred_dets (np.ndarray, [N, 6]): 'x,y,w,h,score,cls_id'
pred_xyxys (np.ndarray, [N, 4]): 'x1,y1,x2,y2'
''' '''
self.det_times.preprocess_time_s.start() self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image) inputs = self.preprocess(image_path)
self.det_times.preprocess_time_s.end() self.det_times.preprocess_time_s.end()
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
...@@ -298,32 +328,50 @@ class SDE_DetectorPicoDet(DetectorPicoDet): ...@@ -298,32 +328,50 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
np_boxes_list.append( np_boxes_list.append(
self.predictor.get_output_handle(output_names[ self.predictor.get_output_handle(output_names[
out_idx + num_outs]).copy_to_cpu()) out_idx + num_outs]).copy_to_cpu())
self.det_times.inference_time_s.end(repeats=repeats) self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.img_num += 1
self.det_times.postprocess_time_s.start() self.det_times.postprocess_time_s.start()
self.postprocess = PicoDetPostProcess( self.picodet_postprocess = PicoDetPostProcess(
inputs['image'].shape[2:], inputs['image'].shape[2:],
inputs['im_shape'], inputs['im_shape'],
inputs['scale_factor'], inputs['scale_factor'],
strides=self.pred_config.fpn_stride, strides=self.pred_config.fpn_stride,
nms_threshold=self.pred_config.nms['nms_threshold']) nms_threshold=self.pred_config.nms['nms_threshold'])
boxes, boxes_num = self.postprocess(np_score_list, np_boxes_list) boxes, boxes_num = self.picodet_postprocess(np_score_list,
np_boxes_list)
if len(boxes) == 0: if len(boxes) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32) pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32) pred_xyxys = np.zeros((1, 4), dtype=np.float32)
else: else:
input_shape = inputs['image'].shape[2:] pred_dets, pred_xyxys = self.postprocess(boxes, ori_image_shape,
im_shape = inputs['im_shape'] threshold)
scale_factor = inputs['scale_factor'] self.det_times.postprocess_time_s.end()
pred_dets, pred_xyxys = self.postprocess_bboxes( self.det_times.img_num += 1
boxes, input_shape, im_shape, scale_factor, threshold)
return pred_dets, pred_xyxys return pred_dets, pred_xyxys
class SDE_ReID(object): class SDE_ReID(object):
"""
ReID of SDE methods
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
batch_size (int): size of per batch in inference, default 50 means at most
50 sub images can be made a batch and send into ReID model
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self, def __init__(self,
pred_config, pred_config,
model_dir, model_dir,
...@@ -394,7 +442,8 @@ class SDE_ReID(object): ...@@ -394,7 +442,8 @@ class SDE_ReID(object):
tlwh = t.to_tlwh() tlwh = t.to_tlwh()
tscore = t.score tscore = t.score
tid = t.track_id tid = t.track_id
if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue if tlwh[2] * tlwh[3] <= tracker.min_box_area:
continue
if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > tracker.vertical_ratio: 3] > tracker.vertical_ratio:
continue continue
...@@ -422,7 +471,8 @@ class SDE_ReID(object): ...@@ -422,7 +471,8 @@ class SDE_ReID(object):
tlwh = t.to_tlwh() tlwh = t.to_tlwh()
tscore = t.score tscore = t.score
tid = t.track_id tid = t.track_id
if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue if tlwh[2] * tlwh[3] <= tracker.min_box_area:
continue
if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > tracker.vertical_ratio: 3] > tracker.vertical_ratio:
continue continue
...@@ -497,17 +547,23 @@ def predict_image(detector, reid_model, image_list): ...@@ -497,17 +547,23 @@ def predict_image(detector, reid_model, image_list):
image_list.sort() image_list.sort()
for i, img_file in enumerate(image_list): for i, img_file in enumerate(image_list):
frame = cv2.imread(img_file) frame = cv2.imread(img_file)
ori_image_shape = list(frame.shape[:2])
if FLAGS.run_benchmark: if FLAGS.run_benchmark:
pred_dets, pred_xyxys = detector.predict( pred_dets, pred_xyxys = detector.predict(
[frame], FLAGS.scaled, FLAGS.threshold, warmup=10, repeats=10) [img_file],
ori_image_shape,
FLAGS.threshold,
FLAGS.scaled,
warmup=10,
repeats=10)
cm, gm, gu = get_current_memory_mb() cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm detector.cpu_mem += cm
detector.gpu_mem += gm detector.gpu_mem += gm
detector.gpu_util += gu detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file)) print('Test iter {}, file name:{}'.format(i, img_file))
else: else:
pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled, pred_dets, pred_xyxys = detector.predict(
FLAGS.threshold) [img_file], ori_image_shape, FLAGS.threshold, FLAGS.scaled)
if len(pred_dets) == 1 and np.sum(pred_dets) == 0: if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'. print('Frame {} has no object, try to modify score threshold.'.
...@@ -577,8 +633,9 @@ def predict_video(detector, reid_model, camera_id): ...@@ -577,8 +633,9 @@ def predict_video(detector, reid_model, camera_id):
if not ret: if not ret:
break break
timer.tic() timer.tic()
pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled, ori_image_shape = list(frame.shape[:2])
FLAGS.threshold) pred_dets, pred_xyxys = detector.predict([frame], ori_image_shape,
FLAGS.threshold, FLAGS.scaled)
if len(pred_dets) == 1 and np.sum(pred_dets) == 0: if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'. print('Frame {} has no object, try to modify score threshold.'.
...@@ -674,7 +731,8 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir): ...@@ -674,7 +731,8 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
results = defaultdict(list) results = defaultdict(list)
mot_features_dict = {} # cid_tid_fid feats mot_features_dict = {} # cid_tid_fid feats
print('Totally {} frames found in seq {}.'.format(len(image_list), seq_name)) print('Totally {} frames found in seq {}.'.format(
len(image_list), seq_name))
for frame_id, img_file in enumerate(image_list): for frame_id, img_file in enumerate(image_list):
if frame_id % 40 == 0: if frame_id % 40 == 0:
...@@ -682,8 +740,8 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir): ...@@ -682,8 +740,8 @@ def predict_mtmct_seq(detector, reid_model, seq_name, output_dir):
frame = cv2.imread(os.path.join(fpath, img_file)) frame = cv2.imread(os.path.join(fpath, img_file))
ori_image_shape = list(frame.shape[:2]) ori_image_shape = list(frame.shape[:2])
frame_path = os.path.join(fpath, img_file) frame_path = os.path.join(fpath, img_file)
pred_dets, pred_xyxys = detector.predict([frame_path], ori_image_shape, FLAGS.scaled, pred_dets, pred_xyxys = detector.predict([frame_path], ori_image_shape,
FLAGS.threshold) FLAGS.threshold, FLAGS.scaled)
if len(pred_dets) == 1 and np.sum(pred_dets) == 0: if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'. print('Frame {} has no object, try to modify score threshold.'.
...@@ -765,15 +823,16 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg): ...@@ -765,15 +823,16 @@ def predict_mtmct(detector, reid_model, mtmct_dir, mtmct_cfg):
ext = seq.split('.')[-1] ext = seq.split('.')[-1]
seq = seq.split('.')[-2] seq = seq.split('.')[-2]
print('ffmpeg processing of video {}'.format(fpath)) print('ffmpeg processing of video {}'.format(fpath))
frames_path = video2frames(video_path=fpath, outpath=mtmct_dir, frame_rate=25) frames_path = video2frames(
video_path=fpath, outpath=mtmct_dir, frame_rate=25)
fpath = os.path.join(mtmct_dir, seq) fpath = os.path.join(mtmct_dir, seq)
if os.path.isdir(fpath) == False: if os.path.isdir(fpath) == False:
print('{} is not a image folder.'.format(fpath)) print('{} is not a image folder.'.format(fpath))
continue continue
mot_features_dict = predict_mtmct_seq(detector, reid_model, mot_features_dict = predict_mtmct_seq(detector, reid_model, seq,
seq, output_dir) output_dir)
cid = int(re.sub('[a-z,A-Z]', "", seq)) cid = int(re.sub('[a-z,A-Z]', "", seq))
tid_data, mot_list_break = trajectory_fusion( tid_data, mot_list_break = trajectory_fusion(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册