diff --git a/configs/datasets/mot.yml b/configs/datasets/mot.yml index e26919566028e0ff511c2e43400aee69cf8418d5..d6056968d6ed794a10e30a1db16db63c2106270c 100644 --- a/configs/datasets/mot.yml +++ b/configs/datasets/mot.yml @@ -27,24 +27,10 @@ EvalMOTDataset: task: MOT16_train dataset_dir: dataset/mot data_root: MOT16/images/train - keep_ori_im: False # set True if save visualization images or video + keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT # for MOT video inference TestMOTDataset: !MOTVideoDataset dataset_dir: dataset/mot keep_ori_im: True # set True if save visualization images or video - - -# for detection or reid evaluation, following the JDE paper, but no use in MOT evaluation -EvalDataset: - !MOTDataSet - dataset_dir: dataset/mot - image_lists: ['citypersons.val', 'caltech.val'] # for detection evaluation - # image_lists: ['caltech.10k.val', 'cuhksysu.val', 'prw.val'] # for reid evaluation - data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide'] - -# for detection inference, no use in MOT inference -TestDataset: - !ImageFolder - dataset_dir: dataset/mot diff --git a/configs/mot/deepsort/_base_/deepsort_reader_1088x608.yml b/configs/mot/deepsort/_base_/deepsort_reader_1088x608.yml index d0920dbcef74914701009929a60855f88d56517e..0ef44508561b8babc413b13deff51447f31829c0 100644 --- a/configs/mot/deepsort/_base_/deepsort_reader_1088x608.yml +++ b/configs/mot/deepsort/_base_/deepsort_reader_1088x608.yml @@ -1,14 +1,3 @@ -TestReader: - inputs_def: - image_shape: [3, 608, 1088] - sample_transforms: - - Decode: {} - - LetterBoxResize: {target_size: [608, 1088]} - - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} - - Permute: {} - batch_size: 1 - - EvalMOTReader: sample_transforms: - Decode: {} @@ -22,7 +11,6 @@ TestMOTReader: inputs_def: image_shape: [3, 608, 1088] sample_transforms: - - Decode: {} - LetterBoxResize: {target_size: [608, 1088]} - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} - Permute: {} diff --git a/configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml b/configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml index e79904a0cc376d59c8bea8ecbbca1629a58debef..442ced2bbeee7e1078e801658ed4faca0a7d3b2c 100644 --- a/configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml +++ b/configs/mot/deepsort/_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml @@ -1,48 +1,46 @@ architecture: DeepSORT -pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/DarkNet53_pretrained.pdparams +pretrain_weights: None DeepSORT: - detector: YOLOv3 + detector: YOLOv3 # JDE version reid: PCBPyramid tracker: DeepSORTTracker +# JDE version for MOT dataset YOLOv3: backbone: DarkNet neck: YOLOv3FPN yolo_head: YOLOv3Head - post_process: BBoxPostProcess + post_process: JDEBBoxPostProcess DarkNet: depth: 53 return_idx: [2, 3, 4] + freeze_norm: True -# use default config -# YOLOv3FPN: +YOLOv3FPN: + freeze_norm: True YOLOv3Head: - anchors: [[10, 13], [16, 30], [33, 23], - [30, 61], [62, 45], [59, 119], - [116, 90], [156, 198], [373, 326]] - anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] - loss: YOLOv3Loss - -YOLOv3Loss: - ignore_thresh: 0.7 - downsample: [32, 16, 8] - label_smooth: false - -BBoxPostProcess: + anchors: [[128,384], [180,540], [256,640], [512,640], + [32,96], [45,135], [64,192], [90,271], + [8,24], [11,34], [16,48], [23,68]] + anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + loss: JDEDetectionLoss + +JDEBBoxPostProcess: decode: - name: YOLOBox - conf_thresh: 0.2 + name: JDEBox + conf_thresh: 0.3 downsample_ratio: 32 - clip_bbox: true nms: name: MultiClassNMS - keep_top_k: 100 + keep_top_k: 500 score_threshold: 0.01 - nms_threshold: 0.45 - nms_top_k: 1000 + nms_threshold: 0.5 + nms_top_k: 2000 + normalized: true + return_idx: false PCBPyramid: num_conv_out_channels: 128 @@ -52,7 +50,7 @@ DeepSORTTracker: budget: 100 max_age: 70 n_init: 3 - metric_type: 'cosine' + metric_type: cosine matching_threshold: 0.2 max_iou_distance: 0.9 - motion: 'KalmanFilter' + motion: KalmanFilter diff --git a/configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml b/configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml index 33f0ca16a82c746326ccb797539358d9199186a3..f6a4bd19cfc16021396841ddb9aec5c0bb293263 100644 --- a/configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml +++ b/configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml @@ -5,14 +5,12 @@ _BASE_: [ '_base_/deepsort_reader_1088x608.yml', ] -metric: MOT - EvalMOTDataset: !MOTImageFolder task: MOT16_train dataset_dir: dataset/mot data_root: MOT16/images/train - keep_ori_im: True # set True if used in DeepSORT + keep_ori_im: True # set as True in DeepSORT det_weights: None reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams diff --git a/configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml b/configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml new file mode 100644 index 0000000000000000000000000000000000000000..ca8ec5d56b7ae8b31e93fc435b0a86a9440dd02f --- /dev/null +++ b/configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml @@ -0,0 +1,28 @@ +_BASE_: [ + '../../datasets/mot.yml', + '../../runtime.yml', + '_base_/deepsort_yolov3_darknet53_pcb_pyramid_r101.yml', + '_base_/deepsort_reader_1088x608.yml', +] + +EvalMOTDataset: + !MOTImageFolder + task: MOT16_train + dataset_dir: dataset/mot + data_root: MOT16/images/train + keep_ori_im: True # set as True in DeepSORT + +det_weights: https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams +reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams + +DeepSORT: + detector: YOLOv3 + reid: PCBPyramid + tracker: DeepSORTTracker + +# JDE version for MOT dataset +YOLOv3: + backbone: DarkNet + neck: YOLOv3FPN + yolo_head: YOLOv3Head + post_process: JDEBBoxPostProcess diff --git a/configs/mot/jde/_base_/jde_reader_1088x608.yml b/configs/mot/jde/_base_/jde_reader_1088x608.yml index 41709b9bcc76cae674a0773b98245232fa4a8b88..3e41b3721059e57bfc979baea66840bd54b4df3c 100644 --- a/configs/mot/jde/_base_/jde_reader_1088x608.yml +++ b/configs/mot/jde/_base_/jde_reader_1088x608.yml @@ -28,38 +28,6 @@ TrainReader: use_shared_memory: true -EvalReader: - sample_transforms: - - Decode: {} - - LetterBoxResize: {target_size: [608, 1088]} - - BboxXYXY2XYWH: {} - - NormalizeBox: {} - - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} - - Permute: {} - batch_transforms: - - Gt2JDETargetMax: - anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] - anchors: [[[128,384], [180,540], [256,640], [512,640]], - [[32,96], [45,135], [64,192], [90,271]], - [[8,24], [11,34], [16,48], [23,68]]] - downsample_ratios: [32, 16, 8] - max_iou_thresh: 0.60 - - BboxCXCYWH2XYXY: {} - - Norm2PixelBbox: {} - batch_size: 1 - - -TestReader: - inputs_def: - image_shape: [3, 608, 1088] - sample_transforms: - - Decode: {} - - LetterBoxResize: {target_size: [608, 1088]} - - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} - - Permute: {} - batch_size: 1 - - EvalMOTReader: sample_transforms: - Decode: {} diff --git a/configs/mot/jde/_base_/jde_reader_576x320.yml b/configs/mot/jde/_base_/jde_reader_576x320.yml index 4b204f7ba80f170a00ed4c74d8960418ddaa44fa..fc50f44001df0809f340617fc4bab9b50ceff220 100644 --- a/configs/mot/jde/_base_/jde_reader_576x320.yml +++ b/configs/mot/jde/_base_/jde_reader_576x320.yml @@ -28,38 +28,6 @@ TrainReader: use_shared_memory: true -EvalReader: - sample_transforms: - - Decode: {} - - LetterBoxResize: {target_size: [320, 576]} - - BboxXYXY2XYWH: {} - - NormalizeBox: {} - - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} - - Permute: {} - batch_transforms: - - Gt2JDETargetMax: - anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] - anchors: [[[85,255], [120,320], [170,320], [340,320]], - [[21,64], [30,90], [43,128], [60,180]], - [[6,16], [8,23], [11,32], [16,45]]] - downsample_ratios: [32, 16, 8] - max_iou_thresh: 0.60 - - BboxCXCYWH2XYXY: {} - - Norm2PixelBbox: {} - batch_size: 1 - - -TestReader: - inputs_def: - image_shape: [3, 320, 576] - sample_transforms: - - Decode: {} - - LetterBoxResize: {target_size: [320, 576]} - - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} - - Permute: {} - batch_size: 1 - - EvalMOTReader: sample_transforms: - Decode: {} diff --git a/configs/mot/jde/_base_/jde_reader_864x480.yml b/configs/mot/jde/_base_/jde_reader_864x480.yml index 2f6b822ac0cda560285c9631be72f858738b78d2..9178fa2b2d42319d52f2a2bd158b9d7be9095ace 100644 --- a/configs/mot/jde/_base_/jde_reader_864x480.yml +++ b/configs/mot/jde/_base_/jde_reader_864x480.yml @@ -28,38 +28,6 @@ TrainReader: use_shared_memory: true -EvalReader: - sample_transforms: - - Decode: {} - - LetterBoxResize: {target_size: [480, 864]} - - BboxXYXY2XYWH: {} - - NormalizeBox: {} - - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} - - Permute: {} - batch_transforms: - - Gt2JDETargetMax: - anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] - anchors: [[[102,305], [143, 429], [203,508], [407,508]], - [[25,76], [36,107], [51,152], [71,215]], - [[6,19], [9,27], [13,38], [18,54]]] - downsample_ratios: [32, 16, 8] - max_iou_thresh: 0.60 - - BboxCXCYWH2XYXY: {} - - Norm2PixelBbox: {} - batch_size: 1 - - -TestReader: - inputs_def: - image_shape: [3, 480, 864] - sample_transforms: - - Decode: {} - - LetterBoxResize: {target_size: [480, 864]} - - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} - - Permute: {} - batch_size: 1 - - EvalMOTReader: sample_transforms: - Decode: {} diff --git a/deploy/python/mot_infer.py b/deploy/python/mot_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..557af13e4c228d90738fe68d581d45cb1cc8eb90 --- /dev/null +++ b/deploy/python/mot_infer.py @@ -0,0 +1,366 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import yaml +import cv2 +import numpy as np +import paddle +from benchmark_utils import PaddleInferBenchmark +from preprocess import preprocess, NormalizeImage, Permute +from mot_preprocess import LetterBoxResize + +from tracker import JDETracker +from ppdet.modeling.mot import visualization as mot_vis +from ppdet.modeling.mot.utils import Timer as MOTTimer + +from paddle.inference import Config +from paddle.inference import create_predictor +from utils import argsparser, Timer, get_current_memory_mb +from infer import get_test_images, print_arguments + +# Global dictionary +MOT_SUPPORT_MODELS = { + 'JDE', + 'FairMOT', +} + + +class MOT_Detector(object): + """ + Args: + pred_config (object): config of model, defined by `Config(model_dir)` + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + use_gpu (bool): whether use gpu + run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, + pred_config, + model_dir, + use_gpu=False, + run_mode='fluid', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1088, + trt_opt_shape=608, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False): + self.pred_config = pred_config + self.predictor, self.config = load_predictor( + model_dir, + run_mode=run_mode, + batch_size=batch_size, + min_subgraph_size=self.pred_config.min_subgraph_size, + use_gpu=use_gpu, + use_dynamic_shape=self.pred_config.use_dynamic_shape, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn) + self.det_times = Timer() + self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0 + self.tracker = JDETracker() + + def preprocess(self, im): + preprocess_ops = [] + for op_info in self.pred_config.preprocess_infos: + new_op_info = op_info.copy() + op_type = new_op_info.pop('type') + preprocess_ops.append(eval(op_type)(**new_op_info)) + im, im_info = preprocess(im, preprocess_ops) + inputs = create_inputs(im, im_info) + return inputs + + def postprocess(self, pred_dets, pred_embs): + online_targets = self.tracker.update(pred_dets, pred_embs) + online_tlwhs, online_ids = [], [] + for t in online_targets: + tlwh = t.tlwh + tid = t.track_id + vertical = tlwh[2] / tlwh[3] > 1.6 + if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical: + online_tlwhs.append(tlwh) + online_ids.append(tid) + return online_tlwhs, online_ids + + def predict(self, image, threshold=0.5, repeats=1): + ''' + Args: + image (dict): dict(['image', 'im_shape', 'scale_factor']) + threshold (float): threshold of predicted box' score + Returns: + online_tlwhs, online_ids (np.ndarray) + ''' + self.det_times.preprocess_time_s.start() + inputs = self.preprocess(image) + self.det_times.preprocess_time_s.end() + pred_dets, pred_embs = None, None + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + self.det_times.inference_time_s.start() + for i in range(repeats): + self.predictor.run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_handle(output_names[0]) + pred_dets = boxes_tensor.copy_to_cpu() + embs_tensor = self.predictor.get_output_handle(output_names[1]) + pred_embs = embs_tensor.copy_to_cpu() + + self.det_times.inference_time_s.end(repeats=repeats) + + self.det_times.postprocess_time_s.start() + online_tlwhs, online_ids = self.postprocess(pred_dets, pred_embs) + self.det_times.postprocess_time_s.end() + self.det_times.img_num += 1 + return online_tlwhs, online_ids + + +def create_inputs(im, im_info): + """generate input for different model type + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + model_arch (str): model type + Returns: + inputs (dict): input of model + """ + inputs = {} + inputs['image'] = np.array((im, )).astype('float32') + inputs['im_shape'] = np.array((im_info['im_shape'], )).astype('float32') + inputs['scale_factor'] = np.array( + (im_info['scale_factor'], )).astype('float32') + return inputs + + +class PredictConfig_MOT(): + """set config of preprocess, postprocess and visualize + Args: + model_dir (str): root path of model.yml + """ + + def __init__(self, model_dir): + # parsing Yaml config for Preprocess + deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + with open(deploy_file) as f: + yml_conf = yaml.safe_load(f) + self.check_model(yml_conf) + self.arch = yml_conf['arch'] + self.preprocess_infos = yml_conf['Preprocess'] + self.min_subgraph_size = yml_conf['min_subgraph_size'] + self.labels = yml_conf['label_list'] + self.mask = False + self.use_dynamic_shape = yml_conf['use_dynamic_shape'] + if 'mask' in yml_conf: + self.mask = yml_conf['mask'] + self.print_config() + + def check_model(self, yml_conf): + """ + Raises: + ValueError: loaded model not in supported model type + """ + for support_model in MOT_SUPPORT_MODELS: + if support_model in yml_conf['arch']: + return True + raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ + 'arch'], MOT_SUPPORT_MODELS)) + + def print_config(self): + print('----------- Model Configuration -----------') + print('%s: %s' % ('Model Arch', self.arch)) + print('%s: ' % ('Transform Order')) + for op_info in self.preprocess_infos: + print('--%s: %s' % ('transform op', op_info['type'])) + print('--------------------------------------------') + + +def load_predictor(model_dir, + run_mode='fluid', + batch_size=1, + use_gpu=False, + min_subgraph_size=3, + use_dynamic_shape=False, + trt_min_shape=1, + trt_max_shape=1088, + trt_opt_shape=608, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False): + """set AnalysisConfig, generate AnalysisPredictor + Args: + model_dir (str): root path of __model__ and __params__ + run_mode (str): mode of running(fluid/trt_fp32/trt_fp16/trt_int8) + batch_size (int): size of pre batch in inference + use_gpu (bool): whether use gpu + use_dynamic_shape (bool): use dynamic shape or not + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + Returns: + predictor (PaddlePredictor): AnalysisPredictor + Raises: + ValueError: predict by TensorRT need use_gpu == True. + """ + if not use_gpu and not run_mode == 'fluid': + raise ValueError( + "Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}" + .format(run_mode, use_gpu)) + config = Config( + os.path.join(model_dir, 'model.pdmodel'), + os.path.join(model_dir, 'model.pdiparams')) + precision_map = { + 'trt_int8': Config.Precision.Int8, + 'trt_fp32': Config.Precision.Float32, + 'trt_fp16': Config.Precision.Half + } + if use_gpu: + # initial GPU memory(M), device ID + config.enable_use_gpu(200, 0) + # optimize graph and fuse op + config.switch_ir_optim(True) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(cpu_threads) + if enable_mkldnn: + try: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + except Exception as e: + print( + "The current environment does not support `mkldnn`, so disable mkldnn." + ) + pass + + if run_mode in precision_map.keys(): + config.enable_tensorrt_engine( + workspace_size=1 << 10, + max_batch_size=batch_size, + min_subgraph_size=min_subgraph_size, + precision_mode=precision_map[run_mode], + use_static=False, + use_calib_mode=trt_calib_mode) + + if use_dynamic_shape: + min_input_shape = {'image': [1, 3, trt_min_shape, trt_min_shape]} + max_input_shape = {'image': [1, 3, trt_max_shape, trt_max_shape]} + opt_input_shape = {'image': [1, 3, trt_opt_shape, trt_opt_shape]} + config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, + opt_input_shape) + print('trt set dynamic shape done!') + + # disable print log when predict + config.disable_glog_info() + # enable shared memory + config.enable_memory_optim() + # disable feed, fetch OP, needed by zero_copy_run + config.switch_use_feed_fetch_ops(False) + predictor = create_predictor(config) + return predictor, config + + +def predict_video(detector, camera_id): + if camera_id != -1: + capture = cv2.VideoCapture(camera_id) + video_name = 'mot_output.mp4' + else: + capture = cv2.VideoCapture(FLAGS.video_file) + video_name = os.path.split(FLAGS.video_file)[-1] + fps = 30 + frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) + print('frame_count', frame_count) + width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) + # yapf: disable + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # yapf: enable + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + out_path = os.path.join(FLAGS.output_dir, video_name) + writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) + frame_id = 0 + timer = MOTTimer() + while (1): + ret, frame = capture.read() + if not ret: + break + timer.tic() + online_tlwhs, online_ids = detector.predict(frame, FLAGS.threshold) + timer.toc() + + online_im = mot_vis.plot_tracking( + frame, + online_tlwhs, + online_ids, + frame_id=frame_id, + fps=1. / timer.average_time) + frame_id += 1 + print('detect frame:%d' % (frame_id)) + im = np.array(online_im) + writer.write(im) + if camera_id != -1: + cv2.imshow('Tracking Detection', im) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + writer.release() + + +def main(): + pred_config = PredictConfig_MOT(FLAGS.model_dir) + detector = MOT_Detector( + pred_config, + FLAGS.model_dir, + use_gpu=FLAGS.use_gpu, + run_mode=FLAGS.run_mode, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn) + + # predict from video file or camera video stream + if FLAGS.video_file is not None or FLAGS.camera_id != -1: + predict_video(detector, FLAGS.camera_id) + else: + print('MOT models do not support predict single image.') + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + print_arguments(FLAGS) + + main() diff --git a/deploy/python/mot_preprocess.py b/deploy/python/mot_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..55e248efb0cb0ddad48aa3fc1de017d8ba7f1b1e --- /dev/null +++ b/deploy/python/mot_preprocess.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +class LetterBoxResize(object): + def __init__(self, target_size): + """ + Resize image to target size, convert normalized xywh to pixel xyxy + format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]). + Args: + target_size (int|list): image target size. + """ + super(LetterBoxResize, self).__init__() + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + + def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)): + # letterbox: resize a rectangular image to a padded rectangular + shape = img.shape[:2] # [height, width] + ratio_h = float(height) / shape[0] + ratio_w = float(width) / shape[1] + ratio = min(ratio_h, ratio_w) + new_shape = (round(shape[1] * ratio), + round(shape[0] * ratio)) # [width, height] + padw = (width - new_shape[0]) / 2 + padh = (height - new_shape[1]) / 2 + top, bottom = round(padh - 0.1), round(padh + 0.1) + left, right = round(padw - 0.1), round(padw + 0.1) + + img = cv2.resize( + img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, + value=color) # padded rectangular + return img, ratio, padw, padh + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + assert len(self.target_size) == 2 + assert self.target_size[0] > 0 and self.target_size[1] > 0 + height, width = self.target_size + h, w = im.shape[:2] + im, ratio, padw, padh = self.letterbox(im, height=height, width=width) + + new_shape = [round(h * ratio), round(w * ratio)] + im_info['im_shape'] = np.array(new_shape, dtype=np.float32) + im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32) + return im, im_info diff --git a/deploy/python/tracker/__init__.py b/deploy/python/tracker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c8eb0c0d1d6660c29dfe69e13b76729e68be5ca5 --- /dev/null +++ b/deploy/python/tracker/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import jde_tracker +from . import deepsort_tracker + +from .jde_tracker import * +from .deepsort_tracker import * diff --git a/deploy/python/tracker/deepsort_tracker.py b/deploy/python/tracker/deepsort_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..c6576cee220bce89eac1cdfd9fdc36ef41e23236 --- /dev/null +++ b/deploy/python/tracker/deepsort_tracker.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/tracker.py +""" + +import numpy as np +from ppdet.modeling.mot.matching.deepsort_matching import NearestNeighborDistanceMetric +from ppdet.modeling.mot.matching.deepsort_matching import iou_cost, min_cost_matching, matching_cascade, gate_cost_matrix +from ppdet.modeling.mot.tracker.base_sde_tracker import Track + +__all__ = ['DeepSORTTracker'] + + +class DeepSORTTracker(object): + __inject__ = ['motion'] + """ + DeepSORT tracker + + Args: + img_size (list): input image size, [h, w] + budget (int): If not None, fix samples per class to at most this number. + Removes the oldest samples when the budget is reached. + max_age (int): maximum number of missed misses before a track is deleted + n_init (float): Number of frames that a track remains in initialization + phase. Number of consecutive detections before the track is confirmed. + The track state is set to `Deleted` if a miss occurs within the first + `n_init` frames. + metric_type (str): either "euclidean" or "cosine", the distance metric + used for measurement to track association. + matching_threshold (float): samples with larger distance are + considered an invalid match. + max_iou_distance (float): max iou distance threshold + motion (object): KalmanFilter instance + """ + + def __init__(self, + img_size=[608, 1088], + budget=100, + max_age=30, + n_init=3, + metric_type='cosine', + matching_threshold=0.2, + max_iou_distance=0.7, + motion='KalmanFilter'): + self.img_size = img_size + self.max_age = max_age + self.n_init = n_init + self.metric = NearestNeighborDistanceMetric(metric_type, + matching_threshold, budget) + self.max_iou_distance = max_iou_distance + self.motion = motion + + self.tracks = [] + self._next_id = 1 + + def predict(self): + """ + Propagate track state distributions one time step forward. + This function should be called once every time step, before `update`. + """ + for track in self.tracks: + track.predict(self.motion) + + def update(self, detections): + """ + Perform measurement update and track management. + Args: + detections (list): List[ppdet.modeling.mot.utils.Detection] + A list of detections at the current time step. + """ + # Run matching cascade. + matches, unmatched_tracks, unmatched_detections = \ + self._match(detections) + + # Update track set. + for track_idx, detection_idx in matches: + self.tracks[track_idx].update(self.motion, + detections[detection_idx]) + for track_idx in unmatched_tracks: + self.tracks[track_idx].mark_missed() + for detection_idx in unmatched_detections: + self._initiate_track(detections[detection_idx]) + self.tracks = [t for t in self.tracks if not t.is_deleted()] + + # Update distance metric. + active_targets = [t.track_id for t in self.tracks if t.is_confirmed()] + features, targets = [], [] + for track in self.tracks: + if not track.is_confirmed(): + continue + features += track.features + targets += [track.track_id for _ in track.features] + track.features = [] + self.metric.partial_fit( + np.asarray(features), np.asarray(targets), active_targets) + output_stracks = self.tracks + return output_stracks + + def _match(self, detections): + def gated_metric(tracks, dets, track_indices, detection_indices): + features = np.array([dets[i].feature for i in detection_indices]) + targets = np.array([tracks[i].track_id for i in track_indices]) + cost_matrix = self.metric.distance(features, targets) + cost_matrix = gate_cost_matrix(self.motion, cost_matrix, tracks, + dets, track_indices, + detection_indices) + return cost_matrix + + # Split track set into confirmed and unconfirmed tracks. + confirmed_tracks = [ + i for i, t in enumerate(self.tracks) if t.is_confirmed() + ] + unconfirmed_tracks = [ + i for i, t in enumerate(self.tracks) if not t.is_confirmed() + ] + + # Associate confirmed tracks using appearance features. + matches_a, unmatched_tracks_a, unmatched_detections = \ + matching_cascade( + gated_metric, self.metric.matching_threshold, self.max_age, + self.tracks, detections, confirmed_tracks) + + # Associate remaining tracks together with unconfirmed tracks using IOU. + iou_track_candidates = unconfirmed_tracks + [ + k for k in unmatched_tracks_a + if self.tracks[k].time_since_update == 1 + ] + unmatched_tracks_a = [ + k for k in unmatched_tracks_a + if self.tracks[k].time_since_update != 1 + ] + matches_b, unmatched_tracks_b, unmatched_detections = \ + min_cost_matching( + iou_cost, self.max_iou_distance, self.tracks, + detections, iou_track_candidates, unmatched_detections) + + matches = matches_a + matches_b + unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b)) + return matches, unmatched_tracks, unmatched_detections + + def _initiate_track(self, detection): + mean, covariance = self.motion.initiate(detection.to_xyah()) + self.tracks.append( + Track(mean, covariance, self._next_id, self.n_init, self.max_age, + detection.feature)) + self._next_id += 1 diff --git a/deploy/python/tracker/jde_tracker.py b/deploy/python/tracker/jde_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..907906f0e322fddeed4b482fcc4ad9e50ebb996c --- /dev/null +++ b/deploy/python/tracker/jde_tracker.py @@ -0,0 +1,241 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py +""" +import numpy as np +from ppdet.modeling.mot.matching import jde_matching as matching +from ppdet.modeling.mot.motion import KalmanFilter +from ppdet.modeling.mot.tracker.base_jde_tracker import TrackState, BaseTrack, STrack +from ppdet.modeling.mot.tracker.base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks + +__all__ = ['JDETracker'] + + +class JDETracker(object): + __inject__ = ['motion'] + """ + JDE tracker + + Args: + det_thresh (float): threshold of detection score + track_buffer (int): buffer for tracker + min_box_area (int): min box area to filter out low quality boxes + tracked_thresh (float): linear assignment threshold of tracked + stracks and detections + r_tracked_thresh (float): linear assignment threshold of + tracked stracks and unmatched detections + unconfirmed_thresh (float): linear assignment threshold of + unconfirmed stracks and unmatched detections + motion (object): KalmanFilter instance + conf_thres (float): confidence threshold for tracking + metric_type (str): either "euclidean" or "cosine", the distance metric + used for measurement to track association. + """ + + def __init__(self, + det_thresh=0.3, + track_buffer=30, + min_box_area=200, + tracked_thresh=0.7, + r_tracked_thresh=0.5, + unconfirmed_thresh=0.7, + motion='KalmanFilter', + conf_thres=0, + metric_type='euclidean'): + self.det_thresh = det_thresh + self.track_buffer = track_buffer + self.min_box_area = min_box_area + self.tracked_thresh = tracked_thresh + self.r_tracked_thresh = r_tracked_thresh + self.unconfirmed_thresh = unconfirmed_thresh + self.motion = KalmanFilter() + self.conf_thres = conf_thres + self.metric_type = metric_type + + self.frame_id = 0 + self.tracked_stracks = [] + self.lost_stracks = [] + self.removed_stracks = [] + + self.max_time_lost = 0 + # max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer) + + def update(self, pred_dets, pred_embs): + """ + Processes the image frame and finds bounding box(detections). + Associates the detection with corresponding tracklets and also handles + lost, removed, refound and active tracklets. + + Args: + pred_dets (Tensor): Detection results of the image, shape is [N, 5]. + pred_embs (Tensor): Embedding results of the image, shape is [N, 512]. + + Return: + output_stracks (list): The list contains information regarding the + online_tracklets for the recieved image tensor. + """ + self.frame_id += 1 + activated_starcks = [] + # for storing active tracks, for the current frame + refind_stracks = [] + # Lost Tracks whose detections are obtained in the current frame + lost_stracks = [] + # The tracks which are not obtained in the current frame but are not + # removed. (Lost for some time lesser than the threshold for removing) + removed_stracks = [] + + remain_inds = np.nonzero(pred_dets[:, 4] > self.conf_thres) + if len(remain_inds) == 0: + pred_dets = np.zeros([0, 1]) + pred_embs = np.zeros([0, 1]) + else: + pred_dets = pred_dets[remain_inds] + pred_embs = pred_embs[remain_inds] + + # Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]] + empty_pred = True if len(pred_dets) == 1 and np.sum( + pred_dets) == 0.0 else False + """ Step 1: Network forward, get detections & embeddings""" + if len(pred_dets) > 0 and not empty_pred: + detections = [ + STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) + for (tlbrs, f) in zip(pred_dets, pred_embs) + ] + else: + detections = [] + ''' Add newly detected tracklets to tracked_stracks''' + unconfirmed = [] + tracked_stracks = [] # type: list[STrack] + for track in self.tracked_stracks: + if not track.is_activated: + # previous tracks which are not active in the current frame are added in unconfirmed list + unconfirmed.append(track) + else: + # Active tracks are added to the local list 'tracked_stracks' + tracked_stracks.append(track) + """ Step 2: First association, with embedding""" + # Combining currently tracked_stracks and lost_stracks + strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) + # Predict the current location with KF + STrack.multi_predict(strack_pool, self.motion) + + dists = matching.embedding_distance( + strack_pool, detections, metric=self.metric_type) + dists = matching.fuse_motion(self.motion, dists, strack_pool, + detections) + # The dists is the list of distances of the detection with the tracks in strack_pool + matches, u_track, u_detection = matching.linear_assignment( + dists, thresh=self.tracked_thresh) + # The matches is the array for corresponding matches of the detection with the corresponding strack_pool + + for itracked, idet in matches: + # itracked is the id of the track and idet is the detection + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + # If the track is active, add the detection to the track + track.update(detections[idet], self.frame_id) + activated_starcks.append(track) + else: + # We have obtained a detection from a track which is not active, + # hence put the track in refind_stracks list + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + # None of the steps below happen if there are no undetected tracks. + """ Step 3: Second association, with IOU""" + detections = [detections[i] for i in u_detection] + # detections is now a list of the unmatched detections + r_tracked_stracks = [] + # This is container for stracks which were tracked till the previous + # frame but no detection was found for it in the current frame. + + for i in u_track: + if strack_pool[i].state == TrackState.Tracked: + r_tracked_stracks.append(strack_pool[i]) + dists = matching.iou_distance(r_tracked_stracks, detections) + matches, u_track, u_detection = matching.linear_assignment( + dists, thresh=self.r_tracked_thresh) + # matches is the list of detections which matched with corresponding + # tracks by IOU distance method. + + for itracked, idet in matches: + track = r_tracked_stracks[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + # Same process done for some unmatched detections, but now considering IOU_distance as measure + + for it in u_track: + track = r_tracked_stracks[it] + if not track.state == TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + # If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost + '''Deal with unconfirmed tracks, usually tracks with only one beginning frame''' + detections = [detections[i] for i in u_detection] + dists = matching.iou_distance(unconfirmed, detections) + matches, u_unconfirmed, u_detection = matching.linear_assignment( + dists, thresh=self.unconfirmed_thresh) + for itracked, idet in matches: + unconfirmed[itracked].update(detections[idet], self.frame_id) + activated_starcks.append(unconfirmed[itracked]) + + # The tracks which are yet not matched + for it in u_unconfirmed: + track = unconfirmed[it] + track.mark_removed() + removed_stracks.append(track) + + # after all these confirmation steps, if a new detection is found, it is initialized for a new track + """ Step 4: Init new stracks""" + for inew in u_detection: + track = detections[inew] + if track.score < self.det_thresh: + continue + track.activate(self.motion, self.frame_id) + activated_starcks.append(track) + """ Step 5: Update state""" + # If the tracks are lost for more frames than the threshold number, the tracks are removed. + for track in self.lost_stracks: + if self.frame_id - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + # Update the self.tracked_stracks and self.lost_stracks using the updates in this step. + self.tracked_stracks = [ + t for t in self.tracked_stracks if t.state == TrackState.Tracked + ] + self.tracked_stracks = joint_stracks(self.tracked_stracks, + activated_starcks) + self.tracked_stracks = joint_stracks(self.tracked_stracks, + refind_stracks) + + self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) + self.removed_stracks.extend(removed_stracks) + self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks( + self.tracked_stracks, self.lost_stracks) + # get scores of lost tracks + output_stracks = [ + track for track in self.tracked_stracks if track.is_activated + ] + + return output_stracks diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 40a4fdc98a56840d3bc483b4f7936de01e37e087..f49a0285cb1333de5beb128dbad3b1cc3d4a7072 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -39,9 +39,13 @@ TRT_MIN_SUBGRAPH = { 'SOLOv2': 60, 'HigherHRNet': 3, 'HRNet': 3, + 'DeepSORT': 3, + 'JDE': 3, + 'FairMOT': 3, } KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet'] +MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT'] def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): @@ -54,7 +58,9 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): label_list = [str(cat) for cat in catid2name.values()] sample_transforms = reader_cfg['sample_transforms'] - for st in sample_transforms[1:]: + if arch != 'mot_arch': + sample_transforms = sample_transforms[1:] + for st in sample_transforms: for key, value in st.items(): p = {'type': key} if key == 'Resize': @@ -106,9 +112,17 @@ def _dump_infer_config(config, path, image_shape, model): label_arch = 'detection_arch' if infer_arch in KEYPOINT_ARCH: label_arch = 'keypoint_arch' + + if infer_arch in MOT_ARCH: + label_arch = 'mot_arch' + reader_cfg = config['TestMOTReader'] + dataset_cfg = config['TestMOTDataset'] + else: + reader_cfg = config['TestReader'] + dataset_cfg = config['TestDataset'] + infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader( - config['TestReader'], config['TestDataset'], config['metric'], - label_arch, image_shape) + reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape) if infer_arch == 'S2ANet': # TODO: move background to num_classes diff --git a/ppdet/engine/tracker.py b/ppdet/engine/tracker.py index 4040d9200158c747249eaaf4d48a2cd88712d47b..0d72be36ee1206694034d235ca57cc3c27707946 100644 --- a/ppdet/engine/tracker.py +++ b/ppdet/engine/tracker.py @@ -32,7 +32,6 @@ from ppdet.metrics import Metric, MOTMetric import ppdet.utils.stats as stats from .callbacks import Callback, ComposeCallback -from .export_utils import _dump_infer_config from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) @@ -104,8 +103,10 @@ class Tracker(object): def load_weights_sde(self, det_weights, reid_weights): if self.model.detector: - load_weight(self.model.detector, det_weights, self.optimizer) - load_weight(self.model.reid, reid_weights, self.optimizer) + load_weight(self.model.detector, det_weights) + load_weight(self.model.reid, reid_weights) + else: + load_weight(self.model.reid, reid_weights, self.optimizer) def _eval_seq_jde(self, dataloader, @@ -130,7 +131,8 @@ class Tracker(object): # forward timer.tic() - online_targets = self.model(data) + pred_dets, pred_embs = self.model(data) + online_targets = self.model.tracker.update(pred_dets, pred_embs) online_tlwhs, online_ids = [], [] for t in online_targets: @@ -199,7 +201,9 @@ class Tracker(object): # forward timer.tic() - online_targets = self.model(data) + detections = self.model(data) + self.model.tracker.predict() + online_targets = self.model.tracker.update(detections) online_tlwhs = [] online_ids = [] @@ -239,6 +243,7 @@ class Tracker(object): "model_type should be 'JDE', 'DeepSORT' or 'FairMOT'" # run tracking + n_frame = 0 timer_avgs, timer_calls = [], [] for seq in seqs: diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index cef560d9d3cadf923fd0683ba5996a541be8fd65..4a0a0af18c4eb998f13acc44a341e039e77a624d 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function import os +import sys import time import random import datetime @@ -46,6 +47,8 @@ logger = setup_logger('ppdet.engine') __all__ = ['Trainer'] +MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT'] + class Trainer(object): def __init__(self, cfg, mode='train'): @@ -57,7 +60,15 @@ class Trainer(object): self.is_loaded_weights = False # build data loader - self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())] + if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']: + self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())] + else: + self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())] + + if cfg.architecture == 'DeepSORT' and self.mode == 'train': + logger.error('DeepSORT has no need of training on mot dataset.') + sys.exit(1) + if self.mode == 'train': self.loader = create('{}Reader'.format(self.mode.capitalize()))( self.dataset, cfg.worker_num) @@ -225,15 +236,16 @@ class Trainer(object): if self.is_loaded_weights: return self.start_epoch = 0 - if hasattr(self.model, 'detector'): - if self.model.__class__.__name__ == 'FairMOT': - load_pretrain_weight(self.model, weights) - else: - load_pretrain_weight(self.model.detector, weights) - else: - load_pretrain_weight(self.model, weights) + load_pretrain_weight(self.model, weights) logger.debug("Load weights {} to start training".format(weights)) + def load_weights_sde(self, det_weights, reid_weights): + if self.model.detector: + load_weight(self.model.detector, det_weights) + load_weight(self.model.reid, reid_weights) + else: + load_weight(self.model.reid, reid_weights) + def resume_weights(self, weights): # support Distill resume weights if hasattr(self.model, 'student_model'): @@ -472,8 +484,12 @@ class Trainer(object): if not os.path.exists(save_dir): os.makedirs(save_dir) image_shape = None - if 'inputs_def' in self.cfg['TestReader']: - inputs_def = self.cfg['TestReader']['inputs_def'] + if self.cfg.architecture in MOT_ARCH: + test_reader_name = 'TestMOTReader' + else: + test_reader_name = 'TestReader' + if 'inputs_def' in self.cfg[test_reader_name]: + inputs_def = self.cfg[test_reader_name]['inputs_def'] image_shape = inputs_def.get('image_shape', None) # set image_shape=[3, -1, -1] as default if image_shape is None: diff --git a/ppdet/modeling/architectures/deepsort.py b/ppdet/modeling/architectures/deepsort.py index f8113a6f4448f03bdf216ec93fd0b2c69d85a666..66184fb7b18ae1d5ebcf2ed10853d0efee14ba19 100644 --- a/ppdet/modeling/architectures/deepsort.py +++ b/ppdet/modeling/architectures/deepsort.py @@ -61,7 +61,6 @@ class DeepSORT(BaseArch): } def _forward(self): - assert 'ori_image' in self.inputs load_dets = 'pred_bboxes' in self.inputs and 'pred_scores' in self.inputs ori_image = self.inputs['ori_image'] @@ -102,10 +101,7 @@ class DeepSORT(BaseArch): else: detections = [] - self.tracker.predict() - online_targets = self.tracker.update(detections) - - return online_targets + return detections def get_pred(self): return self._forward() diff --git a/ppdet/modeling/architectures/fairmot.py b/ppdet/modeling/architectures/fairmot.py index 1a29e3f59bf5003bbf8f28053797262361fc8323..05af9f20f0d0f6b1e59349e29ccb554d13814782 100755 --- a/ppdet/modeling/architectures/fairmot.py +++ b/ppdet/modeling/architectures/fairmot.py @@ -91,12 +91,9 @@ class FairMOT(BaseArch): embedding = paddle.transpose(embedding, [0, 2, 3, 1]) embedding = paddle.reshape(embedding, [-1, paddle.shape(embedding)[-1]]) - id_feature = paddle.gather(embedding, bbox_inds) - dets = det_outs['bbox'] - id_feature = id_feature - # Note: the tracker only considers batch_size=1 and num_classses=1 - online_targets = self.tracker.update(dets, id_feature) - return online_targets + pred_embs = paddle.gather(embedding, bbox_inds) + pred_dets = det_outs['bbox'] + return pred_dets, pred_embs def get_pred(self): output = self._forward() diff --git a/ppdet/modeling/architectures/jde.py b/ppdet/modeling/architectures/jde.py index b2f70a8d45d37061c10d76035307b937e83591aa..0871177bd88f8f0ac242cf14c10047f429a5346a 100644 --- a/ppdet/modeling/architectures/jde.py +++ b/ppdet/modeling/architectures/jde.py @@ -37,7 +37,7 @@ class JDE(BaseArch): tracker (object): tracker instance metric (str): 'MOTDet' for training and detection evaluation, 'ReID' for ReID embedding evaluation, or 'MOT' for multi object tracking - evaluation。 + evaluation. """ def __init__(self, @@ -107,11 +107,11 @@ class JDE(BaseArch): pred_dets = paddle.concat((bbox[:, 2:], bbox[:, 1:2]), axis=1) + boxes_idx = paddle.cast(boxes_idx, 'int64') emb_valid = paddle.gather_nd(emb_outs, boxes_idx) pred_embs = paddle.gather_nd(emb_valid, nms_keep_idx) - online_targets = self.tracker.update(pred_dets, pred_embs) - return online_targets + return pred_dets, pred_embs else: raise ValueError("Unknown metric {} for multi object tracking.". diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 41e75d0f4f533cda5637976fa1ad1a7a8e28fabb..a10d53aeb03c270a70ccdd27c563c3564e31548b 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -342,7 +342,7 @@ class RCNNBox(object): origin_shape = paddle.floor(im_shape / scale_factor + 0.5) scale_list = [] origin_shape_list = [] - + batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1]) # bbox_pred.shape: [N, C*4] for idx in range(batch_size): @@ -863,9 +863,7 @@ class JDEBox(object): return paddle.stack([gx1, gy1, gx2, gy2], axis=1) def decode_delta_map(self, delta_map, anchors): - delta_map_shape = paddle.shape(delta_map) - delta_map_shape.stop_gradient = True - nB, nA, nGh, nGw, _ = delta_map_shape[:] + nB, nA, nGh, nGw, _ = delta_map.shape[:] anchor_mesh = self.generate_anchor(nGh, nGw, anchors) # only support bs=1 anchor_mesh = paddle.unsqueeze(anchor_mesh, 0) @@ -875,9 +873,27 @@ class JDEBox(object): delta_map, shape=[-1, 4]), paddle.reshape( anchor_mesh, shape=[-1, 4])) - pred_map = paddle.reshape(pred_list, shape=[nB, -1, 4]) + pred_map = paddle.reshape(pred_list, shape=[nB, nA * nGh * nGw, 4]) return pred_map + def _postprocessing_by_level(self, nA, stride, head_out, anchor_vec): + boxes_shape = head_out.shape + nB, nGh, nGw = 1, boxes_shape[-2], boxes_shape[-1] + # only support bs=1 + p = paddle.reshape( + head_out, shape=[nB, nA, self.num_classes + 5, nGh, nGw]) + p = paddle.transpose(p, perm=[0, 1, 3, 4, 2]) # [nB, 4, nGh, nGw, 6] + p_box = p[:, :, :, :, :4] + boxes = self.decode_delta_map(p_box, anchor_vec) # [nB, 4*nGh*nGw, 4] + boxes = boxes * stride + + p_conf = paddle.transpose( + p[:, :, :, :, 4:6], perm=[0, 4, 1, 2, 3]) # [nB, 2, 4, 19, 34] + p_conf = F.softmax( + p_conf, axis=1)[:, 1, :, :, :].unsqueeze(-1) # [nB, 4, 19, 34, 1] + scores = paddle.reshape(p_conf, shape=[nB, nA * nGh * nGw, 1]) + return boxes, scores + def __call__(self, yolo_head_out, anchors): bbox_pred_list = [] for i, head_out in enumerate(yolo_head_out): @@ -885,43 +901,16 @@ class JDEBox(object): anc_w, anc_h = anchors[i][0::2], anchors[i][1::2] anchor_vec = np.stack((anc_w, anc_h), axis=1) / stride nA = len(anc_w) - boxes_shape = paddle.shape(head_out) - boxes_shape.stop_gradient = True - nB, nGh, nGw = boxes_shape[0], boxes_shape[-2], boxes_shape[-1] - - p = head_out.reshape((nB, nA, self.num_classes + 5, nGh, nGw)) - p = paddle.transpose(p, perm=[0, 1, 3, 4, 2]) # [nB, 4, 19, 34, 6] - p_box = p[:, :, :, :, :4] # [nB, 4, 19, 34, 4] - boxes = self.decode_delta_map(p_box, anchor_vec) # [nB, 4*19*34, 4] - boxes = boxes * stride - - p_conf = paddle.transpose( - p[:, :, :, :, 4:6], perm=[0, 4, 1, 2, 3]) # [nB, 2, 4, 19, 34] - p_conf = F.softmax( - p_conf, - axis=1)[:, 1, :, :, :].unsqueeze(-1) # [nB, 4, 19, 34, 1] - scores = paddle.reshape(p_conf, shape=[nB, -1, 1]) - + boxes, scores = self._postprocessing_by_level(nA, stride, head_out, + anchor_vec) bbox_pred_list.append(paddle.concat([boxes, scores], axis=-1)) - yolo_boxes_pred = paddle.concat(bbox_pred_list, axis=1) - boxes_idx = paddle.nonzero(yolo_boxes_pred[:, :, -1] > self.conf_thresh) - boxes_idx.stop_gradient = True - if boxes_idx.shape[0] == 0: # TODO: deploy - boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64')) - yolo_boxes_out = paddle.to_tensor( - np.array( - [[[0.0, 0.0, 0.0, 0.0]]], dtype='float32')) - yolo_scores_out = paddle.to_tensor( - np.array( - [[[0.0]]], dtype='float32')) - return boxes_idx, yolo_boxes_out, yolo_scores_out - - yolo_boxes = paddle.gather_nd(yolo_boxes_pred, boxes_idx) - yolo_boxes_out = paddle.reshape(yolo_boxes[:, :4], shape=[nB, -1, 4]) - yolo_scores_out = paddle.reshape(yolo_boxes[:, 4:5], shape=[nB, 1, -1]) - boxes_idx = boxes_idx[:, 1:] - return boxes_idx, yolo_boxes_out, yolo_scores_out # [163], [1, 163, 4], [1, 1, 163] + yolo_boxes_scores = paddle.concat(bbox_pred_list, axis=1) + boxes_idx_over_conf_thr = paddle.nonzero( + yolo_boxes_scores[:, :, -1] > self.conf_thresh) + boxes_idx_over_conf_thr.stop_gradient = True + + return boxes_idx_over_conf_thr, yolo_boxes_scores @register diff --git a/ppdet/modeling/mot/utils.py b/ppdet/modeling/mot/utils.py index 4bf295921723539795474830f365c0055f68227c..f0af9ab21464f94f49bebbe2cf2e3b28bde461b0 100644 --- a/ppdet/modeling/mot/utils.py +++ b/ppdet/modeling/mot/utils.py @@ -121,14 +121,13 @@ def load_det_results(det_file, num_frames): def scale_coords(coords, input_shape, im_shape, scale_factor): im_shape = im_shape.numpy()[0] - ratio = scale_factor.numpy()[0][0] - img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)] - - pad_w = (input_shape[1] - round(img0_shape[1] * ratio)) / 2 - pad_h = (input_shape[0] - round(img0_shape[0] * ratio)) / 2 + ratio = scale_factor[0][0] + pad_w = (input_shape[1] - int(im_shape[1])) / 2 + pad_h = (input_shape[0] - int(im_shape[0])) / 2 + coords = paddle.cast(coords, 'float32') coords[:, 0::2] -= pad_w coords[:, 1::2] -= pad_h - coords[:, 0:4] /= paddle.to_tensor(ratio) + coords[:, 0:4] /= ratio coords[:, :4] = paddle.clip(coords[:, :4], min=0, max=coords[:, :4].max()) return coords.round() diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index accef57eec70f3fabf39697e24117362193c16d1..df0b467ded57f163b2b4d05f0fc89c0995c584f7 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -329,8 +329,38 @@ class S2ANetBBoxPostProcess(nn.Layer): @register -class JDEBBoxPostProcess(BBoxPostProcess): - def __call__(self, head_out, anchors): +class JDEBBoxPostProcess(nn.Layer): + __shared__ = ['num_classes'] + __inject__ = ['decode', 'nms'] + + def __init__(self, num_classes=1, decode=None, nms=None, return_idx=True): + super(JDEBBoxPostProcess, self).__init__() + self.num_classes = num_classes + self.decode = decode + self.nms = nms + self.return_idx = return_idx + + self.fake_bbox_pred = paddle.to_tensor( + np.array( + [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32')) + self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32')) + self.fake_nms_keep_idx = paddle.to_tensor( + np.array( + [[0]], dtype='int32')) + + self.fake_yolo_boxes_out = paddle.to_tensor( + np.array( + [[[0.0, 0.0, 0.0, 0.0]]], dtype='float32')) + self.fake_yolo_scores_out = paddle.to_tensor( + np.array( + [[[0.0]]], dtype='float32')) + self.fake_boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64')) + + def forward(self, + head_out, + anchors, + im_shape=[[608, 1088]], + scale_factor=[[1.0, 1.0]]): """ Decode the bbox and do NMS for JDE model. @@ -345,17 +375,31 @@ class JDEBBoxPostProcess(BBoxPostProcess): bbox_num (Tensor): The number of prediction of each batch with shape [N]. nms_keep_idx (Tensor): The index of kept bboxes after NMS. """ - boxes_idx, bboxes, score = self.decode(head_out, anchors) - bbox_pred, bbox_num, nms_keep_idx = self.nms(bboxes, score, - self.num_classes) - if bbox_pred.shape[0] == 0: - bbox_pred = paddle.to_tensor( - np.array( - [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32')) - bbox_num = paddle.to_tensor(np.array([1], dtype='int32')) - nms_keep_idx = paddle.to_tensor(np.array([[0]], dtype='int32')) + boxes_idx, yolo_boxes_scores = self.decode(head_out, anchors) - return boxes_idx, bbox_pred, bbox_num, nms_keep_idx + if len(boxes_idx) == 0: + boxes_idx = self.fake_boxes_idx + yolo_boxes_out = self.fake_yolo_boxes_out + yolo_scores_out = self.fake_yolo_scores_out + else: + yolo_boxes = paddle.gather_nd(yolo_boxes_scores, boxes_idx) + # TODO: only support bs=1 now + yolo_boxes_out = paddle.reshape( + yolo_boxes[:, :4], shape=[1, len(boxes_idx), 4]) + yolo_scores_out = paddle.reshape( + yolo_boxes[:, 4:5], shape=[1, 1, len(boxes_idx)]) + boxes_idx = boxes_idx[:, 1:] + + bbox_pred, bbox_num, nms_keep_idx = self.nms( + yolo_boxes_out, yolo_scores_out, self.num_classes) + if bbox_pred.shape[0] == 0: + bbox_pred = self.fake_bbox_pred + bbox_num = self.fake_bbox_num + nms_keep_idx = self.fake_nms_keep_idx + if self.return_idx: + return boxes_idx, bbox_pred, bbox_num, nms_keep_idx + else: + return bbox_pred, bbox_num @register @@ -420,7 +464,7 @@ class CenterNetPostProcess(TTFBox): x2 = xs + wh[:, 0:1] / 2 y2 = ys + wh[:, 1:2] / 2 - n, c, feat_h, feat_w = paddle.shape(hm) + n, c, feat_h, feat_w = hm.shape[:] padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2 padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2 x1 = x1 * self.down_ratio diff --git a/ppdet/modeling/reid/jde_embedding_head.py b/ppdet/modeling/reid/jde_embedding_head.py index 5b108387a1e9071de7e461567939ed07cac2eb03..da2f72941f330e85aa3e154f710ab804a9b2f0ae 100644 --- a/ppdet/modeling/reid/jde_embedding_head.py +++ b/ppdet/modeling/reid/jde_embedding_head.py @@ -60,7 +60,7 @@ class JDEEmbeddingHead(nn.Layer): def __init__( self, num_classes=1, - num_identifiers=1, # defined by dataset.total_identities + num_identifiers=14455, # defined by dataset.total_identities when training anchor_levels=3, anchor_scales=4, embedding_dim=512, diff --git a/tools/eval_mot.py b/tools/eval_mot.py index 57475548663166f6b548e48a49e46bbe578473b2..59d8a7a433d4584ab8bd8b589cadeb930314fea6 100644 --- a/tools/eval_mot.py +++ b/tools/eval_mot.py @@ -47,7 +47,7 @@ def parse_args(): parser.add_argument( "--det_results_dir", type=str, - default=None, + default='', help="Directory name for detection results.") parser.add_argument( '--output_dir', diff --git a/tools/export_model.py b/tools/export_model.py index 8cf3885c88552ca9b48f8b8d6796377d96912a2e..213be3d09054ba08e12fe38b523ee7bec5391b46 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -63,7 +63,13 @@ def run(FLAGS, cfg): trainer = Trainer(cfg, mode='test') # load weights - trainer.load_weights(cfg.weights) + if cfg.architecture in ['DeepSORT']: + if cfg.det_weights != 'None': + trainer.load_weights_sde(cfg.det_weights, cfg.reid_weights) + else: + trainer.load_weights_sde(None, cfg.reid_weights) + else: + trainer.load_weights(cfg.weights) # export model trainer.export(FLAGS.output_dir) diff --git a/tools/infer_mot.py b/tools/infer_mot.py index 397a8b7e8a7b2868f715def7658d294d70c19cce..2067375776c58ac98b5f02282becefd0c5d07cce 100644 --- a/tools/infer_mot.py +++ b/tools/infer_mot.py @@ -49,7 +49,7 @@ def parse_args(): parser.add_argument( "--det_results_dir", type=str, - default=None, + default='', help="Directory name for detection results.") parser.add_argument( '--output_dir',