未验证 提交 7dccc8f6 编写于 作者: W wangguanzhong 提交者: GitHub

Refactor python deploy (#5253)

* refactor det deploy

* refactor keypoint deploy

* fix solov2

* fit mot sde pipeline infer

* refine mot sde infer

* fit mot jde infer pipeline

* fit mot pose unite infer

* precommit for format

* refine keypoint detector name

* clean codes

* fix keypoint infer

* refine format
Co-authored-by: NFeng Ni <nemonameless@qq.com>
上级 ef83ab8a
......@@ -89,6 +89,8 @@ class PaddleInferBenchmark(object):
self.preprocess_time_s = perf_info.get('preprocess_time_s', 0)
self.postprocess_time_s = perf_info.get('postprocess_time_s', 0)
self.with_tracker = True if 'tracking_time_s' in perf_info else False
self.tracking_time_s = perf_info.get('tracking_time_s', 0)
self.total_time_s = perf_info.get('total_time_s', 0)
self.inference_time_s_90 = perf_info.get("inference_time_s_90", "")
......@@ -235,8 +237,18 @@ class PaddleInferBenchmark(object):
)
self.logger.info(
f"{identifier} total time spent(s): {self.total_time_s}")
if self.with_tracker:
self.logger.info(
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, "
f"inference_time(ms): {round(self.inference_time_s*1000, 1)}, "
f"postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}, "
f"tracking_time(ms): {round(self.tracking_time_s*1000, 1)}")
else:
self.logger.info(
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, "
f"inference_time(ms): {round(self.inference_time_s*1000, 1)}, "
f"postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
)
if self.inference_time_s_90:
self.looger.info(
......
......@@ -18,12 +18,13 @@ import cv2
import math
import numpy as np
import paddle
import yaml
from det_keypoint_unite_utils import argsparser
from preprocess import decode_image
from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint
from visualize import draw_pose
from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log
from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint
from visualize import visualize_pose
from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images
......@@ -34,24 +35,6 @@ KEYPOINT_SUPPORT_MODELS = {
}
def bench_log(detector, img_list, model_info, batch_size=1, name=None):
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
data_info = {
'batch_size': batch_size,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
log = PaddleInferBenchmark(detector.config, model_info, data_info,
perf_info, mems)
log(name)
def predict_with_given_det(image, det_res, keypoint_detector,
keypoint_batch_size, det_threshold,
keypoint_threshold, run_benchmark):
......@@ -59,32 +42,15 @@ def predict_with_given_det(image, det_res, keypoint_detector,
image, det_res, det_threshold)
keypoint_vector = []
score_vector = []
rect_vector = det_rects
batch_loop_cnt = math.ceil(float(len(rec_images)) / keypoint_batch_size)
for i in range(batch_loop_cnt):
start_index = i * keypoint_batch_size
end_index = min((i + 1) * keypoint_batch_size, len(rec_images))
batch_images = rec_images[start_index:end_index]
batch_records = np.array(records[start_index:end_index])
if run_benchmark:
# warmup
keypoint_result = keypoint_detector.predict(
batch_images, keypoint_threshold, repeats=10, add_timer=False)
# run benchmark
keypoint_result = keypoint_detector.predict(
batch_images, keypoint_threshold, repeats=10, add_timer=True)
else:
keypoint_result = keypoint_detector.predict(batch_images,
keypoint_threshold)
orgkeypoints, scores = translate_to_ori_images(keypoint_result,
batch_records)
keypoint_vector.append(orgkeypoints)
score_vector.append(scores)
rect_vector = det_rects
keypoint_results = keypoint_detector.predict_image(
rec_images, run_benchmark, repeats=10, visual=False)
keypoint_vector, score_vector = translate_to_ori_images(keypoint_results,
np.array(records))
keypoint_res = {}
keypoint_res['keypoint'] = [
np.vstack(keypoint_vector).tolist(), np.vstack(score_vector).tolist()
keypoint_vector.tolist(), score_vector.tolist()
] if len(keypoint_vector) > 0 else [[], []]
keypoint_res['bbox'] = rect_vector
return keypoint_res
......@@ -104,18 +70,15 @@ def topdown_unite_predict(detector,
det_timer.preprocess_time_s.end()
if FLAGS.run_benchmark:
# warmup
results = detector.predict(
[image], FLAGS.det_threshold, repeats=10, add_timer=False)
# run benchmark
results = detector.predict(
[image], FLAGS.det_threshold, repeats=10, add_timer=True)
results = detector.predict_image(
[image], run_benchmark=True, repeats=10)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
else:
results = detector.predict([image], FLAGS.det_threshold)
results = detector.predict_image([image], visual=False)
if results['boxes_num'] == 0:
continue
......@@ -137,10 +100,10 @@ def topdown_unite_predict(detector,
else:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
draw_pose(
visualize_pose(
img_file,
keypoint_res,
visual_thread=FLAGS.keypoint_threshold,
visual_thresh=FLAGS.keypoint_threshold,
save_dir=FLAGS.output_dir)
if save_res:
"""
......@@ -164,8 +127,7 @@ def topdown_unite_predict_video(detector,
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.splitext(os.path.basename(FLAGS.video_file))[
0] + '.mp4'
video_name = os.path.split(FLAGS.video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
......@@ -176,7 +138,7 @@ def topdown_unite_predict_video(detector,
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 0
store_res = []
......@@ -188,16 +150,17 @@ def topdown_unite_predict_video(detector,
print('detect frame: %d' % (index))
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = detector.predict([frame2], FLAGS.det_threshold)
results = detector.predict_image([frame2], visual=False)
keypoint_res = predict_with_given_det(
frame2, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
im = draw_pose(
im = visualize_pose(
frame,
keypoint_res,
visual_thread=FLAGS.keypoint_threshold,
visual_thresh=FLAGS.keypoint_threshold,
returnimg=True)
if save_res:
store_res.append([
......@@ -211,6 +174,7 @@ def topdown_unite_predict_video(detector,
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
print('output_video saved to: {}'.format(out_path))
if save_res:
"""
1) store_res: a list of frame_data
......@@ -224,13 +188,15 @@ def topdown_unite_predict_video(detector,
def main():
pred_config = PredictConfig(FLAGS.det_model_dir)
deploy_file = os.path.join(FLAGS.det_model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
detector_func = 'Detector'
if pred_config.arch == 'PicoDet':
if arch == 'PicoDet':
detector_func = 'DetectorPicoDet'
detector = eval(detector_func)(pred_config,
FLAGS.det_model_dir,
detector = eval(detector_func)(FLAGS.det_model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
trt_min_shape=FLAGS.trt_min_shape,
......@@ -238,14 +204,10 @@ def main():
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.det_threshold)
pred_config = PredictConfig_KeyPoint(FLAGS.keypoint_model_dir)
assert KEYPOINT_SUPPORT_MODELS[
pred_config.
arch] == 'keypoint_topdown', 'Detection-Keypoint unite inference only supports topdown models.'
topdown_keypoint_detector = KeyPoint_Detector(
pred_config,
topdown_keypoint_detector = KeyPointDetector(
FLAGS.keypoint_model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
......@@ -257,6 +219,9 @@ def main():
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
use_dark=FLAGS.use_dark)
keypoint_arch = topdown_keypoint_detector.pred_config.arch
assert KEYPOINT_SUPPORT_MODELS[
keypoint_arch] == 'keypoint_topdown', 'Detection-Keypoint unite inference only supports topdown models.'
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......
......@@ -24,9 +24,15 @@ import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
import sys
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from benchmark_utils import PaddleInferBenchmark
from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from visualize import visualize_box_mask
from utils import argsparser, Timer, get_current_memory_mb
......@@ -47,9 +53,27 @@ SUPPORT_MODELS = {
'PicoDet',
'CenterNet',
'TOOD',
'StrongBaseline',
}
def bench_log(detector, img_list, model_info, batch_size=1, name=None):
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
data_info = {
'batch_size': batch_size,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
log = PaddleInferBenchmark(detector.config, model_info, data_info,
perf_info, mems)
log(name)
class Detector(object):
"""
Args:
......@@ -65,10 +89,12 @@ class Detector(object):
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
output_dir (str): The path of output
threshold (float): The threshold of score for visualization
"""
def __init__(self,
pred_config,
def __init__(
self,
model_dir,
device='CPU',
run_mode='paddle',
......@@ -78,8 +104,10 @@ class Detector(object):
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
self.pred_config = pred_config
enable_mkldnn=False,
output_dir='output',
threshold=0.5, ):
self.pred_config = self.set_config(model_dir)
self.predictor, self.config = load_predictor(
model_dir,
run_mode=run_mode,
......@@ -95,6 +123,12 @@ class Detector(object):
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
self.batch_size = batch_size
self.output_dir = output_dir
self.threshold = threshold
def set_config(self, model_dir):
return PredictConfig(model_dir)
def preprocess(self, image_list):
preprocess_ops = []
......@@ -110,49 +144,34 @@ class Detector(object):
input_im_lst.append(im)
input_im_info_lst.append(im_info)
inputs = create_inputs(input_im_lst, input_im_info_lst)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
return inputs
def postprocess(self,
np_boxes,
np_masks,
inputs,
np_boxes_num,
threshold=0.5):
def postprocess(self, inputs, result):
# postprocess output of predictor
results = {}
results['boxes'] = np_boxes
results['boxes_num'] = np_boxes_num
if np_masks is not None:
results['masks'] = np_masks
return results
np_boxes_num = result['boxes_num']
if np_boxes_num[0] <= 0:
print('[WARNNING] No object detected.')
result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
result = {k: v for k, v in result.items() if v is not None}
return result
def predict(self, image_list, threshold=0.5, repeats=1, add_timer=True):
def predict(self, repeats=1):
'''
Args:
image_list (list): list of image
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
repeats (int): repeats number for prediction
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
MaskRCNN's result include 'masks': np.ndarray:
shape: [N, im_h, im_w]
'''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list)
np_boxes, np_masks = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction
np_boxes, np_masks = None, None
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
......@@ -163,32 +182,136 @@ class Detector(object):
if self.pred_config.mask:
masks_tensor = self.predictor.get_output_handle(output_names[2])
np_masks = masks_tensor.copy_to_cpu()
result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
return result
def merge_batch_result(self, batch_result):
if len(batch_result) == 1:
return batch_result[0]
res_key = batch_result[0].keys()
results = {k: [] for k in res_key}
for res in batch_result:
for k, v in res.items():
results[k].append(v)
for k, v in results.items():
results[k] = np.concatenate(v)
return results
def get_timer(self):
return self.det_times
if add_timer:
def predict_image(self,
image_list,
run_benchmark=False,
repeats=1,
visual=True):
batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
results = []
for i in range(batch_loop_cnt):
start_index = i * self.batch_size
end_index = min((i + 1) * self.batch_size, len(image_list))
batch_image_list = image_list[start_index:end_index]
if run_benchmark:
# preprocess
inputs = self.preprocess(batch_image_list) # warmup
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
result = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
# postprocess
results = []
if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
print('[WARNNING] No object detected.')
results = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
result_warmup = self.postprocess(inputs, result) # warmup
self.det_times.postprocess_time_s.start()
result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(batch_image_list)
cm, gm, gu = get_current_memory_mb()
self.cpu_mem += cm
self.gpu_mem += gm
self.gpu_util += gu
else:
results = self.postprocess(
np_boxes, np_masks, inputs, np_boxes_num, threshold=threshold)
if add_timer:
# preprocess
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
# postprocess
self.det_times.postprocess_time_s.start()
result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(image_list)
self.det_times.img_num += len(batch_image_list)
if visual:
visualize(
batch_image_list,
result,
self.pred_config.labels,
output_dir=self.output_dir,
threshold=self.threshold)
results.append(result)
if visual:
print('Test iter {}'.format(i))
results = self.merge_batch_result(results)
return results
def get_timer(self):
return self.det_times
def predict_video(self, video_file, camera_id):
video_out_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(video_file)
video_out_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame: %d' % (index))
index += 1
results = self.predict_image([frame], visual=False)
im = visualize_box_mask(
frame,
results,
self.pred_config.labels,
threshold=self.threshold)
im = np.array(im)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
class DetectorSOLOv2(Detector):
"""
Args:
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
......@@ -200,10 +323,13 @@ class DetectorSOLOv2(Detector):
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
output_dir (str): The path of output
threshold (float): The threshold of score for visualization
"""
def __init__(self,
pred_config,
def __init__(
self,
model_dir,
device='CPU',
run_mode='paddle',
......@@ -213,48 +339,33 @@ class DetectorSOLOv2(Detector):
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
self.pred_config = pred_config
self.predictor, self.config = load_predictor(
model_dir,
enable_mkldnn=False,
output_dir='./',
threshold=0.5, ):
super(DetectorSOLOv2, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
def predict(self, image, threshold=0.5, repeats=1, add_timer=True):
def predict(self, repeats=1):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns:
results (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
'cate_label': label of segm, shape:[N]
'cate_score': confidence score of segm, shape:[N]
'''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image)
np_label, np_score, np_segms = None, None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
......@@ -266,21 +377,18 @@ class DetectorSOLOv2(Detector):
2]).copy_to_cpu()
np_segms = self.predictor.get_output_handle(output_names[
3]).copy_to_cpu()
if add_timer:
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.img_num += 1
return dict(
result = dict(
segm=np_segms,
label=np_label,
score=np_score,
boxes_num=np_boxes_num)
return result
class DetectorPicoDet(Detector):
"""
Args:
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
......@@ -294,8 +402,8 @@ class DetectorPicoDet(Detector):
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
pred_config,
def __init__(
self,
model_dir,
device='CPU',
run_mode='paddle',
......@@ -305,50 +413,46 @@ class DetectorPicoDet(Detector):
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
self.pred_config = pred_config
self.predictor, self.config = load_predictor(
model_dir,
enable_mkldnn=False,
output_dir='./',
threshold=0.5, ):
super(DetectorPicoDet, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
def postprocess(self, inputs, result):
# postprocess output of predictor
np_score_list = result['boxes']
np_boxes_list = result['boxes_num']
postprocessor = PicoDetPostProcess(
inputs['image'].shape[2:],
inputs['im_shape'],
inputs['scale_factor'],
strides=self.pred_config.fpn_stride,
nms_threshold=self.pred_config.nms['nms_threshold'])
np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
return result
def predict(self, image, threshold=0.5, repeats=1, add_timer=True):
def predict(self, repeats=1):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
'''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
np_score_list, np_boxes_list = [], []
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model_prediction
for i in range(repeats):
self.predictor.run()
np_score_list.clear()
......@@ -362,22 +466,8 @@ class DetectorPicoDet(Detector):
np_boxes_list.append(
self.predictor.get_output_handle(output_names[
out_idx + num_outs]).copy_to_cpu())
if add_timer:
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.img_num += 1
self.det_times.postprocess_time_s.start()
# postprocess
self.postprocess = PicoDetPostProcess(
inputs['image'].shape[2:],
inputs['im_shape'],
inputs['scale_factor'],
strides=self.pred_config.fpn_stride,
nms_threshold=self.pred_config.nms['nms_threshold'])
np_boxes, np_boxes_num = self.postprocess(np_score_list, np_boxes_list)
if add_timer:
self.det_times.postprocess_time_s.end()
return dict(boxes=np_boxes, boxes_num=np_boxes_num)
result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
return result
def create_inputs(imgs, im_info):
......@@ -596,26 +686,26 @@ def get_test_images(infer_dir, infer_img):
return images
def visualize(image_list, results, labels, output_dir='output/', threshold=0.5):
def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
# visualize the predict result
start_idx = 0
for idx, image_file in enumerate(image_list):
im_bboxes_num = results['boxes_num'][idx]
im_bboxes_num = result['boxes_num'][idx]
im_results = {}
if 'boxes' in results:
im_results['boxes'] = results['boxes'][start_idx:start_idx +
if 'boxes' in result:
im_results['boxes'] = result['boxes'][start_idx:start_idx +
im_bboxes_num, :]
if 'masks' in results:
im_results['masks'] = results['masks'][start_idx:start_idx +
if 'masks' in result:
im_results['masks'] = result['masks'][start_idx:start_idx +
im_bboxes_num, :]
if 'segm' in results:
im_results['segm'] = results['segm'][start_idx:start_idx +
if 'segm' in result:
im_results['segm'] = result['segm'][start_idx:start_idx +
im_bboxes_num, :]
if 'label' in results:
im_results['label'] = results['label'][start_idx:start_idx +
if 'label' in result:
im_results['label'] = result['label'][start_idx:start_idx +
im_bboxes_num]
if 'score' in results:
im_results['score'] = results['score'][start_idx:start_idx +
if 'score' in result:
im_results['score'] = result['score'][start_idx:start_idx +
im_bboxes_num]
start_idx += im_bboxes_num
......@@ -636,86 +726,18 @@ def print_arguments(args):
print('------------------------------------------')
def predict_image(detector, image_list, batch_size=1):
batch_loop_cnt = math.ceil(float(len(image_list)) / batch_size)
for i in range(batch_loop_cnt):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(image_list))
batch_image_list = image_list[start_index:end_index]
if FLAGS.run_benchmark:
# warmup
detector.predict(
batch_image_list, FLAGS.threshold, repeats=10, add_timer=False)
# run benchmark
detector.predict(
batch_image_list, FLAGS.threshold, repeats=10, add_timer=True)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}'.format(i))
else:
results = detector.predict(batch_image_list, FLAGS.threshold)
visualize(
batch_image_list,
results,
detector.pred_config.labels,
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold)
def predict_video(detector, camera_id):
video_out_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_out_name = os.path.split(FLAGS.video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame: %d' % (index))
index += 1
results = detector.predict([frame], FLAGS.threshold)
im = visualize_box_mask(
frame,
results,
detector.pred_config.labels,
threshold=FLAGS.threshold)
im = np.array(im)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
def main():
pred_config = PredictConfig(FLAGS.model_dir)
deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
detector_func = 'Detector'
if pred_config.arch == 'SOLOv2':
if arch == 'SOLOv2':
detector_func = 'DetectorSOLOv2'
elif pred_config.arch == 'PicoDet':
elif arch == 'PicoDet':
detector_func = 'DetectorPicoDet'
detector = eval(detector_func)(pred_config,
FLAGS.model_dir,
detector = eval(detector_func)(FLAGS.model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
......@@ -724,41 +746,29 @@ def main():
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id)
detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
else:
# predict from image
if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list, FLAGS.batch_size)
detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
model_dir = FLAGS.model_dir
mode = FLAGS.run_mode
model_dir = FLAGS.model_dir
model_info = {
'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
data_info = {
'batch_size': FLAGS.batch_size,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
det_log = PaddleInferBenchmark(detector.config, model_info,
data_info, perf_info, mems)
det_log('Det')
bench_log(detector, img_list, model_info, name='DET')
if __name__ == '__main__':
......
......@@ -23,10 +23,16 @@ import cv2
import math
import numpy as np
import paddle
import sys
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from preprocess import preprocess, NormalizeImage, Permute
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess
from visualize import draw_pose
from visualize import visualize_pose
from paddle.inference import Config
from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb
......@@ -40,13 +46,13 @@ KEYPOINT_SUPPORT_MODELS = {
}
class KeyPoint_Detector(Detector):
class KeyPointDetector(Detector):
"""
Args:
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
......@@ -58,7 +64,6 @@ class KeyPoint_Detector(Detector):
"""
def __init__(self,
pred_config,
model_dir,
device='CPU',
run_mode='paddle',
......@@ -69,9 +74,10 @@ class KeyPoint_Detector(Detector):
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
output_dir='output',
threshold=0.5,
use_dark=True):
super(KeyPoint_Detector, self).__init__(
pred_config=pred_config,
super(KeyPointDetector, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
......@@ -81,9 +87,14 @@ class KeyPoint_Detector(Detector):
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
self.use_dark = use_dark
def set_config(self, model_dir):
return PredictConfig_KeyPoint(model_dir)
def get_person_from_rect(self, image, results, det_threshold=0.5):
# crop the person result from image
self.det_times.preprocess_time_s.start()
......@@ -103,34 +114,22 @@ class KeyPoint_Detector(Detector):
self.det_times.preprocess_time_s.end()
return rect_images, new_rects, org_rects
def preprocess(self, image_list):
preprocess_ops = []
for op_info in self.pred_config.preprocess_infos:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
input_im_lst = []
input_im_info_lst = []
for im in image_list:
im, im_info = preprocess(im, preprocess_ops)
input_im_lst.append(im)
input_im_info_lst.append(im_info)
inputs = create_inputs(input_im_lst, input_im_info_lst)
return inputs
def postprocess(self, np_boxes, np_masks, inputs, threshold=0.5):
def postprocess(self, inputs, result):
np_heatmap = result['heatmap']
np_masks = result['masks']
# postprocess output of predictor
if KEYPOINT_SUPPORT_MODELS[
self.pred_config.arch] == 'keypoint_bottomup':
results = {}
h, w = inputs['im_shape'][0]
preds = [np_boxes]
preds = [np_heatmap]
if np_masks is not None:
preds += np_masks
preds += [h, w]
keypoint_postprocess = HrHRNetPostProcess()
results['keypoint'] = keypoint_postprocess(*preds)
kpts, scores = keypoint_postprocess(*preds)
results['keypoint'] = kpts
results['score'] = scores
return results
elif KEYPOINT_SUPPORT_MODELS[
self.pred_config.arch] == 'keypoint_topdown':
......@@ -139,44 +138,31 @@ class KeyPoint_Detector(Detector):
center = np.round(imshape / 2.)
scale = imshape / 200.
keypoint_postprocess = HRNetPostProcess(use_dark=self.use_dark)
results['keypoint'] = keypoint_postprocess(np_boxes, center, scale)
kpts, scores = keypoint_postprocess(np_heatmap, center, scale)
results['keypoint'] = kpts
results['score'] = scores
return results
else:
raise ValueError("Unsupported arch: {}, expect {}".format(
self.pred_config.arch, KEYPOINT_SUPPORT_MODELS))
def predict(self, image_list, threshold=0.5, repeats=1, add_timer=True):
def predict(self, repeats=1):
'''
Args:
image_list (list): list of image
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape: [N, im_h, im_w]
'''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list)
np_boxes, np_masks = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction
np_heatmap, np_masks = None, None
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
heatmap_tensor = self.predictor.get_output_handle(output_names[0])
np_heatmap = heatmap_tensor.copy_to_cpu()
if self.pred_config.tagmap:
masks_tensor = self.predictor.get_output_handle(output_names[1])
heat_k = self.predictor.get_output_handle(output_names[2])
......@@ -185,18 +171,113 @@ class KeyPoint_Detector(Detector):
masks_tensor.copy_to_cpu(), heat_k.copy_to_cpu(),
inds_k.copy_to_cpu()
]
if add_timer:
result = dict(heatmap=np_heatmap, masks=np_masks)
return result
def predict_image(self,
image_list,
run_benchmark=False,
repeats=1,
visual=True):
results = []
batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
for i in range(batch_loop_cnt):
start_index = i * self.batch_size
end_index = min((i + 1) * self.batch_size, len(image_list))
batch_image_list = image_list[start_index:end_index]
if run_benchmark:
# preprocess
inputs = self.preprocess(batch_image_list) # warmup
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
result_warmup = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats)
# postprocess
result_warmup = self.postprocess(inputs, result) # warmup
self.det_times.postprocess_time_s.start()
result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(batch_image_list)
cm, gm, gu = get_current_memory_mb()
self.cpu_mem += cm
self.gpu_mem += gm
self.gpu_util += gu
else:
# preprocess
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
# postprocess
results = self.postprocess(
np_boxes, np_masks, inputs, threshold=threshold)
if add_timer:
self.det_times.postprocess_time_s.start()
result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(image_list)
self.det_times.img_num += len(batch_image_list)
if visual:
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
visualize(
batch_image_list,
result,
visual_thresh=self.threshold,
save_dir=self.output_dir)
results.append(result)
if visual:
print('Test iter {}'.format(i))
results = self.merge_batch_result(results)
return results
def predict_video(self, video_file, camera_id):
video_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(video_file)
video_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame: %d' % (index))
index += 1
results = self.predict_image([frame], visual=False)
im = visualize_pose(
frame, results, visual_thresh=self.threshold, returnimg=True)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
def create_inputs(imgs, im_info):
"""generate input for different model type
......@@ -258,90 +339,44 @@ class PredictConfig_KeyPoint():
print('--------------------------------------------')
def predict_image(detector, image_list):
for i, img_file in enumerate(image_list):
if FLAGS.run_benchmark:
# warmup
detector.predict(
[img_file], FLAGS.threshold, repeats=10, add_timer=False)
# run benchmark
detector.predict(
[img_file], FLAGS.threshold, repeats=10, add_timer=True)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file))
else:
results = detector.predict([img_file], FLAGS.threshold)
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
draw_pose(
img_file,
results,
visual_thread=FLAGS.threshold,
save_dir=FLAGS.output_dir)
def predict_video(detector, camera_id):
video_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name + '.mp4')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame: %d' % (index))
index += 1
results = detector.predict([frame], FLAGS.threshold)
im = draw_pose(
frame, results, visual_thread=FLAGS.threshold, returnimg=True)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
def visualize(image_list, results, visual_thresh=0.6, save_dir='output'):
im_results = {}
for i, image_file in enumerate(image_list):
skeletons = results['keypoint']
scores = results['score']
skeleton = skeletons[i:i + 1]
score = scores[i:i + 1]
im_results['keypoint'] = [skeleton, score]
visualize_pose(
image_file,
im_results,
visual_thresh=visual_thresh,
save_dir=save_dir)
def main():
pred_config = PredictConfig_KeyPoint(FLAGS.model_dir)
detector = KeyPoint_Detector(
pred_config,
detector = KeyPointDetector(
FLAGS.model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
use_dark=FLAGS.use_dark)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id)
detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list)
detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
......
......@@ -362,7 +362,8 @@ def affine_transform(pt, t):
def translate_to_ori_images(keypoint_result, batch_records):
kpts, scores = keypoint_result['keypoint']
kpts = keypoint_result['keypoint']
scores = keypoint_result['score']
kpts[..., 0] += batch_records[:, 0:1]
kpts[..., 1] += batch_records[:, 1:2]
return kpts, scores
......@@ -18,21 +18,24 @@ import yaml
import cv2
import numpy as np
from collections import defaultdict
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, PredictConfig
from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from ppdet.modeling.mot.tracker import JDETracker
from ppdet.modeling.mot.visualization import plot_tracking_dict
from ppdet.modeling.mot.utils import MOTTimer, write_mot_results
from pptracking.python.mot import JDETracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results
from pptracking.python.visualize import plot_tracking, plot_tracking_dict
# Global dictionary
MOT_SUPPORT_MODELS = {
MOT_JDE_SUPPORT_MODELS = {
'JDE',
'FairMOT',
}
......@@ -41,7 +44,6 @@ MOT_SUPPORT_MODELS = {
class JDE_Detector(Detector):
"""
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
......@@ -56,8 +58,8 @@ class JDE_Detector(Detector):
"""
def __init__(self,
pred_config,
model_dir,
tracker_config=None,
device='CPU',
run_mode='paddle',
batch_size=1,
......@@ -66,9 +68,10 @@ class JDE_Detector(Detector):
trt_opt_shape=608,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
enable_mkldnn=False,
output_dir='output',
threshold=0.5):
super(JDE_Detector, self).__init__(
pred_config=pred_config,
model_dir=model_dir,
device=device,
run_mode=run_mode,
......@@ -78,17 +81,21 @@ class JDE_Detector(Detector):
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
assert pred_config.tracker, "Tracking model should have tracker"
self.num_classes = len(pred_config.labels)
tp = pred_config.tracker
min_box_area = tp['min_box_area'] if 'min_box_area' in tp else 200
vertical_ratio = tp['vertical_ratio'] if 'vertical_ratio' in tp else 1.6
conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0.
tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7
metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean'
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True)
self.num_classes = len(self.pred_config.labels)
# tracker config
assert self.pred_config.tracker, "The exported JDE Detector model should have tracker."
cfg = self.pred_config.tracker
min_box_area = cfg.get('min_box_area', 200)
vertical_ratio = cfg.get('vertical_ratio', 1.6)
conf_thres = cfg.get('conf_thres', 0.0)
tracked_thresh = cfg.get('tracked_thresh', 0.7)
metric_type = cfg.get('metric_type', 'euclidean')
self.tracker = JDETracker(
num_classes=self.num_classes,
......@@ -98,7 +105,18 @@ class JDE_Detector(Detector):
tracked_thresh=tracked_thresh,
metric_type=metric_type)
def postprocess(self, pred_dets, pred_embs, threshold):
def postprocess(self, inputs, result):
# postprocess output of predictor
np_boxes = result['pred_dets']
if np_boxes.shape[0] <= 0:
print('[WARNNING] No object detected.')
result = {'pred_dets': np.zeros([0, 6]), 'pred_embs': None}
result = {k: v for k, v in result.items() if v is not None}
return result
def tracking(self, det_results):
pred_dets = det_results['pred_dets']
pred_embs = det_results['pred_embs']
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
online_tlwhs = defaultdict(list)
......@@ -110,7 +128,6 @@ class JDE_Detector(Detector):
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tscore < threshold: continue
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
......@@ -120,100 +137,123 @@ class JDE_Detector(Detector):
online_scores[cls_id].append(tscore)
return online_tlwhs, online_scores, online_ids
def predict(self, image_list, threshold=0.5, repeats=1, add_timer=True):
def predict(self, repeats=1):
'''
Args:
image_list (list): list of image
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
repeats (int): repeats number for prediction
Returns:
online_tlwhs, online_scores, online_ids (dict[np.array])
result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box,
matix element:[x_min, y_min, x_max, y_max, score, class]
FairMOT(JDE)'s result include 'pred_embs': np.ndarray:
shape: [N, 128]
'''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list)
pred_dets, pred_embs = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction
np_pred_dets, np_pred_embs = None, None
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
pred_dets = boxes_tensor.copy_to_cpu()
np_pred_dets = boxes_tensor.copy_to_cpu()
embs_tensor = self.predictor.get_output_handle(output_names[1])
pred_embs = embs_tensor.copy_to_cpu()
np_pred_embs = embs_tensor.copy_to_cpu()
result = dict(pred_dets=np_pred_dets, pred_embs=np_pred_embs)
return result
def predict_image(self,
image_list,
run_benchmark=False,
repeats=1,
visual=True):
mot_results = []
num_classes = self.num_classes
image_list.sort()
ids2names = self.pred_config.labels
data_type = 'mcmot' if num_classes > 1 else 'mot'
for frame_id, img_file in enumerate(image_list):
batch_image_list = [img_file] # bs=1 in MOT model
if run_benchmark:
# preprocess
inputs = self.preprocess(batch_image_list) # warmup
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
if add_timer:
# model prediction
result_warmup = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
# postprocess
online_tlwhs, online_scores, online_ids = self.postprocess(
pred_dets, pred_embs, threshold)
if add_timer:
result_warmup = self.postprocess(inputs, result) # warmup
self.det_times.postprocess_time_s.start()
det_result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return online_tlwhs, online_scores, online_ids
def predict_image(detector, image_list):
results = []
num_classes = detector.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = detector.pred_config.labels
# tracking
result_warmup = self.tracking(det_result)
self.det_times.tracking_time_s.start()
online_tlwhs, online_scores, online_ids = self.tracking(
det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
image_list.sort()
for frame_id, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
if FLAGS.run_benchmark:
# warmup
detector.predict(
[frame], FLAGS.threshold, repeats=10, add_timer=False)
# run benchmark
detector.predict(
[frame], FLAGS.threshold, repeats=10, add_timer=True)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(frame_id, img_file))
self.cpu_mem += cm
self.gpu_mem += gm
self.gpu_util += gu
else:
online_tlwhs, online_scores, online_ids = detector.predict(
[frame], FLAGS.threshold)
online_im = plot_tracking_dict(
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
self.det_times.postprocess_time_s.start()
det_result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
# tracking process
self.det_times.tracking_time_s.start()
online_tlwhs, online_scores, online_ids = self.tracking(
det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
if visual:
if frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {})
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id,
frame_id=frame_id,
ids2names=ids2names)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
img_name = os.path.split(img_file)[-1]
out_path = os.path.join(FLAGS.output_dir, img_name)
cv2.imwrite(out_path, online_im)
print("save result to: " + out_path)
seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results
def predict_video(detector, camera_id):
video_name = 'mot_output.mp4'
def predict_video(self, video_file, camera_id):
video_out_name = 'mot_output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
capture = cv2.VideoCapture(video_file)
video_out_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
......@@ -221,33 +261,37 @@ def predict_video(detector, camera_id):
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
frame_id = 1
timer = MOTTimer()
results = defaultdict(list) # support single class and multi classes
num_classes = detector.num_classes
num_classes = self.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = detector.pred_config.labels
ids2names = self.pred_config.labels
while (1):
ret, frame = capture.read()
if not ret:
break
if frame_id % 10 == 0:
print('Tracking frame: %d' % (frame_id))
frame_id += 1
timer.tic()
online_tlwhs, online_scores, online_ids = detector.predict(
[frame], FLAGS.threshold)
mot_results = self.predict_image([frame], visual=False)
timer.toc()
online_tlwhs, online_scores, online_ids = mot_results[0]
for cls_id in range(num_classes):
results[cls_id].append((frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
fps = 1. / timer.average_time
fps = 1. / timer.duration
im = plot_tracking_dict(
frame,
num_classes,
......@@ -257,41 +301,17 @@ def predict_video(detector, camera_id):
frame_id=frame_id,
fps=fps,
ids2names=ids2names)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
else:
writer.write(im)
frame_id += 1
print('detect frame: %d' % (frame_id))
writer.write(im)
if camera_id != -1:
cv2.imshow('Tracking Detection', im)
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results, data_type, num_classes)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir,
out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release()
def main():
pred_config = PredictConfig(FLAGS.model_dir)
detector = JDE_Detector(
pred_config,
FLAGS.model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
......@@ -304,34 +324,22 @@ def main():
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id)
detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list)
detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
model_dir = FLAGS.model_dir
mode = FLAGS.run_mode
model_dir = FLAGS.model_dir
model_info = {
'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
data_info = {
'batch_size': 1,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
det_log = PaddleInferBenchmark(detector.config, model_info,
data_info, perf_info, mems)
det_log('MOT')
bench_log(detector, img_list, model_info, name='MOT')
if __name__ == '__main__':
......
......@@ -13,31 +13,34 @@
# limitations under the License.
import os
import json
import cv2
import math
import copy
import numpy as np
from collections import defaultdict
import paddle
from utils import get_current_memory_mb
from infer import Detector, PredictConfig, print_arguments, get_test_images
from visualize import draw_pose
import yaml
import copy
from collections import defaultdict
from mot_keypoint_unite_utils import argsparser
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint
from det_keypoint_unite_infer import predict_with_given_det, bench_log
from mot_jde_infer import JDE_Detector
from preprocess import decode_image
from infer import print_arguments, get_test_images
from mot_sde_infer import SDE_Detector, MOT_SDE_SUPPORT_MODELS
from mot_jde_infer import JDE_Detector, MOT_JDE_SUPPORT_MODELS
from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS
from det_keypoint_unite_infer import predict_with_given_det
from visualize import visualize_pose
from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images
from ppdet.modeling.mot.visualization import plot_tracking_dict
from ppdet.modeling.mot.utils import MOTTimer as FPSTimer
from ppdet.modeling.mot.utils import write_mot_results
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
# Global dictionary
KEYPOINT_SUPPORT_MODELS = {
'HigherHRNet': 'keypoint_bottomup',
'HRNet': 'keypoint_topdown'
}
from pptracking.python.visualize import plot_tracking, plot_tracking_dict
from pptracking.python.mot.utils import MOTTimer as FPSTimer
def convert_mot_to_det(tlwhs, scores):
......@@ -49,94 +52,87 @@ def convert_mot_to_det(tlwhs, scores):
# support single class now
results['boxes'] = np.vstack(
[np.hstack([0, scores[i], xyxys[i]]) for i in range(num_mot)])
results['boxes_num'] = np.array([num_mot])
return results
def mot_keypoint_unite_predict_image(mot_model,
keypoint_model,
def mot_topdown_unite_predict(mot_detector,
topdown_keypoint_detector,
image_list,
keypoint_batch_size=1):
num_classes = mot_model.num_classes
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
data_type = 'mot'
keypoint_batch_size=1,
save_res=False):
det_timer = mot_detector.get_timer()
store_res = []
image_list.sort()
num_classes = mot_detector.num_classes
for i, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
# Decode image in advance in mot + pose prediction
det_timer.preprocess_time_s.start()
image, _ = decode_image(img_file, {})
det_timer.preprocess_time_s.end()
if FLAGS.run_benchmark:
# warmup
online_tlwhs, online_scores, online_ids = mot_model.predict(
[frame], FLAGS.mot_threshold, repeats=10, add_timer=False)
# run benchmark
online_tlwhs, online_scores, online_ids = mot_model.predict(
[frame], FLAGS.mot_threshold, repeats=10, add_timer=True)
cm, gm, gu = get_current_memory_mb()
mot_model.cpu_mem += cm
mot_model.gpu_mem += gm
mot_model.gpu_util += gu
else:
online_tlwhs, online_scores, online_ids = mot_model.predict(
[frame], FLAGS.mot_threshold)
keypoint_arch = keypoint_model.pred_config.arch
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown':
results = convert_mot_to_det(online_tlwhs, online_scores)
keypoint_results = predict_with_given_det(
frame, results, keypoint_model, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold,
FLAGS.run_benchmark)
mot_results = mot_detector.predict_image(
[image], run_benchmark=True, repeats=10)
cm, gm, gu = get_current_memory_mb()
mot_detector.cpu_mem += cm
mot_detector.gpu_mem += gm
mot_detector.gpu_util += gu
else:
if FLAGS.run_benchmark:
keypoint_results = keypoint_model.predict(
[frame],
FLAGS.keypoint_threshold,
repeats=10,
add_timer=False)
repeats = 10 if FLAGS.run_benchmark else 1
keypoint_results = keypoint_model.predict(
[frame], FLAGS.keypoint_threshold, repeats=repeats)
mot_results = mot_detector.predict_image([image], visual=False)
online_tlwhs, online_scores, online_ids = mot_results[
0] # only support bs=1 in MOT model
results = convert_mot_to_det(
online_tlwhs[0],
online_scores[0]) # only support single class for mot + pose
if results['boxes_num'] == 0:
continue
keypoint_res = predict_with_given_det(
image, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
if save_res:
store_res.append([
i, keypoint_res['bbox'],
[keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
])
if FLAGS.run_benchmark:
cm, gm, gu = get_current_memory_mb()
keypoint_model.cpu_mem += cm
keypoint_model.gpu_mem += gm
keypoint_model.gpu_util += gu
topdown_keypoint_detector.cpu_mem += cm
topdown_keypoint_detector.gpu_mem += gm
topdown_keypoint_detector.gpu_util += gu
else:
im = draw_pose(
frame,
keypoint_results,
visual_thread=FLAGS.keypoint_threshold,
returnimg=True,
ids=online_ids[0]
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown'
else None)
online_im = plot_tracking_dict(
im,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=i)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
img_name = os.path.split(img_file)[-1]
out_path = os.path.join(FLAGS.output_dir, img_name)
cv2.imwrite(out_path, online_im)
print("save result to: " + out_path)
def mot_keypoint_unite_predict_video(mot_model,
keypoint_model,
visualize_pose(
img_file,
keypoint_res,
visual_thresh=FLAGS.keypoint_threshold,
save_dir=FLAGS.output_dir)
if save_res:
"""
1) store_res: a list of image_data
2) image_data: [imageid, rects, [keypoints, scores]]
3) rects: list of rect [xmin, ymin, xmax, ymax]
4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
5) scores: mean of all joint conf
"""
with open("det_keypoint_unite_image_results.json", 'w') as wf:
json.dump(store_res, wf, indent=4)
def mot_topdown_unite_predict_video(mot_detector,
topdown_keypoint_detector,
camera_id,
keypoint_batch_size=1):
keypoint_batch_size=1,
save_res=False):
video_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
......@@ -150,17 +146,12 @@ def mot_keypoint_unite_predict_video(mot_model,
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer_mot = FPSTimer()
timer_kp = FPSTimer()
timer_mot_kp = FPSTimer()
timer_mot, timer_kp, timer_mot_kp = FPSTimer(), FPSTimer(), FPSTimer()
# support single class and multi classes, but should be single class here
mot_results = defaultdict(list)
num_classes = mot_model.num_classes
num_classes = mot_detector.num_classes
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
data_type = 'mot'
......@@ -168,43 +159,41 @@ def mot_keypoint_unite_predict_video(mot_model,
ret, frame = capture.read()
if not ret:
break
if frame_id % 10 == 0:
print('Tracking frame: %d' % (frame_id))
frame_id += 1
timer_mot_kp.tic()
# mot model
timer_mot.tic()
online_tlwhs, online_scores, online_ids = mot_model.predict(
[frame], FLAGS.mot_threshold)
mot_results = mot_detector.predict_image([frame], visual=False)
timer_mot.toc()
mot_results[0].append(
(frame_id + 1, online_tlwhs[0], online_scores[0], online_ids[0]))
mot_fps = 1. / timer_mot.average_time
online_tlwhs, online_scores, online_ids = mot_results[0]
results = convert_mot_to_det(
online_tlwhs[0],
online_scores[0]) # only support single class for mot + pose
if results['boxes_num'] == 0:
continue
# keypoint model
timer_kp.tic()
keypoint_arch = keypoint_model.pred_config.arch
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown':
results = convert_mot_to_det(online_tlwhs[0], online_scores[0])
keypoint_results = predict_with_given_det(
frame, results, keypoint_model, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold,
FLAGS.run_benchmark)
else:
keypoint_results = keypoint_model.predict([frame],
FLAGS.keypoint_threshold)
keypoint_res = predict_with_given_det(
frame, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
timer_kp.toc()
timer_mot_kp.toc()
kp_fps = 1. / timer_kp.average_time
mot_kp_fps = 1. / timer_mot_kp.average_time
im = draw_pose(
kp_fps = 1. / timer_kp.duration
mot_kp_fps = 1. / timer_mot_kp.duration
im = visualize_pose(
frame,
keypoint_results,
visual_thread=FLAGS.keypoint_threshold,
keypoint_res,
visual_thresh=FLAGS.keypoint_threshold,
returnimg=True,
ids=online_ids[0]
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown' else
None)
ids=online_ids[0])
online_im = plot_tracking_dict(
im = plot_tracking_dict(
im,
num_classes,
online_tlwhs,
......@@ -213,55 +202,40 @@ def mot_keypoint_unite_predict_video(mot_model,
frame_id=frame_id,
fps=mot_kp_fps)
im = np.array(online_im)
frame_id += 1
print('detect frame: %d' % (frame_id))
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
else:
writer.write(im)
if camera_id != -1:
cv2.imshow('Tracking and keypoint results', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, mot_results, data_type, num_classes)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir,
out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release()
print('output_video saved to: {}'.format(out_path))
def main():
pred_config = PredictConfig(FLAGS.mot_model_dir)
mot_model = JDE_Detector(
pred_config,
FLAGS.mot_model_dir,
deploy_file = os.path.join(FLAGS.mot_model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
mot_detector_func = 'SDE_Detector'
if arch in MOT_JDE_SUPPORT_MODELS:
mot_detector_func = 'JDE_Detector'
mot_detector = eval(mot_detector_func)(FLAGS.mot_model_dir,
FLAGS.tracker_config,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=1,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.mot_threshold,
output_dir=FLAGS.output_dir)
pred_config = PredictConfig_KeyPoint(FLAGS.keypoint_model_dir)
keypoint_model = KeyPoint_Detector(
pred_config,
topdown_keypoint_detector = KeyPointDetector(
FLAGS.keypoint_model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
......@@ -272,22 +246,27 @@ def main():
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.keypoint_threshold,
output_dir=FLAGS.output_dir,
use_dark=FLAGS.use_dark)
keypoint_arch = topdown_keypoint_detector.pred_config.arch
assert KEYPOINT_SUPPORT_MODELS[
keypoint_arch] == 'keypoint_topdown', 'MOT-Keypoint unite inference only supports topdown models.'
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
mot_keypoint_unite_predict_video(mot_model, keypoint_model,
FLAGS.camera_id,
FLAGS.keypoint_batch_size)
mot_topdown_unite_predict_video(
mot_detector, topdown_keypoint_detector, FLAGS.camera_id,
FLAGS.keypoint_batch_size, FLAGS.save_res)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
mot_keypoint_unite_predict_image(mot_model, keypoint_model, img_list,
FLAGS.keypoint_batch_size)
mot_topdown_unite_predict(mot_detector, topdown_keypoint_detector,
img_list, FLAGS.keypoint_batch_size,
FLAGS.save_res)
if not FLAGS.run_benchmark:
mot_model.det_times.info(average=True)
keypoint_model.det_times.info(average=True)
mot_detector.det_times.info(average=True)
topdown_keypoint_detector.det_times.info(average=True)
else:
mode = FLAGS.run_mode
mot_model_dir = FLAGS.mot_model_dir
......@@ -295,14 +274,15 @@ def main():
'model_name': mot_model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(mot_model, img_list, mot_model_info, name='MOT')
bench_log(mot_detector, img_list, mot_model_info, name='MOT')
keypoint_model_dir = FLAGS.keypoint_model_dir
keypoint_model_info = {
'model_name': keypoint_model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(keypoint_model, img_list, keypoint_model_info, 'KeyPoint')
bench_log(topdown_keypoint_detector, img_list, keypoint_model_info,
FLAGS.keypoint_batch_size, 'KeyPoint')
if __name__ == '__main__':
......
......@@ -123,5 +123,17 @@ def argsparser():
type=bool,
default=True,
help='whether to use darkpose to get better keypoint position predict ')
parser.add_argument(
'--save_res',
type=bool,
default=False,
help=(
"whether to save predict results to json file"
"1) store_res: a list of image_data"
"2) image_data: [imageid, rects, [keypoints, scores]]"
"3) rects: list of rect [xmin, ymin, xmax, ymax]"
"4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list"
"5) scores: mean of all joint conf"))
parser.add_argument(
"--tracker_config", type=str, default=None, help=("tracker donfig"))
return parser
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,93 +18,38 @@ import yaml
import cv2
import numpy as np
from collections import defaultdict
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from picodet_postprocess import PicoDetPostProcess
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, DetectorPicoDet, get_test_images, print_arguments, PredictConfig
from infer import load_predictor
from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
from ppdet.modeling.mot.tracker import DeepSORTTracker
from ppdet.modeling.mot.visualization import plot_tracking
from ppdet.modeling.mot.utils import MOTTimer, write_mot_results
# Global dictionary
MOT_SUPPORT_MODELS = {'DeepSORT'}
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from pptracking.python.mot import JDETracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results
from pptracking.python.visualize import plot_tracking, plot_tracking_dict
def bench_log(detector, img_list, model_info, batch_size=1, name=None):
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
data_info = {
'batch_size': batch_size,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
log = PaddleInferBenchmark(detector.config, model_info, data_info,
perf_info, mems)
log(name)
def scale_coords(coords, input_shape, im_shape, scale_factor):
im_shape = im_shape[0]
ratio = scale_factor[0][0]
pad_w = (input_shape[1] - int(im_shape[1])) / 2
pad_h = (input_shape[0] - int(im_shape[0])) / 2
coords[:, 0::2] -= pad_w
coords[:, 1::2] -= pad_h
coords[:, 0:4] /= ratio
coords[:, :4] = np.clip(coords[:, :4], a_min=0, a_max=coords[:, :4].max())
return coords.round()
def clip_box(xyxy, input_shape, im_shape, scale_factor):
im_shape = im_shape[0]
ratio = scale_factor[0][0]
img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)]
xyxy[:, 0::2] = np.clip(xyxy[:, 0::2], a_min=0, a_max=img0_shape[1])
xyxy[:, 1::2] = np.clip(xyxy[:, 1::2], a_min=0, a_max=img0_shape[0])
w = xyxy[:, 2:3] - xyxy[:, 0:1]
h = xyxy[:, 3:4] - xyxy[:, 1:2]
mask = np.logical_and(h > 0, w > 0)
keep_idx = np.nonzero(mask)
return xyxy[keep_idx[0]], keep_idx
def preprocess_reid(imgs,
w=64,
h=192,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
im_batch = []
for img in imgs:
img = cv2.resize(img, (w, h))
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
img = np.expand_dims(img, axis=0)
im_batch.append(img)
im_batch = np.concatenate(im_batch, 0)
return im_batch
# Global dictionary
MOT_SDE_SUPPORT_MODELS = {
'DeepSORT',
'ByteTrack',
'YOLO',
}
class SDE_Detector(Detector):
"""
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
tracker_config (str): tracker config path
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
......@@ -112,22 +57,24 @@ class SDE_Detector(Detector):
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
use_dark(bool): whether to use postprocess in DarkPose
"""
def __init__(self,
pred_config,
model_dir,
tracker_config,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1088,
trt_opt_shape=608,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
enable_mkldnn=False,
output_dir='output',
threshold=0.5):
super(SDE_Detector, self).__init__(
pred_config=pred_config,
model_dir=model_dir,
device=device,
run_mode=run_mode,
......@@ -137,424 +84,153 @@ class SDE_Detector(Detector):
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
self.pred_config = pred_config
def postprocess(self, boxes, input_shape, im_shape, scale_factor, threshold,
scaled):
over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0]
if len(over_thres_idx) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
return pred_dets, pred_xyxys
else:
boxes = boxes[over_thres_idx]
if not scaled:
# scaled means whether the coords after detector outputs
# have been scaled back to the original image, set True
# in general detector, set False in JDE YOLOv3.
pred_bboxes = scale_coords(boxes[:, 2:], input_shape, im_shape,
scale_factor)
else:
pred_bboxes = boxes[:, 2:]
pred_xyxys, keep_idx = clip_box(pred_bboxes, input_shape, im_shape,
scale_factor)
if len(keep_idx[0]) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
return pred_dets, pred_xyxys
pred_scores = boxes[:, 1:2][keep_idx[0]]
pred_cls_ids = boxes[:, 0:1][keep_idx[0]]
pred_tlwhs = np.concatenate(
(pred_xyxys[:, 0:2], pred_xyxys[:, 2:4] - pred_xyxys[:, 0:2] + 1),
axis=1)
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True)
self.num_classes = len(self.pred_config.labels)
# tracker config
self.tracker_config = tracker_config
cfg = yaml.safe_load(open(self.tracker_config))['tracker']
min_box_area = cfg.get('min_box_area', 200)
vertical_ratio = cfg.get('vertical_ratio', 1.6)
use_byte = cfg.get('use_byte', True)
match_thres = cfg.get('match_thres', 0.9)
conf_thres = cfg.get('conf_thres', 0.6)
low_conf_thres = cfg.get('low_conf_thres', 0.1)
self.tracker = JDETracker(
use_byte=use_byte,
num_classes=self.num_classes,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
match_thres=match_thres,
conf_thres=conf_thres,
low_conf_thres=low_conf_thres)
def tracking(self, det_results):
pred_dets = det_results['boxes']
pred_embs = None
pred_dets = np.concatenate(
(pred_tlwhs, pred_scores, pred_cls_ids), axis=1)
(pred_dets[:, 2:], pred_dets[:, 1:2], pred_dets[:, 0:1]), 1)
# pred_dets should be 'x0, y0, x1, y1, score, cls_id'
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
for cls_id in range(self.num_classes):
online_targets = online_targets_dict[cls_id]
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs[cls_id].append(tlwh)
online_ids[cls_id].append(tid)
online_scores[cls_id].append(tscore)
return pred_dets, pred_xyxys
return online_tlwhs, online_scores, online_ids
def predict(self, image, scaled, threshold=0.5, repeats=1, add_timer=True):
'''
Args:
image (np.ndarray): image numpy data
scaled (bool): whether the coords after detector outputs are scaled,
default False in jde yolov3, set True in general detector.
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns:
pred_dets (np.ndarray, [N, 6])
'''
def predict_image(self,
image_list,
run_benchmark=False,
repeats=1,
visual=True):
mot_results = []
num_classes = self.num_classes
image_list.sort()
ids2names = self.pred_config.labels
for frame_id, img_file in enumerate(image_list):
batch_image_list = [img_file] # bs=1 in MOT model
if run_benchmark:
# preprocess
if add_timer:
inputs = self.preprocess(batch_image_list) # warmup
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
boxes = boxes_tensor.copy_to_cpu()
if add_timer:
# model prediction
result_warmup = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
# postprocess
if len(boxes) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
else:
input_shape = inputs['image'].shape[2:]
im_shape = inputs['im_shape']
scale_factor = inputs['scale_factor']
pred_dets, pred_xyxys = self.postprocess(
boxes, input_shape, im_shape, scale_factor, threshold, scaled)
if add_timer:
result_warmup = self.postprocess(inputs, result) # warmup
self.det_times.postprocess_time_s.start()
det_result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return pred_dets, pred_xyxys
class SDE_DetectorPicoDet(DetectorPicoDet):
"""
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
pred_config,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1088,
trt_opt_shape=608,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
super(SDE_DetectorPicoDet, self).__init__(
pred_config=pred_config,
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
self.pred_config = pred_config
def postprocess_bboxes(self, boxes, input_shape, im_shape, scale_factor,
threshold):
over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0]
if len(over_thres_idx) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
return pred_dets, pred_xyxys
else:
boxes = boxes[over_thres_idx]
pred_bboxes = boxes[:, 2:]
pred_xyxys, keep_idx = clip_box(pred_bboxes, input_shape, im_shape,
scale_factor)
if len(keep_idx[0]) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
return pred_dets, pred_xyxys
pred_scores = boxes[:, 1:2][keep_idx[0]]
pred_cls_ids = boxes[:, 0:1][keep_idx[0]]
pred_tlwhs = np.concatenate(
(pred_xyxys[:, 0:2], pred_xyxys[:, 2:4] - pred_xyxys[:, 0:2] + 1),
axis=1)
pred_dets = np.concatenate(
(pred_tlwhs, pred_scores, pred_cls_ids), axis=1)
return pred_dets, pred_xyxys
def predict(self, image, scaled, threshold=0.5, repeats=1, add_timer=True):
'''
Args:
image (np.ndarray): image numpy data
scaled (bool): whether the coords after detector outputs are scaled,
default False in jde yolov3, set True in general detector.
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns:
pred_dets (np.ndarray, [N, 6])
'''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction
np_score_list, np_boxes_list = [], []
for i in range(repeats):
self.predictor.run()
np_score_list.clear()
np_boxes_list.clear()
output_names = self.predictor.get_output_names()
num_outs = int(len(output_names) / 2)
for out_idx in range(num_outs):
np_score_list.append(
self.predictor.get_output_handle(output_names[out_idx])
.copy_to_cpu())
np_boxes_list.append(
self.predictor.get_output_handle(output_names[
out_idx + num_outs]).copy_to_cpu())
if add_timer:
self.det_times.inference_time_s.end(repeats=repeats)
# tracking
result_warmup = self.tracking(det_result)
self.det_times.tracking_time_s.start()
online_tlwhs, online_scores, online_ids = self.tracking(
det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
self.det_times.postprocess_time_s.start()
# postprocess
self.postprocess = PicoDetPostProcess(
inputs['image'].shape[2:],
inputs['im_shape'],
inputs['scale_factor'],
strides=self.pred_config.fpn_stride,
nms_threshold=self.pred_config.nms['nms_threshold'])
boxes, boxes_num = self.postprocess(np_score_list, np_boxes_list)
if len(boxes) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
else:
input_shape = inputs['image'].shape[2:]
im_shape = inputs['im_shape']
scale_factor = inputs['scale_factor']
pred_dets, pred_xyxys = self.postprocess_bboxes(
boxes, input_shape, im_shape, scale_factor, threshold)
if add_timer:
self.det_times.postprocess_time_s.end()
return pred_dets, pred_xyxys
cm, gm, gu = get_current_memory_mb()
self.cpu_mem += cm
self.gpu_mem += gm
self.gpu_util += gu
class SDE_ReID(object):
def __init__(self,
pred_config,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=50,
trt_min_shape=1,
trt_max_shape=1088,
trt_opt_shape=608,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
self.pred_config = pred_config
self.predictor, self.config = load_predictor(
model_dir,
run_mode=run_mode,
batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
self.batch_size = batch_size
assert pred_config.tracker, "Tracking model should have tracker"
pt = pred_config.tracker
max_age = pt['max_age'] if 'max_age' in pt else 30
max_iou_distance = pt[
'max_iou_distance'] if 'max_iou_distance' in pt else 0.7
self.tracker = DeepSORTTracker(
max_age=max_age, max_iou_distance=max_iou_distance)
def get_crops(self, xyxy, ori_img):
w, h = self.tracker.input_size
else:
self.det_times.preprocess_time_s.start()
crops = []
xyxy = xyxy.astype(np.int64)
ori_img = ori_img.transpose(1, 0, 2) # [h,w,3]->[w,h,3]
for i, bbox in enumerate(xyxy):
crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :]
crops.append(crop)
crops = preprocess_reid(crops, w, h)
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
return crops
def preprocess(self, crops):
# to keep fast speed, only use topk crops
crops = crops[:self.batch_size]
inputs = {}
inputs['crops'] = np.array(crops).astype('float32')
return inputs
def postprocess(self, pred_dets, pred_embs):
tracker = self.tracker
tracker.predict()
online_targets = tracker.update(pred_dets, pred_embs)
online_tlwhs, online_scores, online_ids = [], [], []
for t in online_targets:
if not t.is_confirmed() or t.time_since_update > 1:
continue
tlwh = t.to_tlwh()
tscore = t.score
tid = t.track_id
if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue
if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh)
online_scores.append(tscore)
online_ids.append(tid)
return online_tlwhs, online_scores, online_ids
def predict(self, crops, pred_dets, repeats=1, add_timer=True):
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(crops)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
# model prediction
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
feature_tensor = self.predictor.get_output_handle(output_names[0])
pred_embs = feature_tensor.copy_to_cpu()
if add_timer:
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
# postprocess
online_tlwhs, online_scores, online_ids = self.postprocess(pred_dets,
pred_embs)
if add_timer:
det_result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return online_tlwhs, online_scores, online_ids
# tracking process
self.det_times.tracking_time_s.start()
online_tlwhs, online_scores, online_ids = self.tracking(
det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
def predict_image(detector, reid_model, image_list):
image_list.sort()
for i, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
if FLAGS.run_benchmark:
# warmup
pred_dets, pred_xyxys = detector.predict(
[frame],
FLAGS.scaled,
FLAGS.threshold,
repeats=10,
add_timer=True)
# run benchmark
pred_dets, pred_xyxys = detector.predict(
[frame],
FLAGS.scaled,
FLAGS.threshold,
repeats=10,
add_timer=True)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file))
else:
pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled,
FLAGS.threshold)
if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'.
format(i))
online_im = frame
else:
# reid process
crops = reid_model.get_crops(pred_xyxys, frame)
if FLAGS.run_benchmark:
# warmup
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, pred_dets, repeats=10, add_timer=False)
# run benchmark
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, pred_dets, repeats=10, add_timer=False)
else:
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, pred_dets)
online_im = plot_tracking(
frame, online_tlwhs, online_ids, online_scores, frame_id=i)
if visual:
if frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {})
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
img_name = os.path.split(img_file)[-1]
out_path = os.path.join(FLAGS.output_dir, img_name)
cv2.imwrite(out_path, online_im)
print("save result to: " + out_path)
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
ids2names=[])
seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results
def predict_video(detector, reid_model, camera_id):
def predict_video(self, video_file, camera_id):
video_out_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'mot_output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
capture = cv2.VideoCapture(video_file)
video_out_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
......@@ -562,86 +238,62 @@ def predict_video(detector, reid_model, camera_id):
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images:
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
frame_id = 1
timer = MOTTimer()
results = defaultdict(list)
results = defaultdict(list) # support single class and multi classes
num_classes = self.num_classes
while (1):
ret, frame = capture.read()
if not ret:
break
timer.tic()
pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled,
FLAGS.threshold)
if frame_id % 10 == 0:
print('Tracking frame: %d' % (frame_id))
frame_id += 1
if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'.
format(frame_id))
timer.toc()
im = frame
else:
# reid process
crops = reid_model.get_crops(pred_xyxys, frame)
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, pred_dets)
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
timer.tic()
mot_results = self.predict_image([frame], visual=False)
timer.toc()
fps = 1. / timer.average_time
im = plot_tracking(
online_tlwhs, online_scores, online_ids = mot_results[0]
for cls_id in range(num_classes):
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
fps = 1. / timer.duration
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps)
fps=fps,
ids2names=[])
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
else:
writer.write(im)
frame_id += 1
print('detect frame:%d' % (frame_id))
if camera_id != -1:
cv2.imshow('Tracking Detection', im)
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir,
out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release()
def main():
pred_config = PredictConfig(FLAGS.model_dir)
detector_func = 'SDE_Detector'
if pred_config.arch == 'PicoDet':
detector_func = 'SDE_DetectorPicoDet'
detector = eval(detector_func)(pred_config,
deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
assert arch in MOT_SDE_SUPPORT_MODELS, '{} is not supported.'.format(arch)
detector = SDE_Detector(
FLAGS.model_dir,
FLAGS.tracker_config,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
......@@ -650,48 +302,30 @@ def main():
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
pred_config = PredictConfig(FLAGS.reid_model_dir)
reid_model = SDE_ReID(
pred_config,
FLAGS.reid_model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.reid_batch_size,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, reid_model, FLAGS.camera_id)
detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
else:
# predict from image
if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, reid_model, img_list)
detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
reid_model.det_times.info(average=True)
else:
mode = FLAGS.run_mode
det_model_dir = FLAGS.model_dir
det_model_info = {
'model_name': det_model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(detector, img_list, det_model_info, name='Det')
reid_model_dir = FLAGS.reid_model_dir
reid_model_info = {
'model_name': reid_model_dir.strip('/').split('/')[-1],
model_dir = FLAGS.model_dir
model_info = {
'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(reid_model, img_list, reid_model_info, name='ReID')
bench_log(detector, img_list, model_info, name='MOT')
if __name__ == '__main__':
......
# config of tracker for MOT SDE Detector, use ByteTracker as default.
# The tracker of MOT JDE Detector is exported together with the model.
# Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking.
tracker:
use_byte: true
conf_thres: 0.6
low_conf_thres: 0.1
match_thres: 0.9
min_box_area: 100
vertical_ratio: 1.6
......@@ -118,6 +118,8 @@ def argsparser():
default=False,
help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector.")
parser.add_argument(
"--tracker_config", type=str, default=None, help=("tracker donfig"))
parser.add_argument(
"--reid_model_dir",
type=str,
......@@ -165,29 +167,36 @@ class Times(object):
class Timer(Times):
def __init__(self):
def __init__(self, with_tracker=False):
super(Timer, self).__init__()
self.with_tracker = with_tracker
self.preprocess_time_s = Times()
self.inference_time_s = Times()
self.postprocess_time_s = Times()
self.tracking_time_s = Times()
self.img_num = 0
def info(self, average=False):
total_time = self.preprocess_time_s.value(
) + self.inference_time_s.value() + self.postprocess_time_s.value()
pre_time = self.preprocess_time_s.value()
infer_time = self.inference_time_s.value()
post_time = self.postprocess_time_s.value()
track_time = self.tracking_time_s.value()
total_time = pre_time + infer_time + post_time
if self.with_tracker:
total_time = total_time + track_time
total_time = round(total_time, 4)
print("------------------ Inference Time Info ----------------------")
print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
self.img_num))
preprocess_time = round(
self.preprocess_time_s.value() / max(1, self.img_num),
4) if average else self.preprocess_time_s.value()
postprocess_time = round(
self.postprocess_time_s.value() / max(1, self.img_num),
4) if average else self.postprocess_time_s.value()
inference_time = round(self.inference_time_s.value() /
max(1, self.img_num),
4) if average else self.inference_time_s.value()
preprocess_time = round(pre_time / max(1, self.img_num),
4) if average else pre_time
postprocess_time = round(post_time / max(1, self.img_num),
4) if average else post_time
inference_time = round(infer_time / max(1, self.img_num),
4) if average else infer_time
tracking_time = round(track_time / max(1, self.img_num),
4) if average else track_time
average_latency = total_time / max(1, self.img_num)
qps = 0
......@@ -195,6 +204,12 @@ class Timer(Times):
qps = 1 / average_latency
print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
average_latency * 1000, qps))
if self.with_tracker:
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000, tracking_time * 1000))
else:
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
......@@ -202,18 +217,23 @@ class Timer(Times):
def report(self, average=False):
dic = {}
dic['preprocess_time_s'] = round(
self.preprocess_time_s.value() / max(1, self.img_num),
4) if average else self.preprocess_time_s.value()
dic['postprocess_time_s'] = round(
self.postprocess_time_s.value() / max(1, self.img_num),
4) if average else self.postprocess_time_s.value()
dic['inference_time_s'] = round(
self.inference_time_s.value() / max(1, self.img_num),
4) if average else self.inference_time_s.value()
pre_time = self.preprocess_time_s.value()
infer_time = self.inference_time_s.value()
post_time = self.postprocess_time_s.value()
track_time = self.tracking_time_s.value()
dic['preprocess_time_s'] = round(pre_time / max(1, self.img_num),
4) if average else pre_time
dic['inference_time_s'] = round(infer_time / max(1, self.img_num),
4) if average else infer_time
dic['postprocess_time_s'] = round(post_time / max(1, self.img_num),
4) if average else post_time
dic['tracking_time_s'] = round(post_time / max(1, self.img_num),
4) if average else track_time
dic['img_num'] = self.img_num
total_time = self.preprocess_time_s.value(
) + self.inference_time_s.value() + self.postprocess_time_s.value()
total_time = pre_time + infer_time + post_time
if self.with_tracker:
total_time = total_time + track_time
dic['total_time_s'] = round(total_time, 4)
return dic
......
......@@ -224,9 +224,9 @@ def get_color(idx):
return color
def draw_pose(imgfile,
def visualize_pose(imgfile,
results,
visual_thread=0.6,
visual_thresh=0.6,
save_name='pose.jpg',
save_dir='output',
returnimg=False,
......@@ -239,7 +239,6 @@ def draw_pose(imgfile,
logger.error('Matplotlib not found, please install matplotlib.'
'for example: `pip install matplotlib`.')
raise e
skeletons, scores = results['keypoint']
skeletons = np.array(skeletons)
kpt_nums = 17
......@@ -276,7 +275,7 @@ def draw_pose(imgfile,
canvas = img.copy()
for i in range(kpt_nums):
for j in range(len(skeletons)):
if skeletons[j][i, 2] < visual_thread:
if skeletons[j][i, 2] < visual_thresh:
continue
if ids is None:
color = colors[i] if color_set is None else colors[color_set[j]
......@@ -300,8 +299,8 @@ def draw_pose(imgfile,
for i in range(NUM_EDGES):
for j in range(len(skeletons)):
edge = EDGES[i]
if skeletons[j][edge[0], 2] < visual_thread or skeletons[j][edge[
1], 2] < visual_thread:
if skeletons[j][edge[0], 2] < visual_thresh or skeletons[j][edge[
1], 2] < visual_thresh:
continue
cur_canvas = canvas.copy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册