未验证 提交 7dccc8f6 编写于 作者: W wangguanzhong 提交者: GitHub

Refactor python deploy (#5253)

* refactor det deploy

* refactor keypoint deploy

* fix solov2

* fit mot sde pipeline infer

* refine mot sde infer

* fit mot jde infer pipeline

* fit mot pose unite infer

* precommit for format

* refine keypoint detector name

* clean codes

* fix keypoint infer

* refine format
Co-authored-by: NFeng Ni <nemonameless@qq.com>
上级 ef83ab8a
...@@ -89,6 +89,8 @@ class PaddleInferBenchmark(object): ...@@ -89,6 +89,8 @@ class PaddleInferBenchmark(object):
self.preprocess_time_s = perf_info.get('preprocess_time_s', 0) self.preprocess_time_s = perf_info.get('preprocess_time_s', 0)
self.postprocess_time_s = perf_info.get('postprocess_time_s', 0) self.postprocess_time_s = perf_info.get('postprocess_time_s', 0)
self.with_tracker = True if 'tracking_time_s' in perf_info else False
self.tracking_time_s = perf_info.get('tracking_time_s', 0)
self.total_time_s = perf_info.get('total_time_s', 0) self.total_time_s = perf_info.get('total_time_s', 0)
self.inference_time_s_90 = perf_info.get("inference_time_s_90", "") self.inference_time_s_90 = perf_info.get("inference_time_s_90", "")
...@@ -235,8 +237,18 @@ class PaddleInferBenchmark(object): ...@@ -235,8 +237,18 @@ class PaddleInferBenchmark(object):
) )
self.logger.info( self.logger.info(
f"{identifier} total time spent(s): {self.total_time_s}") f"{identifier} total time spent(s): {self.total_time_s}")
if self.with_tracker:
self.logger.info(
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, "
f"inference_time(ms): {round(self.inference_time_s*1000, 1)}, "
f"postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}, "
f"tracking_time(ms): {round(self.tracking_time_s*1000, 1)}")
else:
self.logger.info( self.logger.info(
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}" f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, "
f"inference_time(ms): {round(self.inference_time_s*1000, 1)}, "
f"postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
) )
if self.inference_time_s_90: if self.inference_time_s_90:
self.looger.info( self.looger.info(
......
...@@ -18,12 +18,13 @@ import cv2 ...@@ -18,12 +18,13 @@ import cv2
import math import math
import numpy as np import numpy as np
import paddle import paddle
import yaml
from det_keypoint_unite_utils import argsparser from det_keypoint_unite_utils import argsparser
from preprocess import decode_image from preprocess import decode_image
from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint
from visualize import draw_pose from visualize import visualize_pose
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images from keypoint_postprocess import translate_to_ori_images
...@@ -34,24 +35,6 @@ KEYPOINT_SUPPORT_MODELS = { ...@@ -34,24 +35,6 @@ KEYPOINT_SUPPORT_MODELS = {
} }
def bench_log(detector, img_list, model_info, batch_size=1, name=None):
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
data_info = {
'batch_size': batch_size,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
log = PaddleInferBenchmark(detector.config, model_info, data_info,
perf_info, mems)
log(name)
def predict_with_given_det(image, det_res, keypoint_detector, def predict_with_given_det(image, det_res, keypoint_detector,
keypoint_batch_size, det_threshold, keypoint_batch_size, det_threshold,
keypoint_threshold, run_benchmark): keypoint_threshold, run_benchmark):
...@@ -59,32 +42,15 @@ def predict_with_given_det(image, det_res, keypoint_detector, ...@@ -59,32 +42,15 @@ def predict_with_given_det(image, det_res, keypoint_detector,
image, det_res, det_threshold) image, det_res, det_threshold)
keypoint_vector = [] keypoint_vector = []
score_vector = [] score_vector = []
rect_vector = det_rects
batch_loop_cnt = math.ceil(float(len(rec_images)) / keypoint_batch_size)
for i in range(batch_loop_cnt):
start_index = i * keypoint_batch_size
end_index = min((i + 1) * keypoint_batch_size, len(rec_images))
batch_images = rec_images[start_index:end_index]
batch_records = np.array(records[start_index:end_index])
if run_benchmark:
# warmup
keypoint_result = keypoint_detector.predict(
batch_images, keypoint_threshold, repeats=10, add_timer=False)
# run benchmark
keypoint_result = keypoint_detector.predict(
batch_images, keypoint_threshold, repeats=10, add_timer=True)
else:
keypoint_result = keypoint_detector.predict(batch_images,
keypoint_threshold)
orgkeypoints, scores = translate_to_ori_images(keypoint_result,
batch_records)
keypoint_vector.append(orgkeypoints)
score_vector.append(scores)
rect_vector = det_rects
keypoint_results = keypoint_detector.predict_image(
rec_images, run_benchmark, repeats=10, visual=False)
keypoint_vector, score_vector = translate_to_ori_images(keypoint_results,
np.array(records))
keypoint_res = {} keypoint_res = {}
keypoint_res['keypoint'] = [ keypoint_res['keypoint'] = [
np.vstack(keypoint_vector).tolist(), np.vstack(score_vector).tolist() keypoint_vector.tolist(), score_vector.tolist()
] if len(keypoint_vector) > 0 else [[], []] ] if len(keypoint_vector) > 0 else [[], []]
keypoint_res['bbox'] = rect_vector keypoint_res['bbox'] = rect_vector
return keypoint_res return keypoint_res
...@@ -104,18 +70,15 @@ def topdown_unite_predict(detector, ...@@ -104,18 +70,15 @@ def topdown_unite_predict(detector,
det_timer.preprocess_time_s.end() det_timer.preprocess_time_s.end()
if FLAGS.run_benchmark: if FLAGS.run_benchmark:
# warmup results = detector.predict_image(
results = detector.predict( [image], run_benchmark=True, repeats=10)
[image], FLAGS.det_threshold, repeats=10, add_timer=False)
# run benchmark
results = detector.predict(
[image], FLAGS.det_threshold, repeats=10, add_timer=True)
cm, gm, gu = get_current_memory_mb() cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm detector.cpu_mem += cm
detector.gpu_mem += gm detector.gpu_mem += gm
detector.gpu_util += gu detector.gpu_util += gu
else: else:
results = detector.predict([image], FLAGS.det_threshold) results = detector.predict_image([image], visual=False)
if results['boxes_num'] == 0: if results['boxes_num'] == 0:
continue continue
...@@ -137,10 +100,10 @@ def topdown_unite_predict(detector, ...@@ -137,10 +100,10 @@ def topdown_unite_predict(detector,
else: else:
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
draw_pose( visualize_pose(
img_file, img_file,
keypoint_res, keypoint_res,
visual_thread=FLAGS.keypoint_threshold, visual_thresh=FLAGS.keypoint_threshold,
save_dir=FLAGS.output_dir) save_dir=FLAGS.output_dir)
if save_res: if save_res:
""" """
...@@ -164,8 +127,7 @@ def topdown_unite_predict_video(detector, ...@@ -164,8 +127,7 @@ def topdown_unite_predict_video(detector,
capture = cv2.VideoCapture(camera_id) capture = cv2.VideoCapture(camera_id)
else: else:
capture = cv2.VideoCapture(FLAGS.video_file) capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.splitext(os.path.basename(FLAGS.video_file))[ video_name = os.path.split(FLAGS.video_file)[-1]
0] + '.mp4'
# Get Video info : resolution, fps, frame count # Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
...@@ -176,7 +138,7 @@ def topdown_unite_predict_video(detector, ...@@ -176,7 +138,7 @@ def topdown_unite_predict_video(detector,
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v') fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 0 index = 0
store_res = [] store_res = []
...@@ -188,16 +150,17 @@ def topdown_unite_predict_video(detector, ...@@ -188,16 +150,17 @@ def topdown_unite_predict_video(detector,
print('detect frame: %d' % (index)) print('detect frame: %d' % (index))
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = detector.predict([frame2], FLAGS.det_threshold)
results = detector.predict_image([frame2], visual=False)
keypoint_res = predict_with_given_det( keypoint_res = predict_with_given_det(
frame2, results, topdown_keypoint_detector, keypoint_batch_size, frame2, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark) FLAGS.det_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
im = draw_pose( im = visualize_pose(
frame, frame,
keypoint_res, keypoint_res,
visual_thread=FLAGS.keypoint_threshold, visual_thresh=FLAGS.keypoint_threshold,
returnimg=True) returnimg=True)
if save_res: if save_res:
store_res.append([ store_res.append([
...@@ -211,6 +174,7 @@ def topdown_unite_predict_video(detector, ...@@ -211,6 +174,7 @@ def topdown_unite_predict_video(detector,
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
writer.release() writer.release()
print('output_video saved to: {}'.format(out_path))
if save_res: if save_res:
""" """
1) store_res: a list of frame_data 1) store_res: a list of frame_data
...@@ -224,13 +188,15 @@ def topdown_unite_predict_video(detector, ...@@ -224,13 +188,15 @@ def topdown_unite_predict_video(detector,
def main(): def main():
pred_config = PredictConfig(FLAGS.det_model_dir) deploy_file = os.path.join(FLAGS.det_model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
detector_func = 'Detector' detector_func = 'Detector'
if pred_config.arch == 'PicoDet': if arch == 'PicoDet':
detector_func = 'DetectorPicoDet' detector_func = 'DetectorPicoDet'
detector = eval(detector_func)(pred_config, detector = eval(detector_func)(FLAGS.det_model_dir,
FLAGS.det_model_dir,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
trt_min_shape=FLAGS.trt_min_shape, trt_min_shape=FLAGS.trt_min_shape,
...@@ -238,14 +204,10 @@ def main(): ...@@ -238,14 +204,10 @@ def main():
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode, trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn) enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.det_threshold)
pred_config = PredictConfig_KeyPoint(FLAGS.keypoint_model_dir) topdown_keypoint_detector = KeyPointDetector(
assert KEYPOINT_SUPPORT_MODELS[
pred_config.
arch] == 'keypoint_topdown', 'Detection-Keypoint unite inference only supports topdown models.'
topdown_keypoint_detector = KeyPoint_Detector(
pred_config,
FLAGS.keypoint_model_dir, FLAGS.keypoint_model_dir,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
...@@ -257,6 +219,9 @@ def main(): ...@@ -257,6 +219,9 @@ def main():
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn, enable_mkldnn=FLAGS.enable_mkldnn,
use_dark=FLAGS.use_dark) use_dark=FLAGS.use_dark)
keypoint_arch = topdown_keypoint_detector.pred_config.arch
assert KEYPOINT_SUPPORT_MODELS[
keypoint_arch] == 'keypoint_topdown', 'Detection-Keypoint unite inference only supports topdown models.'
# predict from video file or camera video stream # predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1: if FLAGS.video_file is not None or FLAGS.camera_id != -1:
......
...@@ -24,9 +24,15 @@ import paddle ...@@ -24,9 +24,15 @@ import paddle
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
import sys
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from picodet_postprocess import PicoDetPostProcess from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from visualize import visualize_box_mask from visualize import visualize_box_mask
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
...@@ -47,9 +53,27 @@ SUPPORT_MODELS = { ...@@ -47,9 +53,27 @@ SUPPORT_MODELS = {
'PicoDet', 'PicoDet',
'CenterNet', 'CenterNet',
'TOOD', 'TOOD',
'StrongBaseline',
} }
def bench_log(detector, img_list, model_info, batch_size=1, name=None):
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
data_info = {
'batch_size': batch_size,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
log = PaddleInferBenchmark(detector.config, model_info, data_info,
perf_info, mems)
log(name)
class Detector(object): class Detector(object):
""" """
Args: Args:
...@@ -65,10 +89,12 @@ class Detector(object): ...@@ -65,10 +89,12 @@ class Detector(object):
calibration, trt_calib_mode need to set True calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
output_dir (str): The path of output
threshold (float): The threshold of score for visualization
""" """
def __init__(self, def __init__(
pred_config, self,
model_dir, model_dir,
device='CPU', device='CPU',
run_mode='paddle', run_mode='paddle',
...@@ -78,8 +104,10 @@ class Detector(object): ...@@ -78,8 +104,10 @@ class Detector(object):
trt_opt_shape=640, trt_opt_shape=640,
trt_calib_mode=False, trt_calib_mode=False,
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False): enable_mkldnn=False,
self.pred_config = pred_config output_dir='output',
threshold=0.5, ):
self.pred_config = self.set_config(model_dir)
self.predictor, self.config = load_predictor( self.predictor, self.config = load_predictor(
model_dir, model_dir,
run_mode=run_mode, run_mode=run_mode,
...@@ -95,6 +123,12 @@ class Detector(object): ...@@ -95,6 +123,12 @@ class Detector(object):
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn)
self.det_times = Timer() self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
self.batch_size = batch_size
self.output_dir = output_dir
self.threshold = threshold
def set_config(self, model_dir):
return PredictConfig(model_dir)
def preprocess(self, image_list): def preprocess(self, image_list):
preprocess_ops = [] preprocess_ops = []
...@@ -110,49 +144,34 @@ class Detector(object): ...@@ -110,49 +144,34 @@ class Detector(object):
input_im_lst.append(im) input_im_lst.append(im)
input_im_info_lst.append(im_info) input_im_info_lst.append(im_info)
inputs = create_inputs(input_im_lst, input_im_info_lst) inputs = create_inputs(input_im_lst, input_im_info_lst)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
return inputs return inputs
def postprocess(self, def postprocess(self, inputs, result):
np_boxes,
np_masks,
inputs,
np_boxes_num,
threshold=0.5):
# postprocess output of predictor # postprocess output of predictor
results = {} np_boxes_num = result['boxes_num']
results['boxes'] = np_boxes if np_boxes_num[0] <= 0:
results['boxes_num'] = np_boxes_num print('[WARNNING] No object detected.')
if np_masks is not None: result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
results['masks'] = np_masks result = {k: v for k, v in result.items() if v is not None}
return results return result
def predict(self, image_list, threshold=0.5, repeats=1, add_timer=True): def predict(self, repeats=1):
''' '''
Args: Args:
image_list (list): list of image repeats (int): repeats number for prediction
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns: Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max] matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray: MaskRCNN's result include 'masks': np.ndarray:
shape: [N, im_h, im_w] shape: [N, im_h, im_w]
''' '''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list)
np_boxes, np_masks = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction # model prediction
np_boxes, np_masks = None, None
for i in range(repeats): for i in range(repeats):
self.predictor.run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
...@@ -163,32 +182,136 @@ class Detector(object): ...@@ -163,32 +182,136 @@ class Detector(object):
if self.pred_config.mask: if self.pred_config.mask:
masks_tensor = self.predictor.get_output_handle(output_names[2]) masks_tensor = self.predictor.get_output_handle(output_names[2])
np_masks = masks_tensor.copy_to_cpu() np_masks = masks_tensor.copy_to_cpu()
result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
return result
def merge_batch_result(self, batch_result):
if len(batch_result) == 1:
return batch_result[0]
res_key = batch_result[0].keys()
results = {k: [] for k in res_key}
for res in batch_result:
for k, v in res.items():
results[k].append(v)
for k, v in results.items():
results[k] = np.concatenate(v)
return results
def get_timer(self):
return self.det_times
if add_timer: def predict_image(self,
image_list,
run_benchmark=False,
repeats=1,
visual=True):
batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
results = []
for i in range(batch_loop_cnt):
start_index = i * self.batch_size
end_index = min((i + 1) * self.batch_size, len(image_list))
batch_image_list = image_list[start_index:end_index]
if run_benchmark:
# preprocess
inputs = self.preprocess(batch_image_list) # warmup
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
result = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats) self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
# postprocess # postprocess
results = [] result_warmup = self.postprocess(inputs, result) # warmup
if reduce(lambda x, y: x * y, np_boxes.shape) < 6: self.det_times.postprocess_time_s.start()
print('[WARNNING] No object detected.') result = self.postprocess(inputs, result)
results = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]} self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(batch_image_list)
cm, gm, gu = get_current_memory_mb()
self.cpu_mem += cm
self.gpu_mem += gm
self.gpu_util += gu
else: else:
results = self.postprocess( # preprocess
np_boxes, np_masks, inputs, np_boxes_num, threshold=threshold) self.det_times.preprocess_time_s.start()
if add_timer: inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
# postprocess
self.det_times.postprocess_time_s.start()
result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(image_list) self.det_times.img_num += len(batch_image_list)
if visual:
visualize(
batch_image_list,
result,
self.pred_config.labels,
output_dir=self.output_dir,
threshold=self.threshold)
results.append(result)
if visual:
print('Test iter {}'.format(i))
results = self.merge_batch_result(results)
return results return results
def get_timer(self): def predict_video(self, video_file, camera_id):
return self.det_times video_out_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(video_file)
video_out_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame: %d' % (index))
index += 1
results = self.predict_image([frame], visual=False)
im = visualize_box_mask(
frame,
results,
self.pred_config.labels,
threshold=self.threshold)
im = np.array(im)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
class DetectorSOLOv2(Detector): class DetectorSOLOv2(Detector):
""" """
Args: Args:
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
...@@ -200,10 +323,13 @@ class DetectorSOLOv2(Detector): ...@@ -200,10 +323,13 @@ class DetectorSOLOv2(Detector):
calibration, trt_calib_mode need to set True calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
output_dir (str): The path of output
threshold (float): The threshold of score for visualization
""" """
def __init__(self, def __init__(
pred_config, self,
model_dir, model_dir,
device='CPU', device='CPU',
run_mode='paddle', run_mode='paddle',
...@@ -213,48 +339,33 @@ class DetectorSOLOv2(Detector): ...@@ -213,48 +339,33 @@ class DetectorSOLOv2(Detector):
trt_opt_shape=640, trt_opt_shape=640,
trt_calib_mode=False, trt_calib_mode=False,
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False): enable_mkldnn=False,
self.pred_config = pred_config output_dir='./',
self.predictor, self.config = load_predictor( threshold=0.5, ):
model_dir, super(DetectorSOLOv2, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode, run_mode=run_mode,
batch_size=batch_size, batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape, trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape, trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape, trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode, trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads, cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn,
self.det_times = Timer() output_dir=output_dir,
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 threshold=threshold, )
def predict(self, image, threshold=0.5, repeats=1, add_timer=True): def predict(self, repeats=1):
''' '''
Args: Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns: Returns:
results (dict): 'segm': np.ndarray,shape:[N, im_h, im_w] result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
'cate_label': label of segm, shape:[N] 'cate_label': label of segm, shape:[N]
'cate_score': confidence score of segm, shape:[N] 'cate_score': confidence score of segm, shape:[N]
''' '''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image)
np_label, np_score, np_segms = None, None, None np_label, np_score, np_segms = None, None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
for i in range(repeats): for i in range(repeats):
self.predictor.run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
...@@ -266,21 +377,18 @@ class DetectorSOLOv2(Detector): ...@@ -266,21 +377,18 @@ class DetectorSOLOv2(Detector):
2]).copy_to_cpu() 2]).copy_to_cpu()
np_segms = self.predictor.get_output_handle(output_names[ np_segms = self.predictor.get_output_handle(output_names[
3]).copy_to_cpu() 3]).copy_to_cpu()
if add_timer:
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.img_num += 1
return dict( result = dict(
segm=np_segms, segm=np_segms,
label=np_label, label=np_label,
score=np_score, score=np_score,
boxes_num=np_boxes_num) boxes_num=np_boxes_num)
return result
class DetectorPicoDet(Detector): class DetectorPicoDet(Detector):
""" """
Args: Args:
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
...@@ -294,8 +402,8 @@ class DetectorPicoDet(Detector): ...@@ -294,8 +402,8 @@ class DetectorPicoDet(Detector):
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
""" """
def __init__(self, def __init__(
pred_config, self,
model_dir, model_dir,
device='CPU', device='CPU',
run_mode='paddle', run_mode='paddle',
...@@ -305,50 +413,46 @@ class DetectorPicoDet(Detector): ...@@ -305,50 +413,46 @@ class DetectorPicoDet(Detector):
trt_opt_shape=640, trt_opt_shape=640,
trt_calib_mode=False, trt_calib_mode=False,
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False): enable_mkldnn=False,
self.pred_config = pred_config output_dir='./',
self.predictor, self.config = load_predictor( threshold=0.5, ):
model_dir, super(DetectorPicoDet, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode, run_mode=run_mode,
batch_size=batch_size, batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape, trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape, trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape, trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode, trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads, cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn,
self.det_times = Timer() output_dir=output_dir,
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 threshold=threshold, )
def postprocess(self, inputs, result):
# postprocess output of predictor
np_score_list = result['boxes']
np_boxes_list = result['boxes_num']
postprocessor = PicoDetPostProcess(
inputs['image'].shape[2:],
inputs['im_shape'],
inputs['scale_factor'],
strides=self.pred_config.fpn_stride,
nms_threshold=self.pred_config.nms['nms_threshold'])
np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
return result
def predict(self, image, threshold=0.5, repeats=1, add_timer=True): def predict(self, repeats=1):
''' '''
Args: Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns: Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max] matix element:[class, score, x_min, y_min, x_max, y_max]
''' '''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
np_score_list, np_boxes_list = [], [] np_score_list, np_boxes_list = [], []
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model_prediction
for i in range(repeats): for i in range(repeats):
self.predictor.run() self.predictor.run()
np_score_list.clear() np_score_list.clear()
...@@ -362,22 +466,8 @@ class DetectorPicoDet(Detector): ...@@ -362,22 +466,8 @@ class DetectorPicoDet(Detector):
np_boxes_list.append( np_boxes_list.append(
self.predictor.get_output_handle(output_names[ self.predictor.get_output_handle(output_names[
out_idx + num_outs]).copy_to_cpu()) out_idx + num_outs]).copy_to_cpu())
if add_timer: result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
self.det_times.inference_time_s.end(repeats=repeats) return result
self.det_times.img_num += 1
self.det_times.postprocess_time_s.start()
# postprocess
self.postprocess = PicoDetPostProcess(
inputs['image'].shape[2:],
inputs['im_shape'],
inputs['scale_factor'],
strides=self.pred_config.fpn_stride,
nms_threshold=self.pred_config.nms['nms_threshold'])
np_boxes, np_boxes_num = self.postprocess(np_score_list, np_boxes_list)
if add_timer:
self.det_times.postprocess_time_s.end()
return dict(boxes=np_boxes, boxes_num=np_boxes_num)
def create_inputs(imgs, im_info): def create_inputs(imgs, im_info):
...@@ -596,26 +686,26 @@ def get_test_images(infer_dir, infer_img): ...@@ -596,26 +686,26 @@ def get_test_images(infer_dir, infer_img):
return images return images
def visualize(image_list, results, labels, output_dir='output/', threshold=0.5): def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
# visualize the predict result # visualize the predict result
start_idx = 0 start_idx = 0
for idx, image_file in enumerate(image_list): for idx, image_file in enumerate(image_list):
im_bboxes_num = results['boxes_num'][idx] im_bboxes_num = result['boxes_num'][idx]
im_results = {} im_results = {}
if 'boxes' in results: if 'boxes' in result:
im_results['boxes'] = results['boxes'][start_idx:start_idx + im_results['boxes'] = result['boxes'][start_idx:start_idx +
im_bboxes_num, :] im_bboxes_num, :]
if 'masks' in results: if 'masks' in result:
im_results['masks'] = results['masks'][start_idx:start_idx + im_results['masks'] = result['masks'][start_idx:start_idx +
im_bboxes_num, :] im_bboxes_num, :]
if 'segm' in results: if 'segm' in result:
im_results['segm'] = results['segm'][start_idx:start_idx + im_results['segm'] = result['segm'][start_idx:start_idx +
im_bboxes_num, :] im_bboxes_num, :]
if 'label' in results: if 'label' in result:
im_results['label'] = results['label'][start_idx:start_idx + im_results['label'] = result['label'][start_idx:start_idx +
im_bboxes_num] im_bboxes_num]
if 'score' in results: if 'score' in result:
im_results['score'] = results['score'][start_idx:start_idx + im_results['score'] = result['score'][start_idx:start_idx +
im_bboxes_num] im_bboxes_num]
start_idx += im_bboxes_num start_idx += im_bboxes_num
...@@ -636,86 +726,18 @@ def print_arguments(args): ...@@ -636,86 +726,18 @@ def print_arguments(args):
print('------------------------------------------') print('------------------------------------------')
def predict_image(detector, image_list, batch_size=1):
batch_loop_cnt = math.ceil(float(len(image_list)) / batch_size)
for i in range(batch_loop_cnt):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(image_list))
batch_image_list = image_list[start_index:end_index]
if FLAGS.run_benchmark:
# warmup
detector.predict(
batch_image_list, FLAGS.threshold, repeats=10, add_timer=False)
# run benchmark
detector.predict(
batch_image_list, FLAGS.threshold, repeats=10, add_timer=True)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}'.format(i))
else:
results = detector.predict(batch_image_list, FLAGS.threshold)
visualize(
batch_image_list,
results,
detector.pred_config.labels,
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold)
def predict_video(detector, camera_id):
video_out_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_out_name = os.path.split(FLAGS.video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame: %d' % (index))
index += 1
results = detector.predict([frame], FLAGS.threshold)
im = visualize_box_mask(
frame,
results,
detector.pred_config.labels,
threshold=FLAGS.threshold)
im = np.array(im)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
def main(): def main():
pred_config = PredictConfig(FLAGS.model_dir) deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
detector_func = 'Detector' detector_func = 'Detector'
if pred_config.arch == 'SOLOv2': if arch == 'SOLOv2':
detector_func = 'DetectorSOLOv2' detector_func = 'DetectorSOLOv2'
elif pred_config.arch == 'PicoDet': elif arch == 'PicoDet':
detector_func = 'DetectorPicoDet' detector_func = 'DetectorPicoDet'
detector = eval(detector_func)(pred_config, detector = eval(detector_func)(FLAGS.model_dir,
FLAGS.model_dir,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
...@@ -724,41 +746,29 @@ def main(): ...@@ -724,41 +746,29 @@ def main():
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode, trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn) enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir)
# predict from video file or camera video stream # predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1: if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id) detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
else: else:
# predict from image # predict from image
if FLAGS.image_dir is None and FLAGS.image_file is not None: if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None" assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list, FLAGS.batch_size) detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
detector.det_times.info(average=True) detector.det_times.info(average=True)
else: else:
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
model_dir = FLAGS.model_dir
mode = FLAGS.run_mode mode = FLAGS.run_mode
model_dir = FLAGS.model_dir
model_info = { model_info = {
'model_name': model_dir.strip('/').split('/')[-1], 'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1] 'precision': mode.split('_')[-1]
} }
data_info = { bench_log(detector, img_list, model_info, name='DET')
'batch_size': FLAGS.batch_size,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
det_log = PaddleInferBenchmark(detector.config, model_info,
data_info, perf_info, mems)
det_log('Det')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -23,10 +23,16 @@ import cv2 ...@@ -23,10 +23,16 @@ import cv2
import math import math
import numpy as np import numpy as np
import paddle import paddle
import sys
# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from preprocess import preprocess, NormalizeImage, Permute from preprocess import preprocess, NormalizeImage, Permute
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess
from visualize import draw_pose from visualize import visualize_pose
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
...@@ -40,13 +46,13 @@ KEYPOINT_SUPPORT_MODELS = { ...@@ -40,13 +46,13 @@ KEYPOINT_SUPPORT_MODELS = {
} }
class KeyPoint_Detector(Detector): class KeyPointDetector(Detector):
""" """
Args: Args:
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt trt_opt_shape (int): opt shape for dynamic shape in trt
...@@ -58,7 +64,6 @@ class KeyPoint_Detector(Detector): ...@@ -58,7 +64,6 @@ class KeyPoint_Detector(Detector):
""" """
def __init__(self, def __init__(self,
pred_config,
model_dir, model_dir,
device='CPU', device='CPU',
run_mode='paddle', run_mode='paddle',
...@@ -69,9 +74,10 @@ class KeyPoint_Detector(Detector): ...@@ -69,9 +74,10 @@ class KeyPoint_Detector(Detector):
trt_calib_mode=False, trt_calib_mode=False,
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False, enable_mkldnn=False,
output_dir='output',
threshold=0.5,
use_dark=True): use_dark=True):
super(KeyPoint_Detector, self).__init__( super(KeyPointDetector, self).__init__(
pred_config=pred_config,
model_dir=model_dir, model_dir=model_dir,
device=device, device=device,
run_mode=run_mode, run_mode=run_mode,
...@@ -81,9 +87,14 @@ class KeyPoint_Detector(Detector): ...@@ -81,9 +87,14 @@ class KeyPoint_Detector(Detector):
trt_opt_shape=trt_opt_shape, trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode, trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads, cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
self.use_dark = use_dark self.use_dark = use_dark
def set_config(self, model_dir):
return PredictConfig_KeyPoint(model_dir)
def get_person_from_rect(self, image, results, det_threshold=0.5): def get_person_from_rect(self, image, results, det_threshold=0.5):
# crop the person result from image # crop the person result from image
self.det_times.preprocess_time_s.start() self.det_times.preprocess_time_s.start()
...@@ -103,34 +114,22 @@ class KeyPoint_Detector(Detector): ...@@ -103,34 +114,22 @@ class KeyPoint_Detector(Detector):
self.det_times.preprocess_time_s.end() self.det_times.preprocess_time_s.end()
return rect_images, new_rects, org_rects return rect_images, new_rects, org_rects
def preprocess(self, image_list): def postprocess(self, inputs, result):
preprocess_ops = [] np_heatmap = result['heatmap']
for op_info in self.pred_config.preprocess_infos: np_masks = result['masks']
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
input_im_lst = []
input_im_info_lst = []
for im in image_list:
im, im_info = preprocess(im, preprocess_ops)
input_im_lst.append(im)
input_im_info_lst.append(im_info)
inputs = create_inputs(input_im_lst, input_im_info_lst)
return inputs
def postprocess(self, np_boxes, np_masks, inputs, threshold=0.5):
# postprocess output of predictor # postprocess output of predictor
if KEYPOINT_SUPPORT_MODELS[ if KEYPOINT_SUPPORT_MODELS[
self.pred_config.arch] == 'keypoint_bottomup': self.pred_config.arch] == 'keypoint_bottomup':
results = {} results = {}
h, w = inputs['im_shape'][0] h, w = inputs['im_shape'][0]
preds = [np_boxes] preds = [np_heatmap]
if np_masks is not None: if np_masks is not None:
preds += np_masks preds += np_masks
preds += [h, w] preds += [h, w]
keypoint_postprocess = HrHRNetPostProcess() keypoint_postprocess = HrHRNetPostProcess()
results['keypoint'] = keypoint_postprocess(*preds) kpts, scores = keypoint_postprocess(*preds)
results['keypoint'] = kpts
results['score'] = scores
return results return results
elif KEYPOINT_SUPPORT_MODELS[ elif KEYPOINT_SUPPORT_MODELS[
self.pred_config.arch] == 'keypoint_topdown': self.pred_config.arch] == 'keypoint_topdown':
...@@ -139,44 +138,31 @@ class KeyPoint_Detector(Detector): ...@@ -139,44 +138,31 @@ class KeyPoint_Detector(Detector):
center = np.round(imshape / 2.) center = np.round(imshape / 2.)
scale = imshape / 200. scale = imshape / 200.
keypoint_postprocess = HRNetPostProcess(use_dark=self.use_dark) keypoint_postprocess = HRNetPostProcess(use_dark=self.use_dark)
results['keypoint'] = keypoint_postprocess(np_boxes, center, scale) kpts, scores = keypoint_postprocess(np_heatmap, center, scale)
results['keypoint'] = kpts
results['score'] = scores
return results return results
else: else:
raise ValueError("Unsupported arch: {}, expect {}".format( raise ValueError("Unsupported arch: {}, expect {}".format(
self.pred_config.arch, KEYPOINT_SUPPORT_MODELS)) self.pred_config.arch, KEYPOINT_SUPPORT_MODELS))
def predict(self, image_list, threshold=0.5, repeats=1, add_timer=True): def predict(self, repeats=1):
''' '''
Args: Args:
image_list (list): list of image
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns: Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max] matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray: MaskRCNN's results include 'masks': np.ndarray:
shape: [N, im_h, im_w] shape: [N, im_h, im_w]
''' '''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list)
np_boxes, np_masks = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction # model prediction
np_heatmap, np_masks = None, None
for i in range(repeats): for i in range(repeats):
self.predictor.run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0]) heatmap_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu() np_heatmap = heatmap_tensor.copy_to_cpu()
if self.pred_config.tagmap: if self.pred_config.tagmap:
masks_tensor = self.predictor.get_output_handle(output_names[1]) masks_tensor = self.predictor.get_output_handle(output_names[1])
heat_k = self.predictor.get_output_handle(output_names[2]) heat_k = self.predictor.get_output_handle(output_names[2])
...@@ -185,18 +171,113 @@ class KeyPoint_Detector(Detector): ...@@ -185,18 +171,113 @@ class KeyPoint_Detector(Detector):
masks_tensor.copy_to_cpu(), heat_k.copy_to_cpu(), masks_tensor.copy_to_cpu(), heat_k.copy_to_cpu(),
inds_k.copy_to_cpu() inds_k.copy_to_cpu()
] ]
if add_timer: result = dict(heatmap=np_heatmap, masks=np_masks)
return result
def predict_image(self,
image_list,
run_benchmark=False,
repeats=1,
visual=True):
results = []
batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
for i in range(batch_loop_cnt):
start_index = i * self.batch_size
end_index = min((i + 1) * self.batch_size, len(image_list))
batch_image_list = image_list[start_index:end_index]
if run_benchmark:
# preprocess
inputs = self.preprocess(batch_image_list) # warmup
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
result_warmup = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats) self.det_times.inference_time_s.end(repeats=repeats)
# postprocess
result_warmup = self.postprocess(inputs, result) # warmup
self.det_times.postprocess_time_s.start() self.det_times.postprocess_time_s.start()
result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(batch_image_list)
cm, gm, gu = get_current_memory_mb()
self.cpu_mem += cm
self.gpu_mem += gm
self.gpu_util += gu
else:
# preprocess
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
# postprocess # postprocess
results = self.postprocess( self.det_times.postprocess_time_s.start()
np_boxes, np_masks, inputs, threshold=threshold) result = self.postprocess(inputs, result)
if add_timer:
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(image_list) self.det_times.img_num += len(batch_image_list)
if visual:
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
visualize(
batch_image_list,
result,
visual_thresh=self.threshold,
save_dir=self.output_dir)
results.append(result)
if visual:
print('Test iter {}'.format(i))
results = self.merge_batch_result(results)
return results return results
def predict_video(self, video_file, camera_id):
video_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(video_file)
video_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame: %d' % (index))
index += 1
results = self.predict_image([frame], visual=False)
im = visualize_pose(
frame, results, visual_thresh=self.threshold, returnimg=True)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
def create_inputs(imgs, im_info): def create_inputs(imgs, im_info):
"""generate input for different model type """generate input for different model type
...@@ -258,90 +339,44 @@ class PredictConfig_KeyPoint(): ...@@ -258,90 +339,44 @@ class PredictConfig_KeyPoint():
print('--------------------------------------------') print('--------------------------------------------')
def predict_image(detector, image_list): def visualize(image_list, results, visual_thresh=0.6, save_dir='output'):
for i, img_file in enumerate(image_list): im_results = {}
if FLAGS.run_benchmark: for i, image_file in enumerate(image_list):
# warmup skeletons = results['keypoint']
detector.predict( scores = results['score']
[img_file], FLAGS.threshold, repeats=10, add_timer=False) skeleton = skeletons[i:i + 1]
# run benchmark score = scores[i:i + 1]
detector.predict( im_results['keypoint'] = [skeleton, score]
[img_file], FLAGS.threshold, repeats=10, add_timer=True) visualize_pose(
cm, gm, gu = get_current_memory_mb() image_file,
detector.cpu_mem += cm im_results,
detector.gpu_mem += gm visual_thresh=visual_thresh,
detector.gpu_util += gu save_dir=save_dir)
print('Test iter {}, file name:{}'.format(i, img_file))
else:
results = detector.predict([img_file], FLAGS.threshold)
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
draw_pose(
img_file,
results,
visual_thread=FLAGS.threshold,
save_dir=FLAGS.output_dir)
def predict_video(detector, camera_id):
video_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name + '.mp4')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
ret, frame = capture.read()
if not ret:
break
print('detect frame: %d' % (index))
index += 1
results = detector.predict([frame], FLAGS.threshold)
im = draw_pose(
frame, results, visual_thread=FLAGS.threshold, returnimg=True)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
def main(): def main():
pred_config = PredictConfig_KeyPoint(FLAGS.model_dir) detector = KeyPointDetector(
detector = KeyPoint_Detector(
pred_config,
FLAGS.model_dir, FLAGS.model_dir,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
trt_min_shape=FLAGS.trt_min_shape, trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape, trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode, trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn, enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
use_dark=FLAGS.use_dark) use_dark=FLAGS.use_dark)
# predict from video file or camera video stream # predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1: if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id) detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
else: else:
# predict from image # predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list) detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
detector.det_times.info(average=True) detector.det_times.info(average=True)
else: else:
......
...@@ -362,7 +362,8 @@ def affine_transform(pt, t): ...@@ -362,7 +362,8 @@ def affine_transform(pt, t):
def translate_to_ori_images(keypoint_result, batch_records): def translate_to_ori_images(keypoint_result, batch_records):
kpts, scores = keypoint_result['keypoint'] kpts = keypoint_result['keypoint']
scores = keypoint_result['score']
kpts[..., 0] += batch_records[:, 0:1] kpts[..., 0] += batch_records[:, 0:1]
kpts[..., 1] += batch_records[:, 1:2] kpts[..., 1] += batch_records[:, 1:2]
return kpts, scores return kpts, scores
...@@ -18,21 +18,24 @@ import yaml ...@@ -18,21 +18,24 @@ import yaml
import cv2 import cv2
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
import paddle import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, PredictConfig
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from ppdet.modeling.mot.tracker import JDETracker from pptracking.python.mot import JDETracker
from ppdet.modeling.mot.visualization import plot_tracking_dict from pptracking.python.mot.utils import MOTTimer, write_mot_results
from ppdet.modeling.mot.utils import MOTTimer, write_mot_results from pptracking.python.visualize import plot_tracking, plot_tracking_dict
# Global dictionary # Global dictionary
MOT_SUPPORT_MODELS = { MOT_JDE_SUPPORT_MODELS = {
'JDE', 'JDE',
'FairMOT', 'FairMOT',
} }
...@@ -41,7 +44,6 @@ MOT_SUPPORT_MODELS = { ...@@ -41,7 +44,6 @@ MOT_SUPPORT_MODELS = {
class JDE_Detector(Detector): class JDE_Detector(Detector):
""" """
Args: Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
...@@ -56,8 +58,8 @@ class JDE_Detector(Detector): ...@@ -56,8 +58,8 @@ class JDE_Detector(Detector):
""" """
def __init__(self, def __init__(self,
pred_config,
model_dir, model_dir,
tracker_config=None,
device='CPU', device='CPU',
run_mode='paddle', run_mode='paddle',
batch_size=1, batch_size=1,
...@@ -66,9 +68,10 @@ class JDE_Detector(Detector): ...@@ -66,9 +68,10 @@ class JDE_Detector(Detector):
trt_opt_shape=608, trt_opt_shape=608,
trt_calib_mode=False, trt_calib_mode=False,
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False): enable_mkldnn=False,
output_dir='output',
threshold=0.5):
super(JDE_Detector, self).__init__( super(JDE_Detector, self).__init__(
pred_config=pred_config,
model_dir=model_dir, model_dir=model_dir,
device=device, device=device,
run_mode=run_mode, run_mode=run_mode,
...@@ -78,17 +81,21 @@ class JDE_Detector(Detector): ...@@ -78,17 +81,21 @@ class JDE_Detector(Detector):
trt_opt_shape=trt_opt_shape, trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode, trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads, cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn,
assert batch_size == 1, "The JDE Detector only supports batch size=1 now" output_dir=output_dir,
assert pred_config.tracker, "Tracking model should have tracker" threshold=threshold, )
self.num_classes = len(pred_config.labels) assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True)
tp = pred_config.tracker self.num_classes = len(self.pred_config.labels)
min_box_area = tp['min_box_area'] if 'min_box_area' in tp else 200
vertical_ratio = tp['vertical_ratio'] if 'vertical_ratio' in tp else 1.6 # tracker config
conf_thres = tp['conf_thres'] if 'conf_thres' in tp else 0. assert self.pred_config.tracker, "The exported JDE Detector model should have tracker."
tracked_thresh = tp['tracked_thresh'] if 'tracked_thresh' in tp else 0.7 cfg = self.pred_config.tracker
metric_type = tp['metric_type'] if 'metric_type' in tp else 'euclidean' min_box_area = cfg.get('min_box_area', 200)
vertical_ratio = cfg.get('vertical_ratio', 1.6)
conf_thres = cfg.get('conf_thres', 0.0)
tracked_thresh = cfg.get('tracked_thresh', 0.7)
metric_type = cfg.get('metric_type', 'euclidean')
self.tracker = JDETracker( self.tracker = JDETracker(
num_classes=self.num_classes, num_classes=self.num_classes,
...@@ -98,7 +105,18 @@ class JDE_Detector(Detector): ...@@ -98,7 +105,18 @@ class JDE_Detector(Detector):
tracked_thresh=tracked_thresh, tracked_thresh=tracked_thresh,
metric_type=metric_type) metric_type=metric_type)
def postprocess(self, pred_dets, pred_embs, threshold): def postprocess(self, inputs, result):
# postprocess output of predictor
np_boxes = result['pred_dets']
if np_boxes.shape[0] <= 0:
print('[WARNNING] No object detected.')
result = {'pred_dets': np.zeros([0, 6]), 'pred_embs': None}
result = {k: v for k, v in result.items() if v is not None}
return result
def tracking(self, det_results):
pred_dets = det_results['pred_dets']
pred_embs = det_results['pred_embs']
online_targets_dict = self.tracker.update(pred_dets, pred_embs) online_targets_dict = self.tracker.update(pred_dets, pred_embs)
online_tlwhs = defaultdict(list) online_tlwhs = defaultdict(list)
...@@ -110,7 +128,6 @@ class JDE_Detector(Detector): ...@@ -110,7 +128,6 @@ class JDE_Detector(Detector):
tlwh = t.tlwh tlwh = t.tlwh
tid = t.track_id tid = t.track_id
tscore = t.score tscore = t.score
if tscore < threshold: continue
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio: 3] > self.tracker.vertical_ratio:
...@@ -120,100 +137,123 @@ class JDE_Detector(Detector): ...@@ -120,100 +137,123 @@ class JDE_Detector(Detector):
online_scores[cls_id].append(tscore) online_scores[cls_id].append(tscore)
return online_tlwhs, online_scores, online_ids return online_tlwhs, online_scores, online_ids
def predict(self, image_list, threshold=0.5, repeats=1, add_timer=True): def predict(self, repeats=1):
''' '''
Args: Args:
image_list (list): list of image repeats (int): repeats number for prediction
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns: Returns:
online_tlwhs, online_scores, online_ids (dict[np.array]) result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box,
matix element:[x_min, y_min, x_max, y_max, score, class]
FairMOT(JDE)'s result include 'pred_embs': np.ndarray:
shape: [N, 128]
''' '''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image_list)
pred_dets, pred_embs = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction # model prediction
np_pred_dets, np_pred_embs = None, None
for i in range(repeats): for i in range(repeats):
self.predictor.run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0]) boxes_tensor = self.predictor.get_output_handle(output_names[0])
pred_dets = boxes_tensor.copy_to_cpu() np_pred_dets = boxes_tensor.copy_to_cpu()
embs_tensor = self.predictor.get_output_handle(output_names[1]) embs_tensor = self.predictor.get_output_handle(output_names[1])
pred_embs = embs_tensor.copy_to_cpu() np_pred_embs = embs_tensor.copy_to_cpu()
result = dict(pred_dets=np_pred_dets, pred_embs=np_pred_embs)
return result
def predict_image(self,
image_list,
run_benchmark=False,
repeats=1,
visual=True):
mot_results = []
num_classes = self.num_classes
image_list.sort()
ids2names = self.pred_config.labels
data_type = 'mcmot' if num_classes > 1 else 'mot'
for frame_id, img_file in enumerate(image_list):
batch_image_list = [img_file] # bs=1 in MOT model
if run_benchmark:
# preprocess
inputs = self.preprocess(batch_image_list) # warmup
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
if add_timer: # model prediction
result_warmup = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats) self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
# postprocess # postprocess
online_tlwhs, online_scores, online_ids = self.postprocess( result_warmup = self.postprocess(inputs, result) # warmup
pred_dets, pred_embs, threshold) self.det_times.postprocess_time_s.start()
if add_timer: det_result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return online_tlwhs, online_scores, online_ids
def predict_image(detector, image_list): # tracking
results = [] result_warmup = self.tracking(det_result)
num_classes = detector.num_classes self.det_times.tracking_time_s.start()
data_type = 'mcmot' if num_classes > 1 else 'mot' online_tlwhs, online_scores, online_ids = self.tracking(
ids2names = detector.pred_config.labels det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
image_list.sort()
for frame_id, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
if FLAGS.run_benchmark:
# warmup
detector.predict(
[frame], FLAGS.threshold, repeats=10, add_timer=False)
# run benchmark
detector.predict(
[frame], FLAGS.threshold, repeats=10, add_timer=True)
cm, gm, gu = get_current_memory_mb() cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm self.cpu_mem += cm
detector.gpu_mem += gm self.gpu_mem += gm
detector.gpu_util += gu self.gpu_util += gu
print('Test iter {}, file name:{}'.format(frame_id, img_file))
else: else:
online_tlwhs, online_scores, online_ids = detector.predict( self.det_times.preprocess_time_s.start()
[frame], FLAGS.threshold) inputs = self.preprocess(batch_image_list)
online_im = plot_tracking_dict( self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
self.det_times.postprocess_time_s.start()
det_result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
# tracking process
self.det_times.tracking_time_s.start()
online_tlwhs, online_scores, online_ids = self.tracking(
det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
if visual:
if frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {})
im = plot_tracking_dict(
frame, frame,
num_classes, num_classes,
online_tlwhs, online_tlwhs,
online_ids, online_ids,
online_scores, online_scores,
frame_id, frame_id=frame_id,
ids2names=ids2names) ids2names=ids2names)
if FLAGS.save_images: seq_name = image_list[0].split('/')[-2]
if not os.path.exists(FLAGS.output_dir): save_dir = os.path.join(self.output_dir, seq_name)
os.makedirs(FLAGS.output_dir) if not os.path.exists(save_dir):
img_name = os.path.split(img_file)[-1] os.makedirs(save_dir)
out_path = os.path.join(FLAGS.output_dir, img_name) cv2.imwrite(
cv2.imwrite(out_path, online_im) os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
print("save result to: " + out_path)
mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results
def predict_video(detector, camera_id): def predict_video(self, video_file, camera_id):
video_name = 'mot_output.mp4' video_out_name = 'mot_output.mp4'
if camera_id != -1: if camera_id != -1:
capture = cv2.VideoCapture(camera_id) capture = cv2.VideoCapture(camera_id)
else: else:
capture = cv2.VideoCapture(FLAGS.video_file) capture = cv2.VideoCapture(video_file)
video_name = os.path.split(FLAGS.video_file)[-1] video_out_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count # Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
...@@ -221,33 +261,37 @@ def predict_video(detector, camera_id): ...@@ -221,33 +261,37 @@ def predict_video(detector, camera_id):
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count)) print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(self.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(self.output_dir, video_out_name)
if not FLAGS.save_images: fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
frame_id = 1
timer = MOTTimer() timer = MOTTimer()
results = defaultdict(list) # support single class and multi classes results = defaultdict(list) # support single class and multi classes
num_classes = detector.num_classes num_classes = self.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot' data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = detector.pred_config.labels ids2names = self.pred_config.labels
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
break break
if frame_id % 10 == 0:
print('Tracking frame: %d' % (frame_id))
frame_id += 1
timer.tic() timer.tic()
online_tlwhs, online_scores, online_ids = detector.predict( mot_results = self.predict_image([frame], visual=False)
[frame], FLAGS.threshold)
timer.toc() timer.toc()
online_tlwhs, online_scores, online_ids = mot_results[0]
for cls_id in range(num_classes): for cls_id in range(num_classes):
results[cls_id].append((frame_id + 1, online_tlwhs[cls_id], results[cls_id].append(
online_scores[cls_id], online_ids[cls_id])) (frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
fps = 1. / timer.average_time fps = 1. / timer.duration
im = plot_tracking_dict( im = plot_tracking_dict(
frame, frame,
num_classes, num_classes,
...@@ -257,41 +301,17 @@ def predict_video(detector, camera_id): ...@@ -257,41 +301,17 @@ def predict_video(detector, camera_id):
frame_id=frame_id, frame_id=frame_id,
fps=fps, fps=fps,
ids2names=ids2names) ids2names=ids2names)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
else:
writer.write(im)
frame_id += 1 writer.write(im)
print('detect frame: %d' % (frame_id))
if camera_id != -1: if camera_id != -1:
cv2.imshow('Tracking Detection', im) cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results, data_type, num_classes)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir,
out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release() writer.release()
def main(): def main():
pred_config = PredictConfig(FLAGS.model_dir)
detector = JDE_Detector( detector = JDE_Detector(
pred_config,
FLAGS.model_dir, FLAGS.model_dir,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
...@@ -304,34 +324,22 @@ def main(): ...@@ -304,34 +324,22 @@ def main():
# predict from video file or camera video stream # predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1: if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id) detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
else: else:
# predict from image # predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list) detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
detector.det_times.info(average=True) detector.det_times.info(average=True)
else: else:
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
model_dir = FLAGS.model_dir
mode = FLAGS.run_mode mode = FLAGS.run_mode
model_dir = FLAGS.model_dir
model_info = { model_info = {
'model_name': model_dir.strip('/').split('/')[-1], 'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1] 'precision': mode.split('_')[-1]
} }
data_info = { bench_log(detector, img_list, model_info, name='MOT')
'batch_size': 1,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
det_log = PaddleInferBenchmark(detector.config, model_info,
data_info, perf_info, mems)
det_log('MOT')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -13,31 +13,34 @@ ...@@ -13,31 +13,34 @@
# limitations under the License. # limitations under the License.
import os import os
import json
import cv2 import cv2
import math import math
import copy
import numpy as np import numpy as np
from collections import defaultdict
import paddle import paddle
import yaml
from utils import get_current_memory_mb import copy
from infer import Detector, PredictConfig, print_arguments, get_test_images from collections import defaultdict
from visualize import draw_pose
from mot_keypoint_unite_utils import argsparser from mot_keypoint_unite_utils import argsparser
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint from preprocess import decode_image
from det_keypoint_unite_infer import predict_with_given_det, bench_log from infer import print_arguments, get_test_images
from mot_jde_infer import JDE_Detector from mot_sde_infer import SDE_Detector, MOT_SDE_SUPPORT_MODELS
from mot_jde_infer import JDE_Detector, MOT_JDE_SUPPORT_MODELS
from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS
from det_keypoint_unite_infer import predict_with_given_det
from visualize import visualize_pose
from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images
from ppdet.modeling.mot.visualization import plot_tracking_dict # add python path
from ppdet.modeling.mot.utils import MOTTimer as FPSTimer import sys
from ppdet.modeling.mot.utils import write_mot_results parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
# Global dictionary from pptracking.python.visualize import plot_tracking, plot_tracking_dict
KEYPOINT_SUPPORT_MODELS = { from pptracking.python.mot.utils import MOTTimer as FPSTimer
'HigherHRNet': 'keypoint_bottomup',
'HRNet': 'keypoint_topdown'
}
def convert_mot_to_det(tlwhs, scores): def convert_mot_to_det(tlwhs, scores):
...@@ -49,94 +52,87 @@ def convert_mot_to_det(tlwhs, scores): ...@@ -49,94 +52,87 @@ def convert_mot_to_det(tlwhs, scores):
# support single class now # support single class now
results['boxes'] = np.vstack( results['boxes'] = np.vstack(
[np.hstack([0, scores[i], xyxys[i]]) for i in range(num_mot)]) [np.hstack([0, scores[i], xyxys[i]]) for i in range(num_mot)])
results['boxes_num'] = np.array([num_mot])
return results return results
def mot_keypoint_unite_predict_image(mot_model, def mot_topdown_unite_predict(mot_detector,
keypoint_model, topdown_keypoint_detector,
image_list, image_list,
keypoint_batch_size=1): keypoint_batch_size=1,
num_classes = mot_model.num_classes save_res=False):
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.' det_timer = mot_detector.get_timer()
data_type = 'mot' store_res = []
image_list.sort() image_list.sort()
num_classes = mot_detector.num_classes
for i, img_file in enumerate(image_list): for i, img_file in enumerate(image_list):
frame = cv2.imread(img_file) # Decode image in advance in mot + pose prediction
det_timer.preprocess_time_s.start()
image, _ = decode_image(img_file, {})
det_timer.preprocess_time_s.end()
if FLAGS.run_benchmark: if FLAGS.run_benchmark:
# warmup mot_results = mot_detector.predict_image(
online_tlwhs, online_scores, online_ids = mot_model.predict( [image], run_benchmark=True, repeats=10)
[frame], FLAGS.mot_threshold, repeats=10, add_timer=False)
# run benchmark
online_tlwhs, online_scores, online_ids = mot_model.predict(
[frame], FLAGS.mot_threshold, repeats=10, add_timer=True)
cm, gm, gu = get_current_memory_mb()
mot_model.cpu_mem += cm
mot_model.gpu_mem += gm
mot_model.gpu_util += gu
else:
online_tlwhs, online_scores, online_ids = mot_model.predict(
[frame], FLAGS.mot_threshold)
keypoint_arch = keypoint_model.pred_config.arch
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown':
results = convert_mot_to_det(online_tlwhs, online_scores)
keypoint_results = predict_with_given_det(
frame, results, keypoint_model, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold,
FLAGS.run_benchmark)
cm, gm, gu = get_current_memory_mb()
mot_detector.cpu_mem += cm
mot_detector.gpu_mem += gm
mot_detector.gpu_util += gu
else: else:
if FLAGS.run_benchmark: mot_results = mot_detector.predict_image([image], visual=False)
keypoint_results = keypoint_model.predict(
[frame], online_tlwhs, online_scores, online_ids = mot_results[
FLAGS.keypoint_threshold, 0] # only support bs=1 in MOT model
repeats=10, results = convert_mot_to_det(
add_timer=False) online_tlwhs[0],
online_scores[0]) # only support single class for mot + pose
repeats = 10 if FLAGS.run_benchmark else 1 if results['boxes_num'] == 0:
keypoint_results = keypoint_model.predict( continue
[frame], FLAGS.keypoint_threshold, repeats=repeats)
keypoint_res = predict_with_given_det(
image, results, topdown_keypoint_detector, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
if save_res:
store_res.append([
i, keypoint_res['bbox'],
[keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
])
if FLAGS.run_benchmark: if FLAGS.run_benchmark:
cm, gm, gu = get_current_memory_mb() cm, gm, gu = get_current_memory_mb()
keypoint_model.cpu_mem += cm topdown_keypoint_detector.cpu_mem += cm
keypoint_model.gpu_mem += gm topdown_keypoint_detector.gpu_mem += gm
keypoint_model.gpu_util += gu topdown_keypoint_detector.gpu_util += gu
else: else:
im = draw_pose(
frame,
keypoint_results,
visual_thread=FLAGS.keypoint_threshold,
returnimg=True,
ids=online_ids[0]
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown'
else None)
online_im = plot_tracking_dict(
im,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=i)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
img_name = os.path.split(img_file)[-1] visualize_pose(
out_path = os.path.join(FLAGS.output_dir, img_name) img_file,
cv2.imwrite(out_path, online_im) keypoint_res,
print("save result to: " + out_path) visual_thresh=FLAGS.keypoint_threshold,
save_dir=FLAGS.output_dir)
def mot_keypoint_unite_predict_video(mot_model, if save_res:
keypoint_model, """
1) store_res: a list of image_data
2) image_data: [imageid, rects, [keypoints, scores]]
3) rects: list of rect [xmin, ymin, xmax, ymax]
4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
5) scores: mean of all joint conf
"""
with open("det_keypoint_unite_image_results.json", 'w') as wf:
json.dump(store_res, wf, indent=4)
def mot_topdown_unite_predict_video(mot_detector,
topdown_keypoint_detector,
camera_id, camera_id,
keypoint_batch_size=1): keypoint_batch_size=1,
save_res=False):
video_name = 'output.mp4'
if camera_id != -1: if camera_id != -1:
capture = cv2.VideoCapture(camera_id) capture = cv2.VideoCapture(camera_id)
video_name = 'output.mp4'
else: else:
capture = cv2.VideoCapture(FLAGS.video_file) capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1] video_name = os.path.split(FLAGS.video_file)[-1]
...@@ -150,17 +146,12 @@ def mot_keypoint_unite_predict_video(mot_model, ...@@ -150,17 +146,12 @@ def mot_keypoint_unite_predict_video(mot_model,
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(FLAGS.output_dir, video_name)
if not FLAGS.save_images: fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
timer_mot = FPSTimer() timer_mot, timer_kp, timer_mot_kp = FPSTimer(), FPSTimer(), FPSTimer()
timer_kp = FPSTimer()
timer_mot_kp = FPSTimer()
# support single class and multi classes, but should be single class here num_classes = mot_detector.num_classes
mot_results = defaultdict(list)
num_classes = mot_model.num_classes
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.' assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
data_type = 'mot' data_type = 'mot'
...@@ -168,43 +159,41 @@ def mot_keypoint_unite_predict_video(mot_model, ...@@ -168,43 +159,41 @@ def mot_keypoint_unite_predict_video(mot_model,
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
break break
if frame_id % 10 == 0:
print('Tracking frame: %d' % (frame_id))
frame_id += 1
timer_mot_kp.tic() timer_mot_kp.tic()
# mot model
timer_mot.tic() timer_mot.tic()
online_tlwhs, online_scores, online_ids = mot_model.predict( mot_results = mot_detector.predict_image([frame], visual=False)
[frame], FLAGS.mot_threshold)
timer_mot.toc() timer_mot.toc()
mot_results[0].append( online_tlwhs, online_scores, online_ids = mot_results[0]
(frame_id + 1, online_tlwhs[0], online_scores[0], online_ids[0])) results = convert_mot_to_det(
mot_fps = 1. / timer_mot.average_time online_tlwhs[0],
online_scores[0]) # only support single class for mot + pose
if results['boxes_num'] == 0:
continue
# keypoint model
timer_kp.tic() timer_kp.tic()
keypoint_res = predict_with_given_det(
keypoint_arch = keypoint_model.pred_config.arch frame, results, topdown_keypoint_detector, keypoint_batch_size,
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown': FLAGS.mot_threshold, FLAGS.keypoint_threshold, FLAGS.run_benchmark)
results = convert_mot_to_det(online_tlwhs[0], online_scores[0])
keypoint_results = predict_with_given_det(
frame, results, keypoint_model, keypoint_batch_size,
FLAGS.mot_threshold, FLAGS.keypoint_threshold,
FLAGS.run_benchmark)
else:
keypoint_results = keypoint_model.predict([frame],
FLAGS.keypoint_threshold)
timer_kp.toc() timer_kp.toc()
timer_mot_kp.toc() timer_mot_kp.toc()
kp_fps = 1. / timer_kp.average_time
mot_kp_fps = 1. / timer_mot_kp.average_time
im = draw_pose( kp_fps = 1. / timer_kp.duration
mot_kp_fps = 1. / timer_mot_kp.duration
im = visualize_pose(
frame, frame,
keypoint_results, keypoint_res,
visual_thread=FLAGS.keypoint_threshold, visual_thresh=FLAGS.keypoint_threshold,
returnimg=True, returnimg=True,
ids=online_ids[0] ids=online_ids[0])
if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown' else
None)
online_im = plot_tracking_dict( im = plot_tracking_dict(
im, im,
num_classes, num_classes,
online_tlwhs, online_tlwhs,
...@@ -213,55 +202,40 @@ def mot_keypoint_unite_predict_video(mot_model, ...@@ -213,55 +202,40 @@ def mot_keypoint_unite_predict_video(mot_model,
frame_id=frame_id, frame_id=frame_id,
fps=mot_kp_fps) fps=mot_kp_fps)
im = np.array(online_im)
frame_id += 1
print('detect frame: %d' % (frame_id))
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
else:
writer.write(im) writer.write(im)
if camera_id != -1: if camera_id != -1:
cv2.imshow('Tracking and keypoint results', im) cv2.imshow('Tracking and keypoint results', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, mot_results, data_type, num_classes)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir,
out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release() writer.release()
print('output_video saved to: {}'.format(out_path))
def main(): def main():
pred_config = PredictConfig(FLAGS.mot_model_dir) deploy_file = os.path.join(FLAGS.mot_model_dir, 'infer_cfg.yml')
mot_model = JDE_Detector( with open(deploy_file) as f:
pred_config, yml_conf = yaml.safe_load(f)
FLAGS.mot_model_dir, arch = yml_conf['arch']
mot_detector_func = 'SDE_Detector'
if arch in MOT_JDE_SUPPORT_MODELS:
mot_detector_func = 'JDE_Detector'
mot_detector = eval(mot_detector_func)(FLAGS.mot_model_dir,
FLAGS.tracker_config,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
batch_size=1,
trt_min_shape=FLAGS.trt_min_shape, trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape, trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode, trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn) enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.mot_threshold,
output_dir=FLAGS.output_dir)
pred_config = PredictConfig_KeyPoint(FLAGS.keypoint_model_dir) topdown_keypoint_detector = KeyPointDetector(
keypoint_model = KeyPoint_Detector(
pred_config,
FLAGS.keypoint_model_dir, FLAGS.keypoint_model_dir,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
...@@ -272,22 +246,27 @@ def main(): ...@@ -272,22 +246,27 @@ def main():
trt_calib_mode=FLAGS.trt_calib_mode, trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn, enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.keypoint_threshold,
output_dir=FLAGS.output_dir,
use_dark=FLAGS.use_dark) use_dark=FLAGS.use_dark)
keypoint_arch = topdown_keypoint_detector.pred_config.arch
assert KEYPOINT_SUPPORT_MODELS[
keypoint_arch] == 'keypoint_topdown', 'MOT-Keypoint unite inference only supports topdown models.'
# predict from video file or camera video stream # predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1: if FLAGS.video_file is not None or FLAGS.camera_id != -1:
mot_keypoint_unite_predict_video(mot_model, keypoint_model, mot_topdown_unite_predict_video(
FLAGS.camera_id, mot_detector, topdown_keypoint_detector, FLAGS.camera_id,
FLAGS.keypoint_batch_size) FLAGS.keypoint_batch_size, FLAGS.save_res)
else: else:
# predict from image # predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
mot_keypoint_unite_predict_image(mot_model, keypoint_model, img_list, mot_topdown_unite_predict(mot_detector, topdown_keypoint_detector,
FLAGS.keypoint_batch_size) img_list, FLAGS.keypoint_batch_size,
FLAGS.save_res)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
mot_model.det_times.info(average=True) mot_detector.det_times.info(average=True)
keypoint_model.det_times.info(average=True) topdown_keypoint_detector.det_times.info(average=True)
else: else:
mode = FLAGS.run_mode mode = FLAGS.run_mode
mot_model_dir = FLAGS.mot_model_dir mot_model_dir = FLAGS.mot_model_dir
...@@ -295,14 +274,15 @@ def main(): ...@@ -295,14 +274,15 @@ def main():
'model_name': mot_model_dir.strip('/').split('/')[-1], 'model_name': mot_model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1] 'precision': mode.split('_')[-1]
} }
bench_log(mot_model, img_list, mot_model_info, name='MOT') bench_log(mot_detector, img_list, mot_model_info, name='MOT')
keypoint_model_dir = FLAGS.keypoint_model_dir keypoint_model_dir = FLAGS.keypoint_model_dir
keypoint_model_info = { keypoint_model_info = {
'model_name': keypoint_model_dir.strip('/').split('/')[-1], 'model_name': keypoint_model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1] 'precision': mode.split('_')[-1]
} }
bench_log(keypoint_model, img_list, keypoint_model_info, 'KeyPoint') bench_log(topdown_keypoint_detector, img_list, keypoint_model_info,
FLAGS.keypoint_batch_size, 'KeyPoint')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -123,5 +123,17 @@ def argsparser(): ...@@ -123,5 +123,17 @@ def argsparser():
type=bool, type=bool,
default=True, default=True,
help='whether to use darkpose to get better keypoint position predict ') help='whether to use darkpose to get better keypoint position predict ')
parser.add_argument(
'--save_res',
type=bool,
default=False,
help=(
"whether to save predict results to json file"
"1) store_res: a list of image_data"
"2) image_data: [imageid, rects, [keypoints, scores]]"
"3) rects: list of rect [xmin, ymin, xmax, ymax]"
"4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list"
"5) scores: mean of all joint conf"))
parser.add_argument(
"--tracker_config", type=str, default=None, help=("tracker donfig"))
return parser return parser
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,93 +18,38 @@ import yaml ...@@ -18,93 +18,38 @@ import yaml
import cv2 import cv2
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
import paddle import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from picodet_postprocess import PicoDetPostProcess
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, DetectorPicoDet, get_test_images, print_arguments, PredictConfig
from infer import load_predictor
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
from ppdet.modeling.mot.tracker import DeepSORTTracker # add python path
from ppdet.modeling.mot.visualization import plot_tracking import sys
from ppdet.modeling.mot.utils import MOTTimer, write_mot_results parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
# Global dictionary
MOT_SUPPORT_MODELS = {'DeepSORT'}
from pptracking.python.mot import JDETracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results
from pptracking.python.visualize import plot_tracking, plot_tracking_dict
def bench_log(detector, img_list, model_info, batch_size=1, name=None): # Global dictionary
mems = { MOT_SDE_SUPPORT_MODELS = {
'cpu_rss_mb': detector.cpu_mem / len(img_list), 'DeepSORT',
'gpu_rss_mb': detector.gpu_mem / len(img_list), 'ByteTrack',
'gpu_util': detector.gpu_util * 100 / len(img_list) 'YOLO',
} }
perf_info = detector.det_times.report(average=True)
data_info = {
'batch_size': batch_size,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
log = PaddleInferBenchmark(detector.config, model_info, data_info,
perf_info, mems)
log(name)
def scale_coords(coords, input_shape, im_shape, scale_factor):
im_shape = im_shape[0]
ratio = scale_factor[0][0]
pad_w = (input_shape[1] - int(im_shape[1])) / 2
pad_h = (input_shape[0] - int(im_shape[0])) / 2
coords[:, 0::2] -= pad_w
coords[:, 1::2] -= pad_h
coords[:, 0:4] /= ratio
coords[:, :4] = np.clip(coords[:, :4], a_min=0, a_max=coords[:, :4].max())
return coords.round()
def clip_box(xyxy, input_shape, im_shape, scale_factor):
im_shape = im_shape[0]
ratio = scale_factor[0][0]
img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)]
xyxy[:, 0::2] = np.clip(xyxy[:, 0::2], a_min=0, a_max=img0_shape[1])
xyxy[:, 1::2] = np.clip(xyxy[:, 1::2], a_min=0, a_max=img0_shape[0])
w = xyxy[:, 2:3] - xyxy[:, 0:1]
h = xyxy[:, 3:4] - xyxy[:, 1:2]
mask = np.logical_and(h > 0, w > 0)
keep_idx = np.nonzero(mask)
return xyxy[keep_idx[0]], keep_idx
def preprocess_reid(imgs,
w=64,
h=192,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
im_batch = []
for img in imgs:
img = cv2.resize(img, (w, h))
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
img = np.expand_dims(img, axis=0)
im_batch.append(img)
im_batch = np.concatenate(im_batch, 0)
return im_batch
class SDE_Detector(Detector): class SDE_Detector(Detector):
""" """
Args: Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
tracker_config (str): tracker config path
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt trt_opt_shape (int): opt shape for dynamic shape in trt
...@@ -112,22 +57,24 @@ class SDE_Detector(Detector): ...@@ -112,22 +57,24 @@ class SDE_Detector(Detector):
calibration, trt_calib_mode need to set True calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN enable_mkldnn (bool): whether to open MKLDNN
use_dark(bool): whether to use postprocess in DarkPose
""" """
def __init__(self, def __init__(self,
pred_config,
model_dir, model_dir,
tracker_config,
device='CPU', device='CPU',
run_mode='paddle', run_mode='paddle',
batch_size=1, batch_size=1,
trt_min_shape=1, trt_min_shape=1,
trt_max_shape=1088, trt_max_shape=1280,
trt_opt_shape=608, trt_opt_shape=640,
trt_calib_mode=False, trt_calib_mode=False,
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False): enable_mkldnn=False,
output_dir='output',
threshold=0.5):
super(SDE_Detector, self).__init__( super(SDE_Detector, self).__init__(
pred_config=pred_config,
model_dir=model_dir, model_dir=model_dir,
device=device, device=device,
run_mode=run_mode, run_mode=run_mode,
...@@ -137,424 +84,153 @@ class SDE_Detector(Detector): ...@@ -137,424 +84,153 @@ class SDE_Detector(Detector):
trt_opt_shape=trt_opt_shape, trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode, trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads, cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn) enable_mkldnn=enable_mkldnn,
assert batch_size == 1, "The JDE Detector only supports batch size=1 now" output_dir=output_dir,
self.pred_config = pred_config threshold=threshold, )
assert batch_size == 1, "MOT model only supports batch_size=1."
def postprocess(self, boxes, input_shape, im_shape, scale_factor, threshold, self.det_times = Timer(with_tracker=True)
scaled): self.num_classes = len(self.pred_config.labels)
over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0]
if len(over_thres_idx) == 0: # tracker config
pred_dets = np.zeros((1, 6), dtype=np.float32) self.tracker_config = tracker_config
pred_xyxys = np.zeros((1, 4), dtype=np.float32) cfg = yaml.safe_load(open(self.tracker_config))['tracker']
return pred_dets, pred_xyxys min_box_area = cfg.get('min_box_area', 200)
else: vertical_ratio = cfg.get('vertical_ratio', 1.6)
boxes = boxes[over_thres_idx] use_byte = cfg.get('use_byte', True)
match_thres = cfg.get('match_thres', 0.9)
if not scaled: conf_thres = cfg.get('conf_thres', 0.6)
# scaled means whether the coords after detector outputs low_conf_thres = cfg.get('low_conf_thres', 0.1)
# have been scaled back to the original image, set True
# in general detector, set False in JDE YOLOv3. self.tracker = JDETracker(
pred_bboxes = scale_coords(boxes[:, 2:], input_shape, im_shape, use_byte=use_byte,
scale_factor) num_classes=self.num_classes,
else: min_box_area=min_box_area,
pred_bboxes = boxes[:, 2:] vertical_ratio=vertical_ratio,
match_thres=match_thres,
pred_xyxys, keep_idx = clip_box(pred_bboxes, input_shape, im_shape, conf_thres=conf_thres,
scale_factor) low_conf_thres=low_conf_thres)
if len(keep_idx[0]) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32) def tracking(self, det_results):
pred_xyxys = np.zeros((1, 4), dtype=np.float32) pred_dets = det_results['boxes']
return pred_dets, pred_xyxys pred_embs = None
pred_scores = boxes[:, 1:2][keep_idx[0]]
pred_cls_ids = boxes[:, 0:1][keep_idx[0]]
pred_tlwhs = np.concatenate(
(pred_xyxys[:, 0:2], pred_xyxys[:, 2:4] - pred_xyxys[:, 0:2] + 1),
axis=1)
pred_dets = np.concatenate( pred_dets = np.concatenate(
(pred_tlwhs, pred_scores, pred_cls_ids), axis=1) (pred_dets[:, 2:], pred_dets[:, 1:2], pred_dets[:, 0:1]), 1)
# pred_dets should be 'x0, y0, x1, y1, score, cls_id'
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
for cls_id in range(self.num_classes):
online_targets = online_targets_dict[cls_id]
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs[cls_id].append(tlwh)
online_ids[cls_id].append(tid)
online_scores[cls_id].append(tscore)
return pred_dets, pred_xyxys return online_tlwhs, online_scores, online_ids
def predict(self, image, scaled, threshold=0.5, repeats=1, add_timer=True): def predict_image(self,
''' image_list,
Args: run_benchmark=False,
image (np.ndarray): image numpy data repeats=1,
scaled (bool): whether the coords after detector outputs are scaled, visual=True):
default False in jde yolov3, set True in general detector. mot_results = []
threshold (float): threshold of predicted box' score num_classes = self.num_classes
repeats (int): repeat number for prediction image_list.sort()
add_timer (bool): whether add timer during prediction ids2names = self.pred_config.labels
Returns: for frame_id, img_file in enumerate(image_list):
pred_dets (np.ndarray, [N, 6]) batch_image_list = [img_file] # bs=1 in MOT model
''' if run_benchmark:
# preprocess # preprocess
if add_timer: inputs = self.preprocess(batch_image_list) # warmup
self.det_times.preprocess_time_s.start() self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image) inputs = self.preprocess(batch_image_list)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end() self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
boxes = boxes_tensor.copy_to_cpu()
if add_timer: # model prediction
result_warmup = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats) self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
# postprocess # postprocess
if len(boxes) == 0: result_warmup = self.postprocess(inputs, result) # warmup
pred_dets = np.zeros((1, 6), dtype=np.float32) self.det_times.postprocess_time_s.start()
pred_xyxys = np.zeros((1, 4), dtype=np.float32) det_result = self.postprocess(inputs, result)
else:
input_shape = inputs['image'].shape[2:]
im_shape = inputs['im_shape']
scale_factor = inputs['scale_factor']
pred_dets, pred_xyxys = self.postprocess(
boxes, input_shape, im_shape, scale_factor, threshold, scaled)
if add_timer:
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return pred_dets, pred_xyxys
class SDE_DetectorPicoDet(DetectorPicoDet):
"""
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self, # tracking
pred_config, result_warmup = self.tracking(det_result)
model_dir, self.det_times.tracking_time_s.start()
device='CPU', online_tlwhs, online_scores, online_ids = self.tracking(
run_mode='paddle', det_result)
batch_size=1, self.det_times.tracking_time_s.end()
trt_min_shape=1,
trt_max_shape=1088,
trt_opt_shape=608,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
super(SDE_DetectorPicoDet, self).__init__(
pred_config=pred_config,
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
assert batch_size == 1, "The JDE Detector only supports batch size=1 now"
self.pred_config = pred_config
def postprocess_bboxes(self, boxes, input_shape, im_shape, scale_factor,
threshold):
over_thres_idx = np.nonzero(boxes[:, 1:2] >= threshold)[0]
if len(over_thres_idx) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
return pred_dets, pred_xyxys
else:
boxes = boxes[over_thres_idx]
pred_bboxes = boxes[:, 2:]
pred_xyxys, keep_idx = clip_box(pred_bboxes, input_shape, im_shape,
scale_factor)
if len(keep_idx[0]) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
return pred_dets, pred_xyxys
pred_scores = boxes[:, 1:2][keep_idx[0]]
pred_cls_ids = boxes[:, 0:1][keep_idx[0]]
pred_tlwhs = np.concatenate(
(pred_xyxys[:, 0:2], pred_xyxys[:, 2:4] - pred_xyxys[:, 0:2] + 1),
axis=1)
pred_dets = np.concatenate(
(pred_tlwhs, pred_scores, pred_cls_ids), axis=1)
return pred_dets, pred_xyxys
def predict(self, image, scaled, threshold=0.5, repeats=1, add_timer=True):
'''
Args:
image (np.ndarray): image numpy data
scaled (bool): whether the coords after detector outputs are scaled,
default False in jde yolov3, set True in general detector.
threshold (float): threshold of predicted box' score
repeats (int): repeat number for prediction
add_timer (bool): whether add timer during prediction
Returns:
pred_dets (np.ndarray, [N, 6])
'''
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
# model prediction
np_score_list, np_boxes_list = [], []
for i in range(repeats):
self.predictor.run()
np_score_list.clear()
np_boxes_list.clear()
output_names = self.predictor.get_output_names()
num_outs = int(len(output_names) / 2)
for out_idx in range(num_outs):
np_score_list.append(
self.predictor.get_output_handle(output_names[out_idx])
.copy_to_cpu())
np_boxes_list.append(
self.predictor.get_output_handle(output_names[
out_idx + num_outs]).copy_to_cpu())
if add_timer:
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.img_num += 1 self.det_times.img_num += 1
self.det_times.postprocess_time_s.start()
# postprocess
self.postprocess = PicoDetPostProcess(
inputs['image'].shape[2:],
inputs['im_shape'],
inputs['scale_factor'],
strides=self.pred_config.fpn_stride,
nms_threshold=self.pred_config.nms['nms_threshold'])
boxes, boxes_num = self.postprocess(np_score_list, np_boxes_list)
if len(boxes) == 0:
pred_dets = np.zeros((1, 6), dtype=np.float32)
pred_xyxys = np.zeros((1, 4), dtype=np.float32)
else:
input_shape = inputs['image'].shape[2:]
im_shape = inputs['im_shape']
scale_factor = inputs['scale_factor']
pred_dets, pred_xyxys = self.postprocess_bboxes(
boxes, input_shape, im_shape, scale_factor, threshold)
if add_timer:
self.det_times.postprocess_time_s.end()
return pred_dets, pred_xyxys
cm, gm, gu = get_current_memory_mb()
self.cpu_mem += cm
self.gpu_mem += gm
self.gpu_util += gu
class SDE_ReID(object): else:
def __init__(self,
pred_config,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=50,
trt_min_shape=1,
trt_max_shape=1088,
trt_opt_shape=608,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
self.pred_config = pred_config
self.predictor, self.config = load_predictor(
model_dir,
run_mode=run_mode,
batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
self.batch_size = batch_size
assert pred_config.tracker, "Tracking model should have tracker"
pt = pred_config.tracker
max_age = pt['max_age'] if 'max_age' in pt else 30
max_iou_distance = pt[
'max_iou_distance'] if 'max_iou_distance' in pt else 0.7
self.tracker = DeepSORTTracker(
max_age=max_age, max_iou_distance=max_iou_distance)
def get_crops(self, xyxy, ori_img):
w, h = self.tracker.input_size
self.det_times.preprocess_time_s.start() self.det_times.preprocess_time_s.start()
crops = [] inputs = self.preprocess(batch_image_list)
xyxy = xyxy.astype(np.int64)
ori_img = ori_img.transpose(1, 0, 2) # [h,w,3]->[w,h,3]
for i, bbox in enumerate(xyxy):
crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :]
crops.append(crop)
crops = preprocess_reid(crops, w, h)
self.det_times.preprocess_time_s.end() self.det_times.preprocess_time_s.end()
return crops
def preprocess(self, crops):
# to keep fast speed, only use topk crops
crops = crops[:self.batch_size]
inputs = {}
inputs['crops'] = np.array(crops).astype('float32')
return inputs
def postprocess(self, pred_dets, pred_embs):
tracker = self.tracker
tracker.predict()
online_targets = tracker.update(pred_dets, pred_embs)
online_tlwhs, online_scores, online_ids = [], [], []
for t in online_targets:
if not t.is_confirmed() or t.time_since_update > 1:
continue
tlwh = t.to_tlwh()
tscore = t.score
tid = t.track_id
if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue
if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh)
online_scores.append(tscore)
online_ids.append(tid)
return online_tlwhs, online_scores, online_ids
def predict(self, crops, pred_dets, repeats=1, add_timer=True):
# preprocess
if add_timer:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(crops)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
if add_timer:
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start() self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
# model prediction
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
feature_tensor = self.predictor.get_output_handle(output_names[0])
pred_embs = feature_tensor.copy_to_cpu()
if add_timer:
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start() self.det_times.postprocess_time_s.start()
det_result = self.postprocess(inputs, result)
# postprocess
online_tlwhs, online_scores, online_ids = self.postprocess(pred_dets,
pred_embs)
if add_timer:
self.det_times.postprocess_time_s.end() self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return online_tlwhs, online_scores, online_ids
# tracking process
self.det_times.tracking_time_s.start()
online_tlwhs, online_scores, online_ids = self.tracking(
det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
def predict_image(detector, reid_model, image_list): if visual:
image_list.sort() if frame_id % 10 == 0:
for i, img_file in enumerate(image_list): print('Tracking frame {}'.format(frame_id))
frame = cv2.imread(img_file) frame, _ = decode_image(img_file, {})
if FLAGS.run_benchmark:
# warmup
pred_dets, pred_xyxys = detector.predict(
[frame],
FLAGS.scaled,
FLAGS.threshold,
repeats=10,
add_timer=True)
# run benchmark
pred_dets, pred_xyxys = detector.predict(
[frame],
FLAGS.scaled,
FLAGS.threshold,
repeats=10,
add_timer=True)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file))
else:
pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled,
FLAGS.threshold)
if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
print('Frame {} has no object, try to modify score threshold.'.
format(i))
online_im = frame
else:
# reid process
crops = reid_model.get_crops(pred_xyxys, frame)
if FLAGS.run_benchmark:
# warmup
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, pred_dets, repeats=10, add_timer=False)
# run benchmark
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, pred_dets, repeats=10, add_timer=False)
else:
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, pred_dets)
online_im = plot_tracking(
frame, online_tlwhs, online_ids, online_scores, frame_id=i)
if FLAGS.save_images: im = plot_tracking_dict(
if not os.path.exists(FLAGS.output_dir): frame,
os.makedirs(FLAGS.output_dir) num_classes,
img_name = os.path.split(img_file)[-1] online_tlwhs,
out_path = os.path.join(FLAGS.output_dir, img_name) online_ids,
cv2.imwrite(out_path, online_im) online_scores,
print("save result to: " + out_path) frame_id=frame_id,
ids2names=[])
seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results
def predict_video(detector, reid_model, camera_id): def predict_video(self, video_file, camera_id):
video_out_name = 'output.mp4'
if camera_id != -1: if camera_id != -1:
capture = cv2.VideoCapture(camera_id) capture = cv2.VideoCapture(camera_id)
video_name = 'mot_output.mp4'
else: else:
capture = cv2.VideoCapture(FLAGS.video_file) capture = cv2.VideoCapture(video_file)
video_name = os.path.split(FLAGS.video_file)[-1] video_out_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count # Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
...@@ -562,86 +238,62 @@ def predict_video(detector, reid_model, camera_id): ...@@ -562,86 +238,62 @@ def predict_video(detector, reid_model, camera_id):
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count)) print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(self.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name) out_path = os.path.join(self.output_dir, video_out_name)
if not FLAGS.save_images: fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
frame_id = 1
timer = MOTTimer() timer = MOTTimer()
results = defaultdict(list) results = defaultdict(list) # support single class and multi classes
num_classes = self.num_classes
while (1): while (1):
ret, frame = capture.read() ret, frame = capture.read()
if not ret: if not ret:
break break
timer.tic() if frame_id % 10 == 0:
pred_dets, pred_xyxys = detector.predict([frame], FLAGS.scaled, print('Tracking frame: %d' % (frame_id))
FLAGS.threshold) frame_id += 1
if len(pred_dets) == 1 and np.sum(pred_dets) == 0: timer.tic()
print('Frame {} has no object, try to modify score threshold.'. mot_results = self.predict_image([frame], visual=False)
format(frame_id))
timer.toc()
im = frame
else:
# reid process
crops = reid_model.get_crops(pred_xyxys, frame)
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, pred_dets)
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
timer.toc() timer.toc()
fps = 1. / timer.average_time online_tlwhs, online_scores, online_ids = mot_results[0]
im = plot_tracking( for cls_id in range(num_classes):
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
fps = 1. / timer.duration
im = plot_tracking_dict(
frame, frame,
num_classes,
online_tlwhs, online_tlwhs,
online_ids, online_ids,
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
fps=fps) fps=fps,
ids2names=[])
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
else:
writer.write(im) writer.write(im)
frame_id += 1
print('detect frame:%d' % (frame_id))
if camera_id != -1: if camera_id != -1:
cv2.imshow('Tracking Detection', im) cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(save_dir,
out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release() writer.release()
def main(): def main():
pred_config = PredictConfig(FLAGS.model_dir) deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
detector_func = 'SDE_Detector' with open(deploy_file) as f:
if pred_config.arch == 'PicoDet': yml_conf = yaml.safe_load(f)
detector_func = 'SDE_DetectorPicoDet' arch = yml_conf['arch']
assert arch in MOT_SDE_SUPPORT_MODELS, '{} is not supported.'.format(arch)
detector = eval(detector_func)(pred_config, detector = SDE_Detector(
FLAGS.model_dir, FLAGS.model_dir,
FLAGS.tracker_config,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
...@@ -650,48 +302,30 @@ def main(): ...@@ -650,48 +302,30 @@ def main():
trt_opt_shape=FLAGS.trt_opt_shape, trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode, trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads, cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn) enable_mkldnn=FLAGS.enable_mkldnn,
threshold=FLAGS.threshold,
pred_config = PredictConfig(FLAGS.reid_model_dir) output_dir=FLAGS.output_dir)
reid_model = SDE_ReID(
pred_config,
FLAGS.reid_model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.reid_batch_size,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
# predict from video file or camera video stream # predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1: if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, reid_model, FLAGS.camera_id) detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
else: else:
# predict from image # predict from image
if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, reid_model, img_list) detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
detector.det_times.info(average=True) detector.det_times.info(average=True)
reid_model.det_times.info(average=True)
else: else:
mode = FLAGS.run_mode mode = FLAGS.run_mode
det_model_dir = FLAGS.model_dir model_dir = FLAGS.model_dir
det_model_info = { model_info = {
'model_name': det_model_dir.strip('/').split('/')[-1], 'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(detector, img_list, det_model_info, name='Det')
reid_model_dir = FLAGS.reid_model_dir
reid_model_info = {
'model_name': reid_model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1] 'precision': mode.split('_')[-1]
} }
bench_log(reid_model, img_list, reid_model_info, name='ReID') bench_log(detector, img_list, model_info, name='MOT')
if __name__ == '__main__': if __name__ == '__main__':
......
# config of tracker for MOT SDE Detector, use ByteTracker as default.
# The tracker of MOT JDE Detector is exported together with the model.
# Here 'min_box_area' and 'vertical_ratio' are set for pedestrian, you can modify for other objects tracking.
tracker:
use_byte: true
conf_thres: 0.6
low_conf_thres: 0.1
match_thres: 0.9
min_box_area: 100
vertical_ratio: 1.6
...@@ -118,6 +118,8 @@ def argsparser(): ...@@ -118,6 +118,8 @@ def argsparser():
default=False, default=False,
help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 " help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector.") "True in general detector.")
parser.add_argument(
"--tracker_config", type=str, default=None, help=("tracker donfig"))
parser.add_argument( parser.add_argument(
"--reid_model_dir", "--reid_model_dir",
type=str, type=str,
...@@ -165,29 +167,36 @@ class Times(object): ...@@ -165,29 +167,36 @@ class Times(object):
class Timer(Times): class Timer(Times):
def __init__(self): def __init__(self, with_tracker=False):
super(Timer, self).__init__() super(Timer, self).__init__()
self.with_tracker = with_tracker
self.preprocess_time_s = Times() self.preprocess_time_s = Times()
self.inference_time_s = Times() self.inference_time_s = Times()
self.postprocess_time_s = Times() self.postprocess_time_s = Times()
self.tracking_time_s = Times()
self.img_num = 0 self.img_num = 0
def info(self, average=False): def info(self, average=False):
total_time = self.preprocess_time_s.value( pre_time = self.preprocess_time_s.value()
) + self.inference_time_s.value() + self.postprocess_time_s.value() infer_time = self.inference_time_s.value()
post_time = self.postprocess_time_s.value()
track_time = self.tracking_time_s.value()
total_time = pre_time + infer_time + post_time
if self.with_tracker:
total_time = total_time + track_time
total_time = round(total_time, 4) total_time = round(total_time, 4)
print("------------------ Inference Time Info ----------------------") print("------------------ Inference Time Info ----------------------")
print("total_time(ms): {}, img_num: {}".format(total_time * 1000, print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
self.img_num)) self.img_num))
preprocess_time = round( preprocess_time = round(pre_time / max(1, self.img_num),
self.preprocess_time_s.value() / max(1, self.img_num), 4) if average else pre_time
4) if average else self.preprocess_time_s.value() postprocess_time = round(post_time / max(1, self.img_num),
postprocess_time = round( 4) if average else post_time
self.postprocess_time_s.value() / max(1, self.img_num), inference_time = round(infer_time / max(1, self.img_num),
4) if average else self.postprocess_time_s.value() 4) if average else infer_time
inference_time = round(self.inference_time_s.value() / tracking_time = round(track_time / max(1, self.img_num),
max(1, self.img_num), 4) if average else track_time
4) if average else self.inference_time_s.value()
average_latency = total_time / max(1, self.img_num) average_latency = total_time / max(1, self.img_num)
qps = 0 qps = 0
...@@ -195,6 +204,12 @@ class Timer(Times): ...@@ -195,6 +204,12 @@ class Timer(Times):
qps = 1 / average_latency qps = 1 / average_latency
print("average latency time(ms): {:.2f}, QPS: {:2f}".format( print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
average_latency * 1000, qps)) average_latency * 1000, qps))
if self.with_tracker:
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000, tracking_time * 1000))
else:
print( print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}". "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000, format(preprocess_time * 1000, inference_time * 1000,
...@@ -202,18 +217,23 @@ class Timer(Times): ...@@ -202,18 +217,23 @@ class Timer(Times):
def report(self, average=False): def report(self, average=False):
dic = {} dic = {}
dic['preprocess_time_s'] = round( pre_time = self.preprocess_time_s.value()
self.preprocess_time_s.value() / max(1, self.img_num), infer_time = self.inference_time_s.value()
4) if average else self.preprocess_time_s.value() post_time = self.postprocess_time_s.value()
dic['postprocess_time_s'] = round( track_time = self.tracking_time_s.value()
self.postprocess_time_s.value() / max(1, self.img_num),
4) if average else self.postprocess_time_s.value() dic['preprocess_time_s'] = round(pre_time / max(1, self.img_num),
dic['inference_time_s'] = round( 4) if average else pre_time
self.inference_time_s.value() / max(1, self.img_num), dic['inference_time_s'] = round(infer_time / max(1, self.img_num),
4) if average else self.inference_time_s.value() 4) if average else infer_time
dic['postprocess_time_s'] = round(post_time / max(1, self.img_num),
4) if average else post_time
dic['tracking_time_s'] = round(post_time / max(1, self.img_num),
4) if average else track_time
dic['img_num'] = self.img_num dic['img_num'] = self.img_num
total_time = self.preprocess_time_s.value( total_time = pre_time + infer_time + post_time
) + self.inference_time_s.value() + self.postprocess_time_s.value() if self.with_tracker:
total_time = total_time + track_time
dic['total_time_s'] = round(total_time, 4) dic['total_time_s'] = round(total_time, 4)
return dic return dic
......
...@@ -224,9 +224,9 @@ def get_color(idx): ...@@ -224,9 +224,9 @@ def get_color(idx):
return color return color
def draw_pose(imgfile, def visualize_pose(imgfile,
results, results,
visual_thread=0.6, visual_thresh=0.6,
save_name='pose.jpg', save_name='pose.jpg',
save_dir='output', save_dir='output',
returnimg=False, returnimg=False,
...@@ -239,7 +239,6 @@ def draw_pose(imgfile, ...@@ -239,7 +239,6 @@ def draw_pose(imgfile,
logger.error('Matplotlib not found, please install matplotlib.' logger.error('Matplotlib not found, please install matplotlib.'
'for example: `pip install matplotlib`.') 'for example: `pip install matplotlib`.')
raise e raise e
skeletons, scores = results['keypoint'] skeletons, scores = results['keypoint']
skeletons = np.array(skeletons) skeletons = np.array(skeletons)
kpt_nums = 17 kpt_nums = 17
...@@ -276,7 +275,7 @@ def draw_pose(imgfile, ...@@ -276,7 +275,7 @@ def draw_pose(imgfile,
canvas = img.copy() canvas = img.copy()
for i in range(kpt_nums): for i in range(kpt_nums):
for j in range(len(skeletons)): for j in range(len(skeletons)):
if skeletons[j][i, 2] < visual_thread: if skeletons[j][i, 2] < visual_thresh:
continue continue
if ids is None: if ids is None:
color = colors[i] if color_set is None else colors[color_set[j] color = colors[i] if color_set is None else colors[color_set[j]
...@@ -300,8 +299,8 @@ def draw_pose(imgfile, ...@@ -300,8 +299,8 @@ def draw_pose(imgfile,
for i in range(NUM_EDGES): for i in range(NUM_EDGES):
for j in range(len(skeletons)): for j in range(len(skeletons)):
edge = EDGES[i] edge = EDGES[i]
if skeletons[j][edge[0], 2] < visual_thread or skeletons[j][edge[ if skeletons[j][edge[0], 2] < visual_thresh or skeletons[j][edge[
1], 2] < visual_thread: 1], 2] < visual_thresh:
continue continue
cur_canvas = canvas.copy() cur_canvas = canvas.copy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册