未验证 提交 8622966f 编写于 作者: W wangguanzhong 提交者: GitHub

update pphuman for pptracking (#5419)

上级 caa23c5e
# config of tracker for MOT SDE Detector, use ByteTracker as default.
# The tracker of MOT JDE Detector is exported together with the model.
# config of tracker for MOT SDE Detector, use 'JDETracker' as default.
# The tracker of MOT JDE Detector (such as FairMOT) is exported together with the model.
# Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking.
tracker:
use_byte: true
type: JDETracker # 'JDETracker' or 'DeepSORTTracker'
# BYTETracker
JDETracker:
use_byte: True
det_thresh: 0.3
conf_thres: 0.6
low_conf_thres: 0.1
match_thres: 0.9
min_box_area: 100
vertical_ratio: 1.6
vertical_ratio: 1.6 # for pedestrian
DeepSORTTracker:
input_size: [64, 192]
min_box_area: 0
vertical_ratio: -1
budget: 100
max_age: 70
n_init: 3
metric_type: cosine
matching_threshold: 0.2
max_iou_distance: 0.9
......@@ -28,7 +28,6 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from python.infer import Detector, DetectorPicoDet
from python.mot_sde_infer import SDE_Detector
from python.attr_infer import AttrDetector
from python.keypoint_infer import KeyPointDetector
from python.keypoint_postprocess import translate_to_ori_images
......@@ -39,6 +38,8 @@ from pipe_utils import argsparser, print_arguments, merge_cfg, PipeTimer
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint
from python.preprocess import decode_image
from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action
from pptracking.python.mot_sde_infer import SDE_Detector
from pptracking.python.visualize import plot_tracking
......@@ -374,6 +375,8 @@ class PipePredictor(object):
# det output format: class, score, xmin, ymin, xmax, ymax
det_res = self.det_predictor.predict_image(
batch_input, visual=False)
det_res = self.det_predictor.filter_box(det_res,
self.cfg['crop_thresh'])
if i > self.warmup_frame:
self.pipe_timer.module_time['det'].end()
self.pipeline_res.update(det_res, 'det')
......@@ -563,6 +566,8 @@ class PipePredictor(object):
det_res_i,
labels=['person'],
threshold=self.cfg['crop_thresh'])
im = np.ascontiguousarray(np.copy(im))
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
if attr_res is not None:
attr_res_i = attr_res['output'][start_idx:start_idx +
boxes_num_i]
......@@ -571,7 +576,7 @@ class PipePredictor(object):
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, img_name)
im.save(out_path, quality=95)
cv2.imwrite(out_path, im)
print("save result to: " + out_path)
start_idx += boxes_num_i
......
......@@ -24,8 +24,8 @@ import paddle
from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb, _is_valid_video, video2frames
from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
from .utils import argsparser, Timer, get_current_memory_mb, _is_valid_video, video2frames
from .det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
# add python path
import sys
......@@ -34,7 +34,7 @@ sys.path.insert(0, parent_path)
from mot.tracker import JDETracker, DeepSORTTracker
from mot.utils import MOTTimer, write_mot_results, flow_statistic, get_crops, clip_box
from visualize import plot_tracking, plot_tracking_dict
from .visualize import plot_tracking, plot_tracking_dict
from mot.mtmct.utils import parse_bias
from mot.mtmct.postprocess import trajectory_fusion, sub_cluster, gen_res, print_mtmct_result
......@@ -100,7 +100,7 @@ class SDE_Detector(Detector):
self.reid_predictor, self.config = load_predictor(
reid_model_dir,
run_mode=run_mode,
batch_size=50, # reid_batch_size
batch_size=50, # reid_batch_size
min_subgraph_size=self.reid_pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
......@@ -113,17 +113,19 @@ class SDE_Detector(Detector):
else:
self.reid_pred_config = None
self.reid_predictor = None
assert tracker_config is not None, 'Note that tracker_config should be set.'
self.tracker_config = tracker_config
tracker_cfg = yaml.safe_load(open(self.tracker_config))
cfg = tracker_cfg[tracker_cfg['type']]
# tracker config
self.use_deepsort_tracker = True if tracker_cfg['type'] == 'DeepSORTTracker' else False
self.use_deepsort_tracker = True if tracker_cfg[
'type'] == 'DeepSORTTracker' else False
if self.use_deepsort_tracker:
# use DeepSORTTracker
if self.reid_pred_config is not None and hasattr(self.reid_pred_config, 'tracker'):
if self.reid_pred_config is not None and hasattr(
self.reid_pred_config, 'tracker'):
cfg = self.reid_pred_config.tracker
budget = cfg.get('budget', 100)
max_age = cfg.get('max_age', 30)
......@@ -138,8 +140,7 @@ class SDE_Detector(Detector):
max_iou_distance=max_iou_distance,
matching_threshold=matching_threshold,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
)
vertical_ratio=vertical_ratio, )
else:
# use ByteTracker
use_byte = cfg.get('use_byte', False)
......@@ -158,9 +159,8 @@ class SDE_Detector(Detector):
vertical_ratio=vertical_ratio,
match_thres=match_thres,
conf_thres=conf_thres,
low_conf_thres=low_conf_thres,
)
low_conf_thres=low_conf_thres, )
self.do_mtmct = False if mtmct_dir is None else True
self.mtmct_dir = mtmct_dir
......@@ -193,7 +193,7 @@ class SDE_Detector(Detector):
crops = get_crops(pred_xyxys, ori_image, w, h)
# to keep fast speed, only use topk crops
crops = crops[:50] # reid_batch_size
crops = crops[:50] # reid_batch_size
det_results['crops'] = np.array(crops).astype('float32')
det_results['boxes'] = pred_dets[:50]
......@@ -206,7 +206,8 @@ class SDE_Detector(Detector):
for i in range(repeats):
self.reid_predictor.run()
output_names = self.reid_predictor.get_output_names()
feature_tensor = self.reid_predictor.get_output_handle(output_names[0])
feature_tensor = self.reid_predictor.get_output_handle(output_names[
0])
pred_embs = feature_tensor.copy_to_cpu()
det_results['embeddings'] = pred_embs
......@@ -249,7 +250,8 @@ class SDE_Detector(Detector):
frame_id = det_results['frame_id']
tracking_outs['feat_data'] = {}
for _tlbr, _id, _feat in zip(online_tlbrs, online_ids, online_feats):
for _tlbr, _id, _feat in zip(online_tlbrs, online_ids,
online_feats):
feat_data = {}
feat_data['bbox'] = _tlbr
feat_data['frame'] = f"{frame_id:06d}"
......@@ -265,7 +267,8 @@ class SDE_Detector(Detector):
online_scores = defaultdict(list)
online_ids = defaultdict(list)
if self.do_mtmct:
online_tlbrs, online_feats = defaultdict(list), defaultdict(list)
online_tlbrs, online_feats = defaultdict(list), defaultdict(
list)
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
for cls_id in range(self.num_classes):
online_targets = online_targets_dict[cls_id]
......@@ -295,7 +298,8 @@ class SDE_Detector(Detector):
seq_name = det_results['seq_name']
frame_id = det_results['frame_id']
tracking_outs['feat_data'] = {}
for _tlbr, _id, _feat in zip(online_tlbrs[0], online_ids[0], online_feats[0]):
for _tlbr, _id, _feat in zip(online_tlbrs[0], online_ids[0],
online_feats[0]):
feat_data = {}
feat_data['bbox'] = _tlbr
feat_data['frame'] = f"{frame_id:06d}"
......@@ -323,7 +327,7 @@ class SDE_Detector(Detector):
image_list.sort()
ids2names = self.pred_config.labels
if self.do_mtmct:
mot_features_dict = {} # cid_tid_fid feats
mot_features_dict = {} # cid_tid_fid feats
else:
mot_results = []
for frame_id, img_file in enumerate(image_list):
......@@ -429,7 +433,7 @@ class SDE_Detector(Detector):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
if self.do_mtmct:
return mot_features_dict
else:
......@@ -452,7 +456,7 @@ class SDE_Detector(Detector):
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 1
......@@ -469,14 +473,17 @@ class SDE_Detector(Detector):
timer.tic()
seq_name = video_out_name.split('.')[0]
mot_results = self.predict_image([frame], visual=False, seq_name=seq_name)
mot_results = self.predict_image(
[frame], visual=False, seq_name=seq_name)
timer.toc()
online_tlwhs, online_scores, online_ids = mot_results[0] # bs=1 in MOT model
online_tlwhs, online_scores, online_ids = mot_results[
0] # bs=1 in MOT model
fps = 1. / timer.duration
if num_classes == 1 and self.use_reid:
# use DeepSORTTracker, only support singe class
results[0].append((frame_id + 1, online_tlwhs, online_scores, online_ids))
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
im = plot_tracking(
frame,
online_tlwhs,
......@@ -488,8 +495,8 @@ class SDE_Detector(Detector):
# use ByteTracker, support multiple class
for cls_id in range(num_classes):
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
(frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
im = plot_tracking_dict(
frame,
num_classes,
......@@ -549,13 +556,15 @@ class SDE_Detector(Detector):
continue
if os.path.exists(os.path.join(fpath, 'img1')):
fpath = os.path.join(fpath, 'img1')
assert os.path.isdir(fpath), '{} should be a directory'.format(fpath)
assert os.path.isdir(fpath), '{} should be a directory'.format(
fpath)
image_list = glob.glob(os.path.join(fpath, '*.jpg'))
image_list.sort()
assert len(image_list) > 0, '{} has no images.'.format(fpath)
print('start tracking seq: {}'.format(seq))
mot_features_dict = self.predict_image(image_list, visual=False, seq_name=seq)
mot_features_dict = self.predict_image(
image_list, visual=False, seq_name=seq)
cid = int(re.sub('[a-z,A-Z]', "", seq))
tid_data, mot_list_break = trajectory_fusion(
......@@ -627,8 +636,7 @@ def main():
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
reid_model_dir=FLAGS.reid_model_dir,
mtmct_dir=FLAGS.mtmct_dir,
)
mtmct_dir=FLAGS.mtmct_dir, )
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......@@ -643,7 +651,8 @@ def main():
assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
seq_name = FLAGS.image_dir.split('/')[-1]
detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
detector.predict_image(
img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
......
......@@ -150,6 +150,25 @@ class Detector(object):
result = {k: v for k, v in result.items() if v is not None}
return result
def filter_box(self, result, threshold):
np_boxes_num = result['boxes_num']
boxes = result['boxes']
start_idx = 0
filter_boxes = []
filter_num = []
for i in range(len(np_boxes_num)):
boxes_num = np_boxes_num[i]
boxes_i = boxes[start_idx:start_idx + boxes_num, :]
idx = boxes_i[:, 1] > threshold
filter_boxes_i = boxes_i[idx, :]
filter_boxes.append(filter_boxes_i)
filter_num.append(filter_boxes_i.shape[0])
start_idx += boxes_num
boxes = np.concatenate(filter_boxes)
filter_num = np.array(filter_num)
filter_res = {'boxes': boxes, 'boxes_num': filter_num}
return filter_res
def predict(self, repeats=1):
'''
Args:
......@@ -736,19 +755,20 @@ def main():
elif arch == 'PicoDet':
detector_func = 'DetectorPicoDet'
detector = eval(detector_func)(FLAGS.model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir)
detector = eval(detector_func)(
FLAGS.model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......@@ -781,6 +801,8 @@ if __name__ == '__main__':
], "device should be CPU, GPU or XPU"
assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
assert not (FLAGS.enable_mkldnn==False and FLAGS.enable_mkldnn_bfloat16==True), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'
assert not (
FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True
), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册