diff --git a/configs/mot/README.md b/configs/mot/README.md index 59f35407e00a17c43b9e9a8cfc39a013f8c020ac..d500e484ac7c8f0f44066a8c6b1710d4804ab403 100644 --- a/configs/mot/README.md +++ b/configs/mot/README.md @@ -39,7 +39,7 @@ or pip install -r requirements.txt ``` **Notes:** -- Install `cython_bbox` for Windows: `pip install -e git+https://github.com/samson-wang/cython_bbox.git#egg=cython-bbox`. You can refer to this [tutorial](https://stackoverflow.com/questions/60349980/is-there-a-way-to-install-cython-bbox-for-windows) +- Install `cython_bbox` for Windows: `pip install -e git+https://github.com/samson-wang/cython_bbox.git#egg=cython-bbox`. You can refer to this [tutorial](https://stackoverflow.com/questions/60349980/is-there-a-way-to-install-cython-bbox-for-windows). - Evaluation on Windows CUDA 11 environment may not be normally. It will be repaired as soon as possible. You can change to CUDA 10.2 or CUDA 10.1 environment for normal evaluation. diff --git a/configs/mot/deepsort/README.md b/configs/mot/deepsort/README.md index b7a258138d604c5b6395494c1476abf93ece568a..96cdcc4cfb3d25198859a2a84720c693cec6c7a9 100644 --- a/configs/mot/deepsort/README.md +++ b/configs/mot/deepsort/README.md @@ -80,6 +80,26 @@ CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsor **Notes:** Please make sure that [ffmpeg](https://ffmpeg.org/ffmpeg.html) is installed first, on Linux(Ubuntu) platform you can directly install it by the following command:`apt-get update && apt-get install -y ffmpeg`. +### 3. Export model + +```bash +1.export detection model +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams + +2.export ReID model +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams +or +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams +``` + +### 4. Using exported model for python inference + +```bash +python deploy/python/mot_reid_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts +``` +**Notes:** +The tracking model is used to predict the video, and does not support the prediction of a single image. The visualization video of the tracking results is saved by default. You can add `--save_mot_txts` to save the txt result file, or `--save_images` to save the visualization images. + ## Citations ``` @inproceedings{Wojke2017simple, diff --git a/configs/mot/deepsort/README_cn.md b/configs/mot/deepsort/README_cn.md index a4b7ba0765459bed92f18c072724310999da0696..3c82a83b75440d0dc19af561587c7e06fad955ca 100644 --- a/configs/mot/deepsort/README_cn.md +++ b/configs/mot/deepsort/README_cn.md @@ -82,6 +82,27 @@ CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsor **注意:** 请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`。 +### 3. 导出预测模型 + +```bash +1.先导出检测模型 +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams + +2.再导出ReID模型 +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams + +或 +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams +``` + +### 4. 用导出的模型基于Python去预测 + +```bash +python deploy/python/mot_reid_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts +``` +**注意:** + 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。 + ## 引用 ``` @inproceedings{Wojke2017simple, diff --git a/configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml b/configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml new file mode 100644 index 0000000000000000000000000000000000000000..834064648c428badd091a10602f69bb38475469b --- /dev/null +++ b/configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml @@ -0,0 +1,81 @@ +_BASE_: [ + '../../datasets/mot.yml', + '../../runtime.yml', + '../jde/_base_/optimizer_30e.yml', + '../jde/_base_/jde_reader_1088x608.yml', +] +weights: output/jde_yolov3_darknet53_30e_1088x608/model_final + +metric: MOTDet + +EvalReader: + inputs_def: + num_max_boxes: 50 + sample_transforms: + - Decode: {} + - LetterBoxResize: {target_size: [608, 1088]} + - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} + - Permute: {} + batch_size: 1 + +TestReader: + inputs_def: + image_shape: [3, 608, 1088] + sample_transforms: + - Decode: {} + - LetterBoxResize: {target_size: [608, 1088]} + - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} + - Permute: {} + batch_size: 1 + +EvalDataset: + !MOTDataSet + dataset_dir: dataset/mot + image_lists: ['mot16.train'] + data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide'] + +TestDataset: + !ImageFolder + anno_path: None + +architecture: YOLOv3 +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/DarkNet53_pretrained.pdparams + +# JDE version for MOT dataset +YOLOv3: + backbone: DarkNet + neck: YOLOv3FPN + yolo_head: YOLOv3Head + post_process: JDEBBoxPostProcess + +DarkNet: + depth: 53 + return_idx: [2, 3, 4] + freeze_norm: True + +YOLOv3FPN: + freeze_norm: True + +YOLOv3Head: + anchors: [[128,384], [180,540], [256,640], [512,640], + [32,96], [45,135], [64,192], [90,271], + [8,24], [11,34], [16,48], [23,68]] + anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + loss: JDEDetectionLoss + +JDEDetectionLoss: + for_mot: False + +JDEBBoxPostProcess: + decode: + name: JDEBox + conf_thresh: 0.3 + downsample_ratio: 32 + nms: + name: MultiClassNMS + keep_top_k: 500 + score_threshold: 0.01 + nms_threshold: 0.5 + nms_top_k: 2000 + normalized: true + return_idx: false diff --git a/deploy/python/infer.py b/deploy/python/infer.py index d07128d2b16be5bb1c698ca9f3f2b35fa6fd3dd8..e270bbddf6367c29945ef9cfad99154fc36f8705 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -27,7 +27,7 @@ from paddle.inference import Config from paddle.inference import create_predictor from benchmark_utils import PaddleInferBenchmark -from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride +from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize from visualize import visualize_box_mask from utils import argsparser, Timer, get_current_memory_mb @@ -41,6 +41,9 @@ SUPPORT_MODELS = { 'SOLOv2', 'TTFNet', 'S2ANet', + 'JDE', + 'FairMOT', + 'DeepSORT', } diff --git a/deploy/python/mot_infer.py b/deploy/python/mot_infer.py index e13b8f32071afc6144747d4298e1cdb617fe34b7..288cb9d5c07a5ae9c3cb5ceab0cbfaecc97ec589 100644 --- a/deploy/python/mot_infer.py +++ b/deploy/python/mot_infer.py @@ -19,8 +19,7 @@ import cv2 import numpy as np import paddle from benchmark_utils import PaddleInferBenchmark -from preprocess import preprocess, NormalizeImage, Permute -from mot_preprocess import LetterBoxResize +from preprocess import preprocess, NormalizeImage, Permute, LetterBoxResize from tracker import JDETracker from ppdet.modeling.mot import visualization as mot_vis @@ -29,7 +28,7 @@ from ppdet.modeling.mot.utils import Timer as MOTTimer from paddle.inference import Config from paddle.inference import create_predictor from utils import argsparser, Timer, get_current_memory_mb -from infer import get_test_images, print_arguments +from infer import get_test_images, print_arguments, PredictConfig # Global dictionary MOT_SUPPORT_MODELS = { @@ -69,8 +68,8 @@ class MOT_Detector(object): self.predictor, self.config = load_predictor( model_dir, run_mode=run_mode, - min_subgraph_size=self.pred_config.min_subgraph_size, device=device, + min_subgraph_size=self.pred_config.min_subgraph_size, use_dynamic_shape=self.pred_config.use_dynamic_shape, trt_min_shape=trt_min_shape, trt_max_shape=trt_max_shape, @@ -109,10 +108,10 @@ class MOT_Detector(object): online_scores.append(tscore) return online_tlwhs, online_scores, online_ids - def predict(self, image, threshold=0.5, repeats=1): + def predict(self, image, threshold=0.5, warmup=0, repeats=1): ''' Args: - image (dict): dict(['image', 'im_shape', 'scale_factor']) + image (np.ndarray): numpy image data threshold (float): threshold of predicted box' score Returns: online_tlwhs, online_ids (np.ndarray) @@ -120,12 +119,19 @@ class MOT_Detector(object): self.det_times.preprocess_time_s.start() inputs = self.preprocess(image) self.det_times.preprocess_time_s.end() + pred_dets, pred_embs = None, None input_names = self.predictor.get_input_names() for i in range(len(input_names)): input_tensor = self.predictor.get_input_handle(input_names[i]) input_tensor.copy_from_cpu(inputs[input_names[i]]) + for i in range(warmup): + self.predictor.run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_handle(output_names[0]) + pred_dets = boxes_tensor.copy_to_cpu() + self.det_times.inference_time_s.start() for i in range(repeats): self.predictor.run() @@ -134,7 +140,6 @@ class MOT_Detector(object): pred_dets = boxes_tensor.copy_to_cpu() embs_tensor = self.predictor.get_output_handle(output_names[1]) pred_embs = embs_tensor.copy_to_cpu() - self.det_times.inference_time_s.end(repeats=repeats) self.det_times.postprocess_time_s.start() @@ -150,7 +155,6 @@ def create_inputs(im, im_info): Args: im (np.ndarray): image (np.ndarray) im_info (dict): info of image - model_arch (str): model type Returns: inputs (dict): input of model """ @@ -162,48 +166,6 @@ def create_inputs(im, im_info): return inputs -class PredictConfig_MOT(): - """set config of preprocess, postprocess and visualize - Args: - model_dir (str): root path of model.yml - """ - - def __init__(self, model_dir): - # parsing Yaml config for Preprocess - deploy_file = os.path.join(model_dir, 'infer_cfg.yml') - with open(deploy_file) as f: - yml_conf = yaml.safe_load(f) - self.check_model(yml_conf) - self.arch = yml_conf['arch'] - self.preprocess_infos = yml_conf['Preprocess'] - self.min_subgraph_size = yml_conf['min_subgraph_size'] - self.labels = yml_conf['label_list'] - self.mask = False - self.use_dynamic_shape = yml_conf['use_dynamic_shape'] - if 'mask' in yml_conf: - self.mask = yml_conf['mask'] - self.print_config() - - def check_model(self, yml_conf): - """ - Raises: - ValueError: loaded model not in supported model type - """ - for support_model in MOT_SUPPORT_MODELS: - if support_model in yml_conf['arch']: - return True - raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ - 'arch'], MOT_SUPPORT_MODELS)) - - def print_config(self): - print('----------- Model Configuration -----------') - print('%s: %s' % ('Model Arch', self.arch)) - print('%s: ' % ('Transform Order')) - for op_info in self.preprocess_infos: - print('--%s: %s' % ('transform op', op_info['type'])) - print('--------------------------------------------') - - def load_predictor(model_dir, run_mode='fluid', batch_size=1, @@ -217,6 +179,7 @@ def load_predictor(model_dir, cpu_threads=1, enable_mkldnn=False): """set AnalysisConfig, generate AnalysisPredictor + Note: only support batch_size=1 now Args: model_dir (str): root path of __model__ and __params__ run_mode (str): mode of running(fluid/trt_fp32/trt_fp16/trt_int8) @@ -325,6 +288,30 @@ def write_mot_results(filename, results, data_type='mot'): f.write(line) +def predict_image(detector, image_list): + results = [] + for i, img_file in enumerate(image_list): + frame = cv2.imread(img_file) + if FLAGS.run_benchmark: + detector.predict(frame, FLAGS.threshold, warmup=10, repeats=10) + cm, gm, gu = get_current_memory_mb() + detector.cpu_mem += cm + detector.gpu_mem += gm + detector.gpu_util += gu + print('Test iter {}, file name:{}'.format(i, img_file)) + else: + online_tlwhs, online_scores, online_ids = detector.predict( + frame, FLAGS.threshold) + + online_im = mot_vis.plot_tracking( + frame, online_tlwhs, online_ids, online_scores, frame_id=i) + + if FLAGS.save_images: + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + cv2.imwrite(os.path.join(FLAGS.output_dir, img_file), online_im) + + def predict_video(detector, camera_id): if camera_id != -1: capture = cv2.VideoCapture(camera_id) @@ -364,8 +351,7 @@ def predict_video(detector, camera_id): online_ids, online_scores, frame_id=frame_id, - fps=fps, - threhold=FLAGS.threshold) + fps=fps) if FLAGS.save_images: save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) if not os.path.exists(save_dir): @@ -381,7 +367,7 @@ def predict_video(detector, camera_id): cv2.imshow('Tracking Detection', im) if cv2.waitKey(1) & 0xFF == ord('q'): break - if FLAGS.save_results: + if FLAGS.save_mot_txts: result_filename = os.path.join(FLAGS.output_dir, video_name.split('.')[-2] + '.txt') write_mot_results(result_filename, results) @@ -389,7 +375,7 @@ def predict_video(detector, camera_id): def main(): - pred_config = PredictConfig_MOT(FLAGS.model_dir) + pred_config = PredictConfig(FLAGS.model_dir) detector = MOT_Detector( pred_config, FLAGS.model_dir, @@ -406,7 +392,32 @@ def main(): if FLAGS.video_file is not None or FLAGS.camera_id != -1: predict_video(detector, FLAGS.camera_id) else: - print('MOT models do not support predict single image.') + # predict from image + img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) + predict_image(detector, img_list) + if not FLAGS.run_benchmark: + detector.det_times.info(average=True) + else: + mems = { + 'cpu_rss_mb': detector.cpu_mem / len(img_list), + 'gpu_rss_mb': detector.gpu_mem / len(img_list), + 'gpu_util': detector.gpu_util * 100 / len(img_list) + } + perf_info = detector.det_times.report(average=True) + model_dir = FLAGS.model_dir + mode = FLAGS.run_mode + model_info = { + 'model_name': model_dir.strip('/').split('/')[-1], + 'precision': mode.split('_')[-1] + } + data_info = { + 'batch_size': 1, + 'shape': "dynamic_shape", + 'data_num': perf_info['img_num'] + } + det_log = PaddleInferBenchmark(detector.config, model_info, + data_info, perf_info, mems) + det_log('MOT') if __name__ == '__main__': diff --git a/deploy/python/mot_keypoint_unite_infer.py b/deploy/python/mot_keypoint_unite_infer.py index 2a2a6c7f51da6f197f894b848aaef33e50ee8015..58411df4a5e7ce131a6d9c481ea2b4b317251a86 100644 --- a/deploy/python/mot_keypoint_unite_infer.py +++ b/deploy/python/mot_keypoint_unite_infer.py @@ -20,16 +20,55 @@ import paddle from mot_keypoint_unite_utils import argsparser from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint +from keypoint_det_unite_infer import bench_log from keypoint_visualize import draw_pose from benchmark_utils import PaddleInferBenchmark from utils import Timer from tracker import JDETracker -from mot_preprocess import LetterBoxResize -from mot_infer import MOT_Detector, PredictConfig_MOT, write_mot_results -from infer import print_arguments +from preprocess import LetterBoxResize +from mot_infer import MOT_Detector, write_mot_results +from infer import Detector, PredictConfig, print_arguments, get_test_images from ppdet.modeling.mot import visualization as mot_vis from ppdet.modeling.mot.utils import Timer as FPSTimer +from utils import get_current_memory_mb + + +def mot_keypoint_unite_predict_image(mot_model, keypoint_model, image_list): + for i, img_file in enumerate(image_list): + frame = cv2.imread(img_file) + + if FLAGS.run_benchmark: + mot_model.predict(frame, FLAGS.mot_threshold, warmup=10, repeats=10) + cm, gm, gu = get_current_memory_mb() + mot_model.cpu_mem += cm + mot_model.gpu_mem += gm + mot_model.gpu_util += gu + + keypoint_model.predict( + [frame], FLAGS.keypoint_threshold, warmup=10, repeats=10) + cm, gm, gu = get_current_memory_mb() + keypoint_model.cpu_mem += cm + keypoint_model.gpu_mem += gm + keypoint_model.gpu_util += gu + else: + online_tlwhs, online_scores, online_ids = mot_model.predict( + frame, FLAGS.mot_threshold) + keypoint_results = keypoint_model.predict([frame], + FLAGS.keypoint_threshold) + + im = draw_pose( + frame, + keypoint_results, + visual_thread=FLAGS.keypoint_threshold, + returnimg=True) + + online_im = mot_vis.plot_tracking( + im, online_tlwhs, online_ids, online_scores, frame_id=i) + if FLAGS.save_images: + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + cv2.imwrite(os.path.join(FLAGS.output_dir, img_file), online_im) def mot_keypoint_unite_predict_video(mot_model, keypoint_model, camera_id): @@ -117,7 +156,7 @@ def mot_keypoint_unite_predict_video(mot_model, keypoint_model, camera_id): def main(): - pred_config = PredictConfig_MOT(FLAGS.mot_model_dir) + pred_config = PredictConfig(FLAGS.mot_model_dir) mot_model = MOT_Detector( pred_config, FLAGS.mot_model_dir, @@ -149,7 +188,28 @@ def main(): mot_keypoint_unite_predict_video(mot_model, keypoint_model, FLAGS.camera_id) else: - print('Do not support unite predict single image.') + # predict from image + img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) + mot_keypoint_unite_predict_image(mot_model, keypoint_model, img_list) + + if not FLAGS.run_benchmark: + mot_model.det_times.info(average=True) + keypoint_model.det_times.info(average=True) + else: + mode = FLAGS.run_mode + mot_model_dir = FLAGS.mot_model_dir + mot_model_info = { + 'model_name': mot_model_dir.strip('/').split('/')[-1], + 'precision': mode.split('_')[-1] + } + bench_log(mot_model, img_list, mot_model_info, name='MOT') + + keypoint_model_dir = FLAGS.keypoint_model_dir + keypoint_model_info = { + 'model_name': keypoint_model_dir.strip('/').split('/')[-1], + 'precision': mode.split('_')[-1] + } + bench_log(keypoint_model, img_list, keypoint_model_info, 'KeyPoint') if __name__ == '__main__': diff --git a/deploy/python/mot_preprocess.py b/deploy/python/mot_preprocess.py deleted file mode 100644 index 55e248efb0cb0ddad48aa3fc1de017d8ba7f1b1e..0000000000000000000000000000000000000000 --- a/deploy/python/mot_preprocess.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np - - -class LetterBoxResize(object): - def __init__(self, target_size): - """ - Resize image to target size, convert normalized xywh to pixel xyxy - format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]). - Args: - target_size (int|list): image target size. - """ - super(LetterBoxResize, self).__init__() - if isinstance(target_size, int): - target_size = [target_size, target_size] - self.target_size = target_size - - def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)): - # letterbox: resize a rectangular image to a padded rectangular - shape = img.shape[:2] # [height, width] - ratio_h = float(height) / shape[0] - ratio_w = float(width) / shape[1] - ratio = min(ratio_h, ratio_w) - new_shape = (round(shape[1] * ratio), - round(shape[0] * ratio)) # [width, height] - padw = (width - new_shape[0]) / 2 - padh = (height - new_shape[1]) / 2 - top, bottom = round(padh - 0.1), round(padh + 0.1) - left, right = round(padw - 0.1), round(padw + 0.1) - - img = cv2.resize( - img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border - img = cv2.copyMakeBorder( - img, top, bottom, left, right, cv2.BORDER_CONSTANT, - value=color) # padded rectangular - return img, ratio, padw, padh - - def __call__(self, im, im_info): - """ - Args: - im (np.ndarray): image (np.ndarray) - im_info (dict): info of image - Returns: - im (np.ndarray): processed image (np.ndarray) - im_info (dict): info of processed image - """ - assert len(self.target_size) == 2 - assert self.target_size[0] > 0 and self.target_size[1] > 0 - height, width = self.target_size - h, w = im.shape[:2] - im, ratio, padw, padh = self.letterbox(im, height=height, width=width) - - new_shape = [round(h * ratio), round(w * ratio)] - im_info['im_shape'] = np.array(new_shape, dtype=np.float32) - im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32) - return im, im_info diff --git a/deploy/python/mot_reid_infer.py b/deploy/python/mot_reid_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..176acb0c5e241cf01fc5e2130547a5be6515160a --- /dev/null +++ b/deploy/python/mot_reid_infer.py @@ -0,0 +1,478 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import yaml +import cv2 +import numpy as np +import paddle +from benchmark_utils import PaddleInferBenchmark +from preprocess import preprocess, NormalizeImage, Permute, LetterBoxResize +from tracker import DeepSORTTracker +from ppdet.modeling.mot import visualization as mot_vis +from ppdet.modeling.mot.utils import Timer as MOTTimer +from ppdet.modeling.mot.utils import Detection + +from paddle.inference import Config +from paddle.inference import create_predictor +from utils import argsparser, Timer, get_current_memory_mb +from infer import get_test_images, print_arguments, PredictConfig, Detector +from mot_infer import create_inputs, load_predictor, write_mot_results + +# Global dictionary +MOT_SUPPORT_MODELS = {'DeepSORT'} + + +def bench_log(detector, img_list, model_info, batch_size=1, name=None): + mems = { + 'cpu_rss_mb': detector.cpu_mem / len(img_list), + 'gpu_rss_mb': detector.gpu_mem / len(img_list), + 'gpu_util': detector.gpu_util * 100 / len(img_list) + } + perf_info = detector.det_times.report(average=True) + data_info = { + 'batch_size': batch_size, + 'shape': "dynamic_shape", + 'data_num': perf_info['img_num'] + } + log = PaddleInferBenchmark(detector.config, model_info, data_info, + perf_info, mems) + log(name) + + +def scale_coords(coords, input_shape, im_shape, scale_factor): + im_shape = im_shape[0] + ratio = scale_factor[0][0] + pad_w = (input_shape[1] - int(im_shape[1])) / 2 + pad_h = (input_shape[0] - int(im_shape[0])) / 2 + coords[:, 0::2] -= pad_w + coords[:, 1::2] -= pad_h + coords[:, 0:4] /= ratio + coords[:, :4] = np.clip(coords[:, :4], a_min=0, a_max=coords[:, :4].max()) + return coords.round() + + +def clip_box(xyxy, input_shape, im_shape, scale_factor): + im_shape = im_shape[0] + ratio = scale_factor[0][0] + img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)] + xyxy[:, 0::2] = np.clip(xyxy[:, 0::2], a_min=0, a_max=img0_shape[1]) + xyxy[:, 1::2] = np.clip(xyxy[:, 1::2], a_min=0, a_max=img0_shape[0]) + return xyxy + + +def get_crops(xyxy, ori_img, pred_scores, w, h): + crops = [] + keep_scores = [] + xyxy = xyxy.astype(np.int64) + ori_img = ori_img.transpose(1, 0, 2) # [h,w,3]->[w,h,3] + for i, bbox in enumerate(xyxy): + if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]: + continue + crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :] + crops.append(crop) + keep_scores.append(pred_scores[i]) + if len(crops) == 0: + return [], [] + crops = preprocess_reid(crops, w, h) + return crops, keep_scores + + +def preprocess_reid(imgs, + w=64, + h=192, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]): + im_batch = [] + for img in imgs: + img = cv2.resize(img, (w, h)) + img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255 + img_mean = np.array(mean).reshape((3, 1, 1)) + img_std = np.array(std).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + img = np.expand_dims(img, axis=0) + im_batch.append(img) + im_batch = np.concatenate(im_batch, 0) + return im_batch + + +class MOT_Detector(object): + """ + Args: + pred_config (object): config of model, defined by `Config(model_dir)` + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, + pred_config, + model_dir, + device='CPU', + run_mode='fluid', + trt_min_shape=1, + trt_max_shape=1088, + trt_opt_shape=608, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False): + self.pred_config = pred_config + self.predictor, self.config = load_predictor( + model_dir, + run_mode=run_mode, + device=device, + min_subgraph_size=self.pred_config.min_subgraph_size, + use_dynamic_shape=self.pred_config.use_dynamic_shape, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn) + self.det_times = Timer() + self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 + + def preprocess(self, im): + preprocess_ops = [] + for op_info in self.pred_config.preprocess_infos: + new_op_info = op_info.copy() + op_type = new_op_info.pop('type') + preprocess_ops.append(eval(op_type)(**new_op_info)) + im, im_info = preprocess(im, preprocess_ops) + inputs = create_inputs(im, im_info) + return inputs + + def postprocess(self, boxes, input_shape, im_shape, scale_factor, + threshold): + pred_bboxes = scale_coords(boxes[:, 2:], input_shape, im_shape, + scale_factor) + pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape, scale_factor) + pred_scores = boxes[:, 1:2] + keep_mask = pred_scores[:, 0] >= threshold + return pred_bboxes[keep_mask], pred_scores[keep_mask] + + def predict(self, image, threshold=0.5, warmup=0, repeats=1): + ''' + Args: + image (np.ndarray): image numpy data + threshold (float): threshold of predicted box' score + Returns: + pred_bboxes, pred_scores (np.ndarray) + ''' + self.det_times.preprocess_time_s.start() + inputs = self.preprocess(image) + self.det_times.preprocess_time_s.end() + + pred_bboxes, pred_scores = None, None + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + for i in range(warmup): + self.predictor.run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_handle(output_names[0]) + boxes = boxes_tensor.copy_to_cpu() + + self.det_times.inference_time_s.start() + for i in range(repeats): + self.predictor.run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_handle(output_names[0]) + boxes = boxes_tensor.copy_to_cpu() + self.det_times.inference_time_s.end(repeats=repeats) + + self.det_times.postprocess_time_s.start() + input_shape = inputs['image'].shape[2:] + im_shape = inputs['im_shape'] + scale_factor = inputs['scale_factor'] + pred_bboxes, pred_scores = self.postprocess( + boxes, input_shape, im_shape, scale_factor, threshold) + self.det_times.postprocess_time_s.end() + self.det_times.img_num += 1 + return pred_bboxes, pred_scores + + +class MOT_ReID(object): + def __init__(self, + pred_config, + model_dir, + device='CPU', + run_mode='fluid', + trt_min_shape=1, + trt_max_shape=1088, + trt_opt_shape=608, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False): + self.pred_config = pred_config + self.predictor, self.config = load_predictor( + model_dir, + run_mode=run_mode, + min_subgraph_size=self.pred_config.min_subgraph_size, + device=device, + use_dynamic_shape=self.pred_config.use_dynamic_shape, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn) + self.det_times = Timer() + self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 + + self.tracker = DeepSORTTracker() + + def preprocess(self, crops): + inputs = {} + inputs['crops'] = np.array(crops).astype('float32') + return inputs + + def postprocess(self, bbox_tlwh, pred_scores, features): + detections = [ + Detection(tlwh, score, feat) + for tlwh, score, feat in zip(bbox_tlwh, pred_scores, features) + ] + self.tracker.predict() + online_targets = self.tracker.update(detections) + + online_tlwhs = [] + online_scores = [] + online_ids = [] + for track in online_targets: + if not track.is_confirmed() or track.time_since_update > 1: + continue + online_tlwhs.append(track.to_tlwh()) + online_scores.append(1.0) + online_ids.append(track.track_id) + return online_tlwhs, online_scores, online_ids + + def predict(self, crops, bbox_tlwh, pred_scores, warmup=0, repeats=1): + self.det_times.preprocess_time_s.start() + inputs = self.preprocess(crops) + self.det_times.preprocess_time_s.end() + + pred_dets, pred_embs = None, None + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + for i in range(warmup): + self.predictor.run() + output_names = self.predictor.get_output_names() + feature_tensor = self.predictor.get_output_handle(output_names[0]) + features = feature_tensor.copy_to_cpu() + + self.det_times.inference_time_s.start() + for i in range(repeats): + self.predictor.run() + output_names = self.predictor.get_output_names() + feature_tensor = self.predictor.get_output_handle(output_names[0]) + features = feature_tensor.copy_to_cpu() + self.det_times.inference_time_s.end(repeats=repeats) + + self.det_times.postprocess_time_s.start() + online_tlwhs, online_scores, online_ids = self.postprocess( + bbox_tlwh, pred_scores, features) + self.det_times.postprocess_time_s.end() + self.det_times.img_num += 1 + return online_tlwhs, online_scores, online_ids + + +def predict_image(detector, reid_model, image_list): + results = [] + for i, img_file in enumerate(image_list): + frame = cv2.imread(img_file) + if FLAGS.run_benchmark: + pred_bboxes, pred_scores = detector.predict( + frame, FLAGS.threshold, warmup=10, repeats=10) + cm, gm, gu = get_current_memory_mb() + detector.cpu_mem += cm + detector.gpu_mem += gm + detector.gpu_util += gu + print('Test iter {}, file name:{}'.format(i, img_file)) + else: + pred_bboxes, pred_scores = detector.predict(frame, FLAGS.threshold) + + # process + bbox_tlwh = np.concatenate( + (pred_bboxes[:, 0:2], + pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1), + axis=1) + crops, pred_scores = get_crops( + pred_bboxes, frame, pred_scores, w=64, h=192) + + if FLAGS.run_benchmark: + online_tlwhs, online_scores, online_ids = reid_model.predict( + crops, bbox_tlwh, pred_scores, warmup=10, repeats=10) + else: + online_tlwhs, online_scores, online_ids = reid_model.predict( + crops, bbox_tlwh, pred_scores) + + online_im = mot_vis.plot_tracking( + frame, online_tlwhs, online_ids, online_scores, frame_id=i) + + if FLAGS.save_images: + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + cv2.imwrite(os.path.join(FLAGS.output_dir, img_file), online_im) + + +def predict_video(detector, reid_model, camera_id): + if camera_id != -1: + capture = cv2.VideoCapture(camera_id) + video_name = 'mot_output.mp4' + else: + capture = cv2.VideoCapture(FLAGS.video_file) + video_name = os.path.split(FLAGS.video_file)[-1] + fps = 30 + frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) + print('frame_count', frame_count) + width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) + # yapf: disable + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # yapf: enable + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + out_path = os.path.join(FLAGS.output_dir, video_name) + writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) + frame_id = 0 + timer = MOTTimer() + results = [] + while (1): + ret, frame = capture.read() + if not ret: + break + timer.tic() + pred_bboxes, pred_scores = detector.predict(frame, FLAGS.threshold) + timer.toc() + + bbox_tlwh = np.concatenate( + (pred_bboxes[:, 0:2], + pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1), + axis=1) + crops, pred_scores = get_crops( + pred_bboxes, frame, pred_scores, w=64, h=192) + + online_tlwhs, online_scores, online_ids = reid_model.predict( + crops, bbox_tlwh, pred_scores) + + results.append((frame_id + 1, online_tlwhs, online_scores, online_ids)) + fps = 1. / timer.average_time + online_im = mot_vis.plot_tracking( + frame, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + fps=fps) + if FLAGS.save_images: + save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2]) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + cv2.imwrite( + os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), + online_im) + frame_id += 1 + print('detect frame:%d' % (frame_id)) + im = np.array(online_im) + writer.write(im) + if camera_id != -1: + cv2.imshow('Tracking Detection', im) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + if FLAGS.save_mot_txts: + result_filename = os.path.join(FLAGS.output_dir, + video_name.split('.')[-2] + '.txt') + write_mot_results(result_filename, results) + writer.release() + + +def main(): + pred_config = PredictConfig(FLAGS.model_dir) + detector = MOT_Detector( + pred_config, + FLAGS.model_dir, + device=FLAGS.device, + run_mode=FLAGS.run_mode, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn) + + pred_config = PredictConfig(FLAGS.reid_model_dir) + reid_model = MOT_ReID( + pred_config, + FLAGS.reid_model_dir, + device=FLAGS.device, + run_mode=FLAGS.run_mode, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn) + + # predict from video file or camera video stream + if FLAGS.video_file is not None or FLAGS.camera_id != -1: + predict_video(detector, reid_model, FLAGS.camera_id) + else: + # predict from image + img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) + predict_image(detector, reid_model, img_list) + + if not FLAGS.run_benchmark: + detector.det_times.info(average=True) + reid_model.det_times.info(average=True) + else: + mode = FLAGS.run_mode + det_model_dir = FLAGS.model_dir + det_model_info = { + 'model_name': det_model_dir.strip('/').split('/')[-1], + 'precision': mode.split('_')[-1] + } + bench_log(detector, img_list, det_model_info, name='Det') + + reid_model_dir = FLAGS.reid_model_dir + reid_model_info = { + 'model_name': reid_model_dir.strip('/').split('/')[-1], + 'precision': mode.split('_')[-1] + } + bench_log(reid_model, img_list, reid_model_info, name='ReID') + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + print_arguments(FLAGS) + FLAGS.device = FLAGS.device.upper() + assert FLAGS.device in ['CPU', 'GPU', 'XPU' + ], "device should be CPU, GPU or XPU" + + main() diff --git a/deploy/python/preprocess.py b/deploy/python/preprocess.py index 5e44596f1d8821521732caf5d6fd1e686b31cb69..d4fdd5fe7557217264498ca09d559c5b8cdd4404 100644 --- a/deploy/python/preprocess.py +++ b/deploy/python/preprocess.py @@ -193,6 +193,60 @@ class PadStride(object): return padding_im, im_info +class LetterBoxResize(object): + def __init__(self, target_size): + """ + Resize image to target size, convert normalized xywh to pixel xyxy + format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]). + Args: + target_size (int|list): image target size. + """ + super(LetterBoxResize, self).__init__() + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + + def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)): + # letterbox: resize a rectangular image to a padded rectangular + shape = img.shape[:2] # [height, width] + ratio_h = float(height) / shape[0] + ratio_w = float(width) / shape[1] + ratio = min(ratio_h, ratio_w) + new_shape = (round(shape[1] * ratio), + round(shape[0] * ratio)) # [width, height] + padw = (width - new_shape[0]) / 2 + padh = (height - new_shape[1]) / 2 + top, bottom = round(padh - 0.1), round(padh + 0.1) + left, right = round(padw - 0.1), round(padw + 0.1) + + img = cv2.resize( + img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, + value=color) # padded rectangular + return img, ratio, padw, padh + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + assert len(self.target_size) == 2 + assert self.target_size[0] > 0 and self.target_size[1] > 0 + height, width = self.target_size + h, w = im.shape[:2] + im, ratio, padw, padh = self.letterbox(im, height=height, width=width) + + new_shape = [round(h * ratio), round(w * ratio)] + im_info['im_shape'] = np.array(new_shape, dtype=np.float32) + im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32) + return im, im_info + + def preprocess(im, preprocess_ops): # process image by preprocess_ops im_info = { diff --git a/deploy/python/tracker/deepsort_tracker.py b/deploy/python/tracker/deepsort_tracker.py index c6576cee220bce89eac1cdfd9fdc36ef41e23236..789896e57657bc25c0b470d466549f181c11db75 100644 --- a/deploy/python/tracker/deepsort_tracker.py +++ b/deploy/python/tracker/deepsort_tracker.py @@ -16,6 +16,7 @@ This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_so """ import numpy as np +from ppdet.modeling.mot.motion import KalmanFilter from ppdet.modeling.mot.matching.deepsort_matching import NearestNeighborDistanceMetric from ppdet.modeling.mot.matching.deepsort_matching import iou_cost, min_cost_matching, matching_cascade, gate_cost_matrix from ppdet.modeling.mot.tracker.base_sde_tracker import Track @@ -24,7 +25,6 @@ __all__ = ['DeepSORTTracker'] class DeepSORTTracker(object): - __inject__ = ['motion'] """ DeepSORT tracker @@ -60,7 +60,7 @@ class DeepSORTTracker(object): self.metric = NearestNeighborDistanceMetric(metric_type, matching_threshold, budget) self.max_iou_distance = max_iou_distance - self.motion = motion + self.motion = KalmanFilter() self.tracks = [] self._next_id = 1 diff --git a/deploy/python/utils.py b/deploy/python/utils.py index 7179d43ee3a1ed45ef0cb2157727d7254df1b0a7..ef28401ba4a4bd8465e2147af23c1e30155cb9f9 100644 --- a/deploy/python/utils.py +++ b/deploy/python/utils.py @@ -108,6 +108,12 @@ def argsparser(): '--save_mot_txts', action='store_true', help='Save tracking results (txt).') + parser.add_argument( + "--reid_model_dir", + type=str, + default=None, + help=("Directory include:'model.pdiparams', 'model.pdmodel', " + "'infer_cfg.yml', created by tools/export_model.py.")) parser.add_argument( '--use_dark', type=bool, diff --git a/ppdet/engine/tracker.py b/ppdet/engine/tracker.py index fb803948448a14805a84f47bd937cc2fd495a6c5..4f8a0808bdea6d6bace0296e3f1359076a36d29e 100644 --- a/ppdet/engine/tracker.py +++ b/ppdet/engine/tracker.py @@ -24,7 +24,7 @@ import numpy as np from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight - +from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box from ppdet.modeling.mot.utils import Timer, load_det_results from ppdet.modeling.mot import visualization as mot_vis @@ -188,9 +188,12 @@ class Tracker(object): logger.info('Processing frame {} ({:.2f} fps)'.format( frame_id, 1. / max(1e-5, timer.average_time))) + ori_image = data['ori_image'] + input_shape = data['image'].shape[2:] + im_shape = data['im_shape'] + scale_factor = data['scale_factor'] timer.tic() if not use_detector: - timer.tic() dets = dets_list[frame_id] bbox_tlwh = paddle.to_tensor(dets['bbox'], dtype='float32') pred_scores = paddle.to_tensor(dets['score'], dtype='float32') @@ -203,14 +206,35 @@ class Tracker(object): else: pred_bboxes = [] pred_scores = [] - data.update({ - 'pred_bboxes': pred_bboxes, - 'pred_scores': pred_scores - }) + else: + outs = self.model.detector(data) + if outs['bbox_num'] > 0: + pred_bboxes = scale_coords(outs['bbox'][:, 2:], input_shape, + im_shape, scale_factor) + pred_scores = outs['bbox'][:, 1:2] + else: + pred_bboxes = [] + pred_scores = [] - # forward - timer.tic() - detections = self.model(data) + pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape, + scale_factor) + bbox_tlwh = paddle.concat( + (pred_bboxes[:, 0:2], + pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1), + axis=1) + + crops, pred_scores = get_crops( + pred_bboxes, ori_image, pred_scores, w=64, h=192) + crops = paddle.to_tensor(crops) + pred_scores = paddle.to_tensor(pred_scores) + + data.update({'crops': crops}) + features = self.model(data) + features = features.numpy() + detections = [ + Detection(tlwh, score, feat) + for tlwh, score, feat in zip(bbox_tlwh, pred_scores, features) + ] self.model.tracker.predict() online_targets = self.model.tracker.update(detections) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index fa293cd31a748a31d5b7b483ec8497b4a99438e1..4ce3d263588e4d0672836e9627fd7b9955d3d70b 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -36,7 +36,7 @@ from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.visualizer import visualize_results, save_result from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval -from ppdet.metrics import RBoxMetric +from ppdet.metrics import RBoxMetric, JDEDetMetric from ppdet.data.source.category import get_categories import ppdet.utils.stats as stats @@ -243,6 +243,8 @@ class Trainer(object): len(eval_dataset), self.cfg.num_joints, self.cfg.save_dir) ] + elif self.cfg.metric == 'MOTDet': + self._metrics = [JDEDetMetric(), ] else: logger.warn("Metric not support for metric type {}".format( self.cfg.metric)) @@ -545,6 +547,11 @@ class Trainer(object): "scale_factor": InputSpec( shape=[None, 2], name='scale_factor') }] + if self.cfg.architecture == 'DeepSORT': + input_spec[0].update({ + "crops": InputSpec( + shape=[None, 3, 192, 64], name='crops') + }) # dy2st and save model if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT': diff --git a/ppdet/metrics/mot_metrics.py b/ppdet/metrics/mot_metrics.py index 770d2dea6d42a22ae97b0bc0aaf5686c25dc47af..e1bc8b93b65685acd8ac7153955c86b44df2e921 100644 --- a/ppdet/metrics/mot_metrics.py +++ b/ppdet/metrics/mot_metrics.py @@ -20,12 +20,14 @@ import copy import numpy as np import paddle import paddle.nn.functional as F +from ppdet.modeling.bbox_utils import bbox_iou_np_expand +from .map_utils import ap_per_class from .metrics import Metric from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) -__all__ = ['MOTEvaluator', 'MOTMetric'] +__all__ = ['MOTEvaluator', 'MOTMetric', 'JDEDetMetric'] def read_mot_results(filename, is_gt=False, is_ignore=False): @@ -236,3 +238,67 @@ class MOTMetric(Metric): def get_results(self): return self.strsummary + + +class JDEDetMetric(Metric): + # Note this detection AP metric is different from COCOMetric or VOCMetric, + # and the bboxes coordinates are not scaled to the original image + def __init__(self, overlap_thresh=0.5): + self.overlap_thresh = overlap_thresh + self.reset() + + def reset(self): + self.AP_accum = np.zeros(1) + self.AP_accum_count = np.zeros(1) + + def update(self, inputs, outputs): + bboxes = outputs['bbox'][:, 2:].numpy() + scores = outputs['bbox'][:, 1].numpy() + labels = outputs['bbox'][:, 0].numpy() + bbox_lengths = outputs['bbox_num'].numpy() + if bboxes.shape[0] == 1 and bboxes.sum() == 0.0: + return + + gt_boxes = inputs['gt_bbox'].numpy()[0] + gt_labels = inputs['gt_class'].numpy()[0] + if gt_labels.shape[0] == 0: + return + + correct = [] + detected = [] + for i in range(bboxes.shape[0]): + obj_pred = 0 + pred_bbox = bboxes[i].reshape(1, 4) + # Compute iou with target boxes + iou = bbox_iou_np_expand(pred_bbox, gt_boxes, x1y1x2y2=True)[0] + # Extract index of largest overlap + best_i = np.argmax(iou) + # If overlap exceeds threshold and classification is correct mark as correct + if iou[best_i] > self.overlap_thresh and obj_pred == gt_labels[ + best_i] and best_i not in detected: + correct.append(1) + detected.append(best_i) + else: + correct.append(0) + + # Compute Average Precision (AP) per class + target_cls = list(gt_labels.T[0]) + AP, AP_class, R, P = ap_per_class( + tp=correct, + conf=scores, + pred_cls=np.zeros_like(scores), + target_cls=target_cls) + self.AP_accum_count += np.bincount(AP_class, minlength=1) + self.AP_accum += np.bincount(AP_class, minlength=1, weights=AP) + + def accumulate(self): + logger.info("Accumulating evaluatation results...") + self.map_stat = self.AP_accum[0] / (self.AP_accum_count[0] + 1E-16) + + def log(self): + map_stat = 100. * self.map_stat + logger.info("mAP({:.2f}) = {:.2f}%".format(self.overlap_thresh, + map_stat)) + + def get_results(self): + return self.map_stat diff --git a/ppdet/modeling/architectures/deepsort.py b/ppdet/modeling/architectures/deepsort.py index 66184fb7b18ae1d5ebcf2ed10853d0efee14ba19..066f7a4ce791cf844d6e0124085c72aa0bfb6d1b 100644 --- a/ppdet/modeling/architectures/deepsort.py +++ b/ppdet/modeling/architectures/deepsort.py @@ -61,47 +61,9 @@ class DeepSORT(BaseArch): } def _forward(self): - load_dets = 'pred_bboxes' in self.inputs and 'pred_scores' in self.inputs - - ori_image = self.inputs['ori_image'] - input_shape = self.inputs['image'].shape[2:] - im_shape = self.inputs['im_shape'] - scale_factor = self.inputs['scale_factor'] - - if self.detector and not load_dets: - outs = self.detector(self.inputs) - if outs['bbox_num'] > 0: - pred_bboxes = scale_coords(outs['bbox'][:, 2:], input_shape, - im_shape, scale_factor) - pred_scores = outs['bbox'][:, 1:2] - else: - pred_bboxes = [] - pred_scores = [] - else: - pred_bboxes = self.inputs['pred_bboxes'] - pred_scores = self.inputs['pred_scores'] - - if len(pred_bboxes) > 0: - pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape, - scale_factor) - bbox_tlwh = paddle.concat( - (pred_bboxes[:, 0:2], - pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1), - axis=1) - - crops, pred_scores = get_crops( - pred_bboxes, ori_image, pred_scores, w=64, h=192) - - if len(crops) > 0: - features = self.reid(paddle.to_tensor(crops)) - detections = [Detection(bbox_tlwh[i], conf, features[i])\ - for i, conf in enumerate(pred_scores)] - else: - detections = [] - else: - detections = [] - - return detections + crops = self.inputs['crops'] + features = self.reid(crops) + return features def get_pred(self): return self._forward() diff --git a/ppdet/modeling/architectures/yolo.py b/ppdet/modeling/architectures/yolo.py index 0a035a94acf2cbf2cef8841bc4993a102b2aeb31..d5979e695c6af2d57a85ea8c0cf1fd92076ae409 100644 --- a/ppdet/modeling/architectures/yolo.py +++ b/ppdet/modeling/architectures/yolo.py @@ -1,9 +1,24 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import absolute_import from __future__ import division from __future__ import print_function from ppdet.core.workspace import register, create from .meta_arch import BaseArch +from ..post_process import JDEBBoxPostProcess __all__ = ['YOLOv3'] @@ -39,6 +54,7 @@ class YOLOv3(BaseArch): self.yolo_head = yolo_head self.post_process = post_process self.for_mot = for_mot + self.return_idx = isinstance(post_process, JDEBBoxPostProcess) @classmethod def from_config(cls, cfg, *args, **kwargs): @@ -90,9 +106,13 @@ class YOLOv3(BaseArch): 'emb_feats': emb_feats, } else: - bbox, bbox_num = self.post_process( - yolo_head_outs, self.yolo_head.mask_anchors, - self.inputs['im_shape'], self.inputs['scale_factor']) + if self.return_idx: + _, bbox, bbox_num, _ = self.post_process( + yolo_head_outs, self.yolo_head.mask_anchors) + else: + bbox, bbox_num = self.post_process( + yolo_head_outs, self.yolo_head.mask_anchors, + self.inputs['im_shape'], self.inputs['scale_factor']) output = {'bbox': bbox, 'bbox_num': bbox_num} return output diff --git a/ppdet/modeling/losses/jde_loss.py b/ppdet/modeling/losses/jde_loss.py index 59ace08f2fe7da375d5615ea8e5c02f2ebbbd192..5c3b5a61534e793b243526fabdcf604114ce2512 100644 --- a/ppdet/modeling/losses/jde_loss.py +++ b/ppdet/modeling/losses/jde_loss.py @@ -28,9 +28,10 @@ __all__ = ['JDEDetectionLoss', 'JDEEmbeddingLoss', 'JDELoss'] class JDEDetectionLoss(nn.Layer): __shared__ = ['num_classes'] - def __init__(self, num_classes=1): + def __init__(self, num_classes=1, for_mot=True): super(JDEDetectionLoss, self).__init__() self.num_classes = num_classes + self.for_mot = for_mot def det_loss(self, p_det, anchor, t_conf, t_box): pshape = paddle.shape(p_det) @@ -92,7 +93,17 @@ class JDEDetectionLoss(nn.Layer): loss_conf, loss_box = self.det_loss(p_det, anchor, t_conf, t_box) loss_confs.append(loss_conf) loss_boxes.append(loss_box) - return {'loss_confs': loss_confs, 'loss_boxes': loss_boxes} + if self.for_mot: + return {'loss_confs': loss_confs, 'loss_boxes': loss_boxes} + else: + jde_conf_losses = sum(loss_confs) + jde_box_losses = sum(loss_boxes) + jde_det_losses = { + "loss_conf": jde_conf_losses, + "loss_box": jde_box_losses, + "loss": jde_conf_losses + jde_box_losses, + } + return jde_det_losses @register diff --git a/ppdet/modeling/mot/utils.py b/ppdet/modeling/mot/utils.py index 13ea52d1c188ca3e7a7a54b32a90f1e3cde44424..2426e663aa9a8472865910d1454422bb07a27f39 100644 --- a/ppdet/modeling/mot/utils.py +++ b/ppdet/modeling/mot/utils.py @@ -82,7 +82,7 @@ class Detection(object): def __init__(self, tlwh, confidence, feature): self.tlwh = np.asarray(tlwh, dtype=np.float32) self.confidence = np.asarray(confidence, dtype=np.float32) - self.feature = feature.numpy() + self.feature = feature def to_tlbr(self): """ diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 80eef2547663f1189bc2c77264fc380fa82fd8d6..9c917242d92ccaca431c4ce245b99ed0a259f1b7 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -355,11 +355,7 @@ class JDEBBoxPostProcess(nn.Layer): [[[0.0]]], dtype='float32')) self.fake_boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64')) - def forward(self, - head_out, - anchors, - im_shape=[[608, 1088]], - scale_factor=[[1.0, 1.0]]): + def forward(self, head_out, anchors): """ Decode the bbox and do NMS for JDE model. @@ -389,16 +385,21 @@ class JDEBBoxPostProcess(nn.Layer): yolo_boxes[:, 4:5], shape=[1, 1, len(boxes_idx)]) boxes_idx = boxes_idx[:, 1:] - bbox_pred, bbox_num, nms_keep_idx = self.nms( - yolo_boxes_out, yolo_scores_out, self.num_classes) - if bbox_pred.shape[0] == 0: - bbox_pred = self.fake_bbox_pred - bbox_num = self.fake_bbox_num - nms_keep_idx = self.fake_nms_keep_idx if self.return_idx: + bbox_pred, bbox_num, nms_keep_idx = self.nms( + yolo_boxes_out, yolo_scores_out, self.num_classes) + if bbox_pred.shape[0] == 0: + bbox_pred = self.fake_bbox_pred + bbox_num = self.fake_bbox_num + nms_keep_idx = self.fake_nms_keep_idx return boxes_idx, bbox_pred, bbox_num, nms_keep_idx else: - return bbox_pred, bbox_num + bbox_pred, bbox_num, _ = self.nms(yolo_boxes_out, yolo_scores_out, + self.num_classes) + if bbox_pred.shape[0] == 0: + bbox_pred = self.fake_bbox_pred + bbox_num = self.fake_bbox_num + return _, bbox_pred, bbox_num, _ @register diff --git a/ppdet/modeling/reid/pyramidal_embedding.py b/ppdet/modeling/reid/pyramidal_embedding.py index d74cd828ffce72e326d50013a4f72bc82069dbe1..f099a9655cb8080088ee69a517ceb2b441b81c9b 100644 --- a/ppdet/modeling/reid/pyramidal_embedding.py +++ b/ppdet/modeling/reid/pyramidal_embedding.py @@ -78,8 +78,7 @@ class PCBPyramid(nn.Layer): for idx_branches in range(self.num_branches): if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]): idx_levels += 1 - if self.used_levels[idx_levels] == 0: - continue + pyramid_conv_list.append( nn.Sequential( nn.Conv2D(input_ch, num_conv_out_channels, 1), @@ -89,8 +88,7 @@ class PCBPyramid(nn.Layer): for idx_branches in range(self.num_branches): if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]): idx_levels += 1 - if self.used_levels[idx_levels] == 0: - continue + name = "Linear_branch_id_{}".format(idx_branches) fc = nn.Linear( in_features=num_conv_out_channels, @@ -113,8 +111,6 @@ class PCBPyramid(nn.Layer): for idx_branches in range(self.num_branches): if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]): idx_levels += 1 - if self.used_levels[idx_levels] == 0: - continue idx_in_each_level = idx_branches - sum(self.num_in_each_level[ 0:idx_levels]) stripe_size_in_each_level = each_stripe_size * (idx_levels + 1)