未验证 提交 7441fba7 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] fix picodet deepsort deploy, add cls_name visualization (#4513)

上级 9c0b62a7
...@@ -23,7 +23,6 @@ import paddle ...@@ -23,7 +23,6 @@ import paddle
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
from preprocess import preprocess
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, PredictConfig from infer import Detector, get_test_images, print_arguments, PredictConfig
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
...@@ -167,6 +166,8 @@ def predict_image(detector, image_list): ...@@ -167,6 +166,8 @@ def predict_image(detector, image_list):
results = [] results = []
num_classes = detector.num_classes num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot' data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = detector.pred_config.labels
image_list.sort() image_list.sort()
for frame_id, img_file in enumerate(image_list): for frame_id, img_file in enumerate(image_list):
frame = cv2.imread(img_file) frame = cv2.imread(img_file)
...@@ -181,7 +182,8 @@ def predict_image(detector, image_list): ...@@ -181,7 +182,8 @@ def predict_image(detector, image_list):
online_tlwhs, online_scores, online_ids = detector.predict( online_tlwhs, online_scores, online_ids = detector.predict(
[frame], FLAGS.threshold) [frame], FLAGS.threshold)
online_im = plot_tracking_dict(frame, num_classes, online_tlwhs, online_im = plot_tracking_dict(frame, num_classes, online_tlwhs,
online_ids, online_scores, frame_id) online_ids, online_scores, frame_id,
ids2names)
if FLAGS.save_images: if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
...@@ -216,6 +218,8 @@ def predict_video(detector, camera_id): ...@@ -216,6 +218,8 @@ def predict_video(detector, camera_id):
results = defaultdict(list) # support single class and multi classes results = defaultdict(list) # support single class and multi classes
num_classes = detector.num_classes num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot' data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = detector.pred_config.labels
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
...@@ -237,7 +241,8 @@ def predict_video(detector, camera_id): ...@@ -237,7 +241,8 @@ def predict_video(detector, camera_id):
online_ids, online_ids,
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
fps=fps) fps=fps,
ids2names=ids2names)
if FLAGS.save_images: if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
......
...@@ -23,9 +23,9 @@ import paddle ...@@ -23,9 +23,9 @@ import paddle
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
from preprocess import preprocess from picodet_postprocess import PicoDetPostProcess
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, PredictConfig from infer import Detector, DetectorPicoDet, get_test_images, print_arguments, PredictConfig
from infer import load_predictor from infer import load_predictor
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
...@@ -139,6 +139,7 @@ class SDE_Detector(Detector): ...@@ -139,6 +139,7 @@ class SDE_Detector(Detector):
cpu_threads=cpu_threads, cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now" assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
self.pred_config = pred_config
def postprocess(self, boxes, input_shape, im_shape, scale_factor, threshold, def postprocess(self, boxes, input_shape, im_shape, scale_factor, threshold,
scaled): scaled):
...@@ -147,6 +148,8 @@ class SDE_Detector(Detector): ...@@ -147,6 +148,8 @@ class SDE_Detector(Detector):
pred_dets = np.zeros((1, 6), dtype=np.float32) pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32) pred_xyxys = np.zeros((1, 4), dtype=np.float32)
return pred_dets, pred_xyxys return pred_dets, pred_xyxys
else:
boxes = boxes[over_thres_idx]
if not scaled: if not scaled:
# scaled means whether the coords after detector outputs # scaled means whether the coords after detector outputs
...@@ -159,6 +162,11 @@ class SDE_Detector(Detector): ...@@ -159,6 +162,11 @@ class SDE_Detector(Detector):
pred_xyxys, keep_idx = clip_box(pred_bboxes, input_shape, im_shape, pred_xyxys, keep_idx = clip_box(pred_bboxes, input_shape, im_shape,
scale_factor) scale_factor)
if len(keep_idx[0]) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
return pred_dets, pred_xyxys
pred_scores = boxes[:, 1:2][keep_idx[0]] pred_scores = boxes[:, 1:2][keep_idx[0]]
pred_cls_ids = boxes[:, 0:1][keep_idx[0]] pred_cls_ids = boxes[:, 0:1][keep_idx[0]]
pred_tlwhs = np.concatenate( pred_tlwhs = np.concatenate(
...@@ -168,7 +176,7 @@ class SDE_Detector(Detector): ...@@ -168,7 +176,7 @@ class SDE_Detector(Detector):
pred_dets = np.concatenate( pred_dets = np.concatenate(
(pred_tlwhs, pred_scores, pred_cls_ids), axis=1) (pred_tlwhs, pred_scores, pred_cls_ids), axis=1)
return pred_dets[over_thres_idx], pred_xyxys[over_thres_idx] return pred_dets, pred_xyxys
def predict(self, image, scaled, threshold=0.5, warmup=0, repeats=1): def predict(self, image, scaled, threshold=0.5, warmup=0, repeats=1):
''' '''
...@@ -220,6 +228,142 @@ class SDE_Detector(Detector): ...@@ -220,6 +228,142 @@ class SDE_Detector(Detector):
return pred_dets, pred_xyxys return pred_dets, pred_xyxys
class SDE_DetectorPicoDet(DetectorPicoDet):
"""
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1088,
trt_opt_shape=608,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
super(SDE_DetectorPicoDet, self).__init__(
pred_config=pred_config,
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
self.pred_config = pred_config
def postprocess_bboxes(self, boxes, input_shape, im_shape, scale_factor, threshold):
over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0]
if len(over_thres_idx) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
return pred_dets, pred_xyxys
else:
boxes = boxes[over_thres_idx]
pred_bboxes = boxes[:, 2:]
pred_xyxys, keep_idx = clip_box(pred_bboxes, input_shape, im_shape,
scale_factor)
if len(keep_idx[0]) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
return pred_dets, pred_xyxys
pred_scores = boxes[:, 1:2][keep_idx[0]]
pred_cls_ids = boxes[:, 0:1][keep_idx[0]]
pred_tlwhs = np.concatenate(
(pred_xyxys[:, 0:2], pred_xyxys[:, 2:4] - pred_xyxys[:, 0:2] + 1),
axis=1)
pred_dets = np.concatenate(
(pred_tlwhs, pred_scores, pred_cls_ids), axis=1)
return pred_dets, pred_xyxys
def predict(self, image, scaled, threshold=0.5, warmup=0, repeats=1):
'''
Args:
image (np.ndarray): image numpy data
threshold (float): threshold of predicted box' score
scaled (bool): whether the coords after detector outputs are scaled,
default False in jde yolov3, set True in general detector.
Returns:
pred_dets (np.ndarray, [N, 6])
'''
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image)
self.det_times.preprocess_time_s.end()
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
np_score_list, np_boxes_list = [], []
for i in range(warmup):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
boxes = boxes_tensor.copy_to_cpu()
self.det_times.inference_time_s.start()
for i in range(repeats):
self.predictor.run()
np_score_list.clear()
np_boxes_list.clear()
output_names = self.predictor.get_output_names()
num_outs = int(len(output_names) / 2)
for out_idx in range(num_outs):
np_score_list.append(
self.predictor.get_output_handle(output_names[out_idx])
.copy_to_cpu())
np_boxes_list.append(
self.predictor.get_output_handle(output_names[
out_idx + num_outs]).copy_to_cpu())
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.img_num += 1
self.det_times.postprocess_time_s.start()
self.postprocess = PicoDetPostProcess(
inputs['image'].shape[2:],
inputs['im_shape'],
inputs['scale_factor'],
strides=self.pred_config.fpn_stride,
nms_threshold=self.pred_config.nms['nms_threshold'])
boxes, boxes_num = self.postprocess(np_score_list, np_boxes_list)
if len(boxes) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
else:
input_shape = inputs['image'].shape[2:]
im_shape = inputs['im_shape']
scale_factor = inputs['scale_factor']
pred_dets, pred_xyxys = self.postprocess_bboxes(
boxes, input_shape, im_shape, scale_factor, threshold)
return pred_dets, pred_xyxys
class SDE_ReID(object): class SDE_ReID(object):
def __init__(self, def __init__(self,
pred_config, pred_config,
...@@ -350,7 +494,7 @@ def predict_image(detector, reid_model, image_list): ...@@ -350,7 +494,7 @@ def predict_image(detector, reid_model, image_list):
pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled, pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled,
FLAGS.threshold) FLAGS.threshold)
if len(pred_dets) == 1 and sum(pred_dets) == 0: if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'. print('Frame {} has no object, try to modify score threshold.'.
format(i)) format(i))
online_im = frame online_im = frame
...@@ -407,7 +551,7 @@ def predict_video(detector, reid_model, camera_id): ...@@ -407,7 +551,7 @@ def predict_video(detector, reid_model, camera_id):
pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled, pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled,
FLAGS.threshold) FLAGS.threshold)
if len(pred_dets) == 1 and sum(pred_dets) == 0: if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'. print('Frame {} has no object, try to modify score threshold.'.
format(frame_id)) format(frame_id))
timer.toc() timer.toc()
...@@ -464,11 +608,15 @@ def predict_video(detector, reid_model, camera_id): ...@@ -464,11 +608,15 @@ def predict_video(detector, reid_model, camera_id):
def main(): def main():
pred_config = PredictConfig(FLAGS.model_dir) pred_config = PredictConfig(FLAGS.model_dir)
detector = SDE_Detector( detector_func = 'SDE_Detector'
pred_config, if pred_config.arch == 'PicoDet':
detector_func = 'SDE_DetectorPicoDet'
detector = eval(detector_func)(pred_config,
FLAGS.model_dir, FLAGS.model_dir,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
trt_min_shape=FLAGS.trt_min_shape, trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape, trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
......
...@@ -28,7 +28,7 @@ def plot_tracking(image, ...@@ -28,7 +28,7 @@ def plot_tracking(image,
scores=None, scores=None,
frame_id=0, frame_id=0,
fps=0., fps=0.,
ids2=None): ids2names=[]):
im = np.ascontiguousarray(np.copy(image)) im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2] im_h, im_w = im.shape[:2]
...@@ -52,15 +52,16 @@ def plot_tracking(image, ...@@ -52,15 +52,16 @@ def plot_tracking(image,
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h))) intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i]) obj_id = int(obj_ids[i])
id_text = '{}'.format(int(obj_id)) id_text = '{}'.format(int(obj_id))
if ids2 is not None: if ids2names != []:
id_text = id_text + ', {}'.format(int(ids2[i])) assert len(ids2names) == 1, "plot_tracking only supports single classes."
id_text = '{}_'.format(ids2names[0]) + id_text
_line_thickness = 1 if obj_id <= 0 else line_thickness _line_thickness = 1 if obj_id <= 0 else line_thickness
color = get_color(abs(obj_id)) color = get_color(abs(obj_id))
cv2.rectangle( cv2.rectangle(
im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness) im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
cv2.putText( cv2.putText(
im, im,
id_text, (intbox[0], intbox[1] + 10), id_text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN, cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255), text_scale, (0, 0, 255),
thickness=text_thickness) thickness=text_thickness)
...@@ -69,7 +70,7 @@ def plot_tracking(image, ...@@ -69,7 +70,7 @@ def plot_tracking(image,
text = '{:.2f}'.format(float(scores[i])) text = '{:.2f}'.format(float(scores[i]))
cv2.putText( cv2.putText(
im, im,
text, (intbox[0], intbox[1] - 10), text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN, cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255), text_scale, (0, 255, 255),
thickness=text_thickness) thickness=text_thickness)
...@@ -83,7 +84,7 @@ def plot_tracking_dict(image, ...@@ -83,7 +84,7 @@ def plot_tracking_dict(image,
scores_dict, scores_dict,
frame_id=0, frame_id=0,
fps=0., fps=0.,
ids2=None): ids2names=[]):
im = np.ascontiguousarray(np.copy(image)) im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2] im_h, im_w = im.shape[:2]
...@@ -111,10 +112,12 @@ def plot_tracking_dict(image, ...@@ -111,10 +112,12 @@ def plot_tracking_dict(image,
x1, y1, w, h = tlwh x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h))) intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i]) obj_id = int(obj_ids[i])
if num_classes == 1:
id_text = '{}'.format(int(obj_id)) id_text = '{}'.format(int(obj_id))
if ids2names != []:
id_text = '{}_{}'.format(ids2names[cls_id], id_text)
else: else:
id_text = 'class{}_id{}'.format(cls_id, int(obj_id)) id_text = 'class{}_{}'.format(cls_id, id_text)
_line_thickness = 1 if obj_id <= 0 else line_thickness _line_thickness = 1 if obj_id <= 0 else line_thickness
color = get_color(abs(obj_id)) color = get_color(abs(obj_id))
...@@ -126,7 +129,7 @@ def plot_tracking_dict(image, ...@@ -126,7 +129,7 @@ def plot_tracking_dict(image,
thickness=line_thickness) thickness=line_thickness)
cv2.putText( cv2.putText(
im, im,
id_text, (intbox[0], intbox[1] + 10), id_text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN, cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255), text_scale, (0, 0, 255),
thickness=text_thickness) thickness=text_thickness)
...@@ -135,7 +138,7 @@ def plot_tracking_dict(image, ...@@ -135,7 +138,7 @@ def plot_tracking_dict(image,
text = '{:.2f}'.format(float(scores[i])) text = '{:.2f}'.format(float(scores[i]))
cv2.putText( cv2.putText(
im, im,
text, (intbox[0], intbox[1] - 10), text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN, cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255), text_scale, (0, 255, 255),
thickness=text_thickness) thickness=text_thickness)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册