未验证 提交 af424a9a 编写于 作者: G George Ni 提交者: GitHub

[MOT] add deepsort deploy (#3515)

* add deepsort export model, reid infer

* fix jde_yolov3 metric

* fix deepsort deploy

* fix doc, clean code

* add mot deploy infer image, add run_benchmark
上级 3a497c5f
......@@ -39,7 +39,7 @@ or
pip install -r requirements.txt
```
**Notes:**
- Install `cython_bbox` for Windows: `pip install -e git+https://github.com/samson-wang/cython_bbox.git#egg=cython-bbox`. You can refer to this [tutorial](https://stackoverflow.com/questions/60349980/is-there-a-way-to-install-cython-bbox-for-windows)
- Install `cython_bbox` for Windows: `pip install -e git+https://github.com/samson-wang/cython_bbox.git#egg=cython-bbox`. You can refer to this [tutorial](https://stackoverflow.com/questions/60349980/is-there-a-way-to-install-cython-bbox-for-windows).
- Evaluation on Windows CUDA 11 environment may not be normally. It will be repaired as soon as possible. You can change to CUDA 10.2 or CUDA 10.1 environment for normal evaluation.
......
......@@ -80,6 +80,26 @@ CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsor
**Notes:**
Please make sure that [ffmpeg](https://ffmpeg.org/ffmpeg.html) is installed first, on Linux(Ubuntu) platform you can directly install it by the following command:`apt-get update && apt-get install -y ffmpeg`.
### 3. Export model
```bash
1.export detection model
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams
2.export ReID model
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
or
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
```
### 4. Using exported model for python inference
```bash
python deploy/python/mot_reid_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
```
**Notes:**
The tracking model is used to predict the video, and does not support the prediction of a single image. The visualization video of the tracking results is saved by default. You can add `--save_mot_txts` to save the txt result file, or `--save_images` to save the visualization images.
## Citations
```
@inproceedings{Wojke2017simple,
......
......@@ -82,6 +82,27 @@ CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/deepsort/deepsor
**注意:**
请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`
### 3. 导出预测模型
```bash
1.先导出检测模型
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/jde_yolov3_darknet53_30e_1088x608.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/jde_yolov3_darknet53_30e_1088x608.pdparams
2.再导出ReID模型
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_yolov3_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/deepsort/deepsort_pcb_pyramid_r101.yml -o reid_weights=https://paddledet.bj.bcebos.com/models/mot/deepsort_pcb_pyramid_r101.pdparams
```
### 4. 用导出的模型基于Python去预测
```bash
python deploy/python/mot_reid_infer.py --model_dir=output_inference/jde_yolov3_darknet53_30e_1088x608/ --reid_model_dir=output_inference/deepsort_yolov3_pcb_pyramid_r101/ --video_file={your video name}.mp4 --device=GPU --save_mot_txts
```
**注意:**
跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
## 引用
```
@inproceedings{Wojke2017simple,
......
_BASE_: [
'../../datasets/mot.yml',
'../../runtime.yml',
'../jde/_base_/optimizer_30e.yml',
'../jde/_base_/jde_reader_1088x608.yml',
]
weights: output/jde_yolov3_darknet53_30e_1088x608/model_final
metric: MOTDet
EvalReader:
inputs_def:
num_max_boxes: 50
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
TestReader:
inputs_def:
image_shape: [3, 608, 1088]
sample_transforms:
- Decode: {}
- LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
EvalDataset:
!MOTDataSet
dataset_dir: dataset/mot
image_lists: ['mot16.train']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
TestDataset:
!ImageFolder
anno_path: None
architecture: YOLOv3
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/DarkNet53_pretrained.pdparams
# JDE version for MOT dataset
YOLOv3:
backbone: DarkNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: JDEBBoxPostProcess
DarkNet:
depth: 53
return_idx: [2, 3, 4]
freeze_norm: True
YOLOv3FPN:
freeze_norm: True
YOLOv3Head:
anchors: [[128,384], [180,540], [256,640], [512,640],
[32,96], [45,135], [64,192], [90,271],
[8,24], [11,34], [16,48], [23,68]]
anchor_masks: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
loss: JDEDetectionLoss
JDEDetectionLoss:
for_mot: False
JDEBBoxPostProcess:
decode:
name: JDEBox
conf_thresh: 0.3
downsample_ratio: 32
nms:
name: MultiClassNMS
keep_top_k: 500
score_threshold: 0.01
nms_threshold: 0.5
nms_top_k: 2000
normalized: true
return_idx: false
......@@ -27,7 +27,7 @@ from paddle.inference import Config
from paddle.inference import create_predictor
from benchmark_utils import PaddleInferBenchmark
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize
from visualize import visualize_box_mask
from utils import argsparser, Timer, get_current_memory_mb
......@@ -41,6 +41,9 @@ SUPPORT_MODELS = {
'SOLOv2',
'TTFNet',
'S2ANet',
'JDE',
'FairMOT',
'DeepSORT',
}
......
......@@ -19,8 +19,7 @@ import cv2
import numpy as np
import paddle
from benchmark_utils import PaddleInferBenchmark
from preprocess import preprocess, NormalizeImage, Permute
from mot_preprocess import LetterBoxResize
from preprocess import preprocess, NormalizeImage, Permute, LetterBoxResize
from tracker import JDETracker
from ppdet.modeling.mot import visualization as mot_vis
......@@ -29,7 +28,7 @@ from ppdet.modeling.mot.utils import Timer as MOTTimer
from paddle.inference import Config
from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb
from infer import get_test_images, print_arguments
from infer import get_test_images, print_arguments, PredictConfig
# Global dictionary
MOT_SUPPORT_MODELS = {
......@@ -69,8 +68,8 @@ class MOT_Detector(object):
self.predictor, self.config = load_predictor(
model_dir,
run_mode=run_mode,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
min_subgraph_size=self.pred_config.min_subgraph_size,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
......@@ -109,10 +108,10 @@ class MOT_Detector(object):
online_scores.append(tscore)
return online_tlwhs, online_scores, online_ids
def predict(self, image, threshold=0.5, repeats=1):
def predict(self, image, threshold=0.5, warmup=0, repeats=1):
'''
Args:
image (dict): dict(['image', 'im_shape', 'scale_factor'])
image (np.ndarray): numpy image data
threshold (float): threshold of predicted box' score
Returns:
online_tlwhs, online_ids (np.ndarray)
......@@ -120,12 +119,19 @@ class MOT_Detector(object):
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image)
self.det_times.preprocess_time_s.end()
pred_dets, pred_embs = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
for i in range(warmup):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
pred_dets = boxes_tensor.copy_to_cpu()
self.det_times.inference_time_s.start()
for i in range(repeats):
self.predictor.run()
......@@ -134,7 +140,6 @@ class MOT_Detector(object):
pred_dets = boxes_tensor.copy_to_cpu()
embs_tensor = self.predictor.get_output_handle(output_names[1])
pred_embs = embs_tensor.copy_to_cpu()
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
......@@ -150,7 +155,6 @@ def create_inputs(im, im_info):
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
model_arch (str): model type
Returns:
inputs (dict): input of model
"""
......@@ -162,48 +166,6 @@ def create_inputs(im, im_info):
return inputs
class PredictConfig_MOT():
"""set config of preprocess, postprocess and visualize
Args:
model_dir (str): root path of model.yml
"""
def __init__(self, model_dir):
# parsing Yaml config for Preprocess
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.labels = yml_conf['label_list']
self.mask = False
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
if 'mask' in yml_conf:
self.mask = yml_conf['mask']
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in MOT_SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], MOT_SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def load_predictor(model_dir,
run_mode='fluid',
batch_size=1,
......@@ -217,6 +179,7 @@ def load_predictor(model_dir,
cpu_threads=1,
enable_mkldnn=False):
"""set AnalysisConfig, generate AnalysisPredictor
Note: only support batch_size=1 now
Args:
model_dir (str): root path of __model__ and __params__
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16/trt_int8)
......@@ -325,6 +288,30 @@ def write_mot_results(filename, results, data_type='mot'):
f.write(line)
def predict_image(detector, image_list):
results = []
for i, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
if FLAGS.run_benchmark:
detector.predict(frame, FLAGS.threshold, warmup=10, repeats=10)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file))
else:
online_tlwhs, online_scores, online_ids = detector.predict(
frame, FLAGS.threshold)
online_im = mot_vis.plot_tracking(
frame, online_tlwhs, online_ids, online_scores, frame_id=i)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
cv2.imwrite(os.path.join(FLAGS.output_dir, img_file), online_im)
def predict_video(detector, camera_id):
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
......@@ -364,8 +351,7 @@ def predict_video(detector, camera_id):
online_ids,
online_scores,
frame_id=frame_id,
fps=fps,
threhold=FLAGS.threshold)
fps=fps)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
......@@ -381,7 +367,7 @@ def predict_video(detector, camera_id):
cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if FLAGS.save_results:
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
......@@ -389,7 +375,7 @@ def predict_video(detector, camera_id):
def main():
pred_config = PredictConfig_MOT(FLAGS.model_dir)
pred_config = PredictConfig(FLAGS.model_dir)
detector = MOT_Detector(
pred_config,
FLAGS.model_dir,
......@@ -406,7 +392,32 @@ def main():
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, FLAGS.camera_id)
else:
print('MOT models do not support predict single image.')
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, img_list)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
model_dir = FLAGS.model_dir
mode = FLAGS.run_mode
model_info = {
'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
data_info = {
'batch_size': 1,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
det_log = PaddleInferBenchmark(detector.config, model_info,
data_info, perf_info, mems)
det_log('MOT')
if __name__ == '__main__':
......
......@@ -20,16 +20,55 @@ import paddle
from mot_keypoint_unite_utils import argsparser
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint
from keypoint_det_unite_infer import bench_log
from keypoint_visualize import draw_pose
from benchmark_utils import PaddleInferBenchmark
from utils import Timer
from tracker import JDETracker
from mot_preprocess import LetterBoxResize
from mot_infer import MOT_Detector, PredictConfig_MOT, write_mot_results
from infer import print_arguments
from preprocess import LetterBoxResize
from mot_infer import MOT_Detector, write_mot_results
from infer import Detector, PredictConfig, print_arguments, get_test_images
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.modeling.mot.utils import Timer as FPSTimer
from utils import get_current_memory_mb
def mot_keypoint_unite_predict_image(mot_model, keypoint_model, image_list):
for i, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
if FLAGS.run_benchmark:
mot_model.predict(frame, FLAGS.mot_threshold, warmup=10, repeats=10)
cm, gm, gu = get_current_memory_mb()
mot_model.cpu_mem += cm
mot_model.gpu_mem += gm
mot_model.gpu_util += gu
keypoint_model.predict(
[frame], FLAGS.keypoint_threshold, warmup=10, repeats=10)
cm, gm, gu = get_current_memory_mb()
keypoint_model.cpu_mem += cm
keypoint_model.gpu_mem += gm
keypoint_model.gpu_util += gu
else:
online_tlwhs, online_scores, online_ids = mot_model.predict(
frame, FLAGS.mot_threshold)
keypoint_results = keypoint_model.predict([frame],
FLAGS.keypoint_threshold)
im = draw_pose(
frame,
keypoint_results,
visual_thread=FLAGS.keypoint_threshold,
returnimg=True)
online_im = mot_vis.plot_tracking(
im, online_tlwhs, online_ids, online_scores, frame_id=i)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
cv2.imwrite(os.path.join(FLAGS.output_dir, img_file), online_im)
def mot_keypoint_unite_predict_video(mot_model, keypoint_model, camera_id):
......@@ -117,7 +156,7 @@ def mot_keypoint_unite_predict_video(mot_model, keypoint_model, camera_id):
def main():
pred_config = PredictConfig_MOT(FLAGS.mot_model_dir)
pred_config = PredictConfig(FLAGS.mot_model_dir)
mot_model = MOT_Detector(
pred_config,
FLAGS.mot_model_dir,
......@@ -149,7 +188,28 @@ def main():
mot_keypoint_unite_predict_video(mot_model, keypoint_model,
FLAGS.camera_id)
else:
print('Do not support unite predict single image.')
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
mot_keypoint_unite_predict_image(mot_model, keypoint_model, img_list)
if not FLAGS.run_benchmark:
mot_model.det_times.info(average=True)
keypoint_model.det_times.info(average=True)
else:
mode = FLAGS.run_mode
mot_model_dir = FLAGS.mot_model_dir
mot_model_info = {
'model_name': mot_model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(mot_model, img_list, mot_model_info, name='MOT')
keypoint_model_dir = FLAGS.keypoint_model_dir
keypoint_model_info = {
'model_name': keypoint_model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(keypoint_model, img_list, keypoint_model_info, 'KeyPoint')
if __name__ == '__main__':
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
class LetterBoxResize(object):
def __init__(self, target_size):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super(LetterBoxResize, self).__init__()
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
# letterbox: resize a rectangular image to a padded rectangular
shape = img.shape[:2] # [height, width]
ratio_h = float(height) / shape[0]
ratio_w = float(width) / shape[1]
ratio = min(ratio_h, ratio_w)
new_shape = (round(shape[1] * ratio),
round(shape[0] * ratio)) # [width, height]
padw = (width - new_shape[0]) / 2
padh = (height - new_shape[1]) / 2
top, bottom = round(padh - 0.1), round(padh + 0.1)
left, right = round(padw - 0.1), round(padw + 0.1)
img = cv2.resize(
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=color) # padded rectangular
return img, ratio, padw, padh
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
height, width = self.target_size
h, w = im.shape[:2]
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
new_shape = [round(h * ratio), round(w * ratio)]
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
return im, im_info
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import yaml
import cv2
import numpy as np
import paddle
from benchmark_utils import PaddleInferBenchmark
from preprocess import preprocess, NormalizeImage, Permute, LetterBoxResize
from tracker import DeepSORTTracker
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.modeling.mot.utils import Timer as MOTTimer
from ppdet.modeling.mot.utils import Detection
from paddle.inference import Config
from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb
from infer import get_test_images, print_arguments, PredictConfig, Detector
from mot_infer import create_inputs, load_predictor, write_mot_results
# Global dictionary
MOT_SUPPORT_MODELS = {'DeepSORT'}
def bench_log(detector, img_list, model_info, batch_size=1, name=None):
mems = {
'cpu_rss_mb': detector.cpu_mem / len(img_list),
'gpu_rss_mb': detector.gpu_mem / len(img_list),
'gpu_util': detector.gpu_util * 100 / len(img_list)
}
perf_info = detector.det_times.report(average=True)
data_info = {
'batch_size': batch_size,
'shape': "dynamic_shape",
'data_num': perf_info['img_num']
}
log = PaddleInferBenchmark(detector.config, model_info, data_info,
perf_info, mems)
log(name)
def scale_coords(coords, input_shape, im_shape, scale_factor):
im_shape = im_shape[0]
ratio = scale_factor[0][0]
pad_w = (input_shape[1] - int(im_shape[1])) / 2
pad_h = (input_shape[0] - int(im_shape[0])) / 2
coords[:, 0::2] -= pad_w
coords[:, 1::2] -= pad_h
coords[:, 0:4] /= ratio
coords[:, :4] = np.clip(coords[:, :4], a_min=0, a_max=coords[:, :4].max())
return coords.round()
def clip_box(xyxy, input_shape, im_shape, scale_factor):
im_shape = im_shape[0]
ratio = scale_factor[0][0]
img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)]
xyxy[:, 0::2] = np.clip(xyxy[:, 0::2], a_min=0, a_max=img0_shape[1])
xyxy[:, 1::2] = np.clip(xyxy[:, 1::2], a_min=0, a_max=img0_shape[0])
return xyxy
def get_crops(xyxy, ori_img, pred_scores, w, h):
crops = []
keep_scores = []
xyxy = xyxy.astype(np.int64)
ori_img = ori_img.transpose(1, 0, 2) # [h,w,3]->[w,h,3]
for i, bbox in enumerate(xyxy):
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
continue
crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :]
crops.append(crop)
keep_scores.append(pred_scores[i])
if len(crops) == 0:
return [], []
crops = preprocess_reid(crops, w, h)
return crops, keep_scores
def preprocess_reid(imgs,
w=64,
h=192,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]):
im_batch = []
for img in imgs:
img = cv2.resize(img, (w, h))
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
img = np.expand_dims(img, axis=0)
im_batch.append(img)
im_batch = np.concatenate(im_batch, 0)
return im_batch
class MOT_Detector(object):
"""
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
trt_min_shape=1,
trt_max_shape=1088,
trt_opt_shape=608,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
self.pred_config = pred_config
self.predictor, self.config = load_predictor(
model_dir,
run_mode=run_mode,
device=device,
min_subgraph_size=self.pred_config.min_subgraph_size,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
def preprocess(self, im):
preprocess_ops = []
for op_info in self.pred_config.preprocess_infos:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
im, im_info = preprocess(im, preprocess_ops)
inputs = create_inputs(im, im_info)
return inputs
def postprocess(self, boxes, input_shape, im_shape, scale_factor,
threshold):
pred_bboxes = scale_coords(boxes[:, 2:], input_shape, im_shape,
scale_factor)
pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape, scale_factor)
pred_scores = boxes[:, 1:2]
keep_mask = pred_scores[:, 0] >= threshold
return pred_bboxes[keep_mask], pred_scores[keep_mask]
def predict(self, image, threshold=0.5, warmup=0, repeats=1):
'''
Args:
image (np.ndarray): image numpy data
threshold (float): threshold of predicted box' score
Returns:
pred_bboxes, pred_scores (np.ndarray)
'''
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(image)
self.det_times.preprocess_time_s.end()
pred_bboxes, pred_scores = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
for i in range(warmup):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
boxes = boxes_tensor.copy_to_cpu()
self.det_times.inference_time_s.start()
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
boxes = boxes_tensor.copy_to_cpu()
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
input_shape = inputs['image'].shape[2:]
im_shape = inputs['im_shape']
scale_factor = inputs['scale_factor']
pred_bboxes, pred_scores = self.postprocess(
boxes, input_shape, im_shape, scale_factor, threshold)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return pred_bboxes, pred_scores
class MOT_ReID(object):
def __init__(self,
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
trt_min_shape=1,
trt_max_shape=1088,
trt_opt_shape=608,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False):
self.pred_config = pred_config
self.predictor, self.config = load_predictor(
model_dir,
run_mode=run_mode,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
self.det_times = Timer()
self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
self.tracker = DeepSORTTracker()
def preprocess(self, crops):
inputs = {}
inputs['crops'] = np.array(crops).astype('float32')
return inputs
def postprocess(self, bbox_tlwh, pred_scores, features):
detections = [
Detection(tlwh, score, feat)
for tlwh, score, feat in zip(bbox_tlwh, pred_scores, features)
]
self.tracker.predict()
online_targets = self.tracker.update(detections)
online_tlwhs = []
online_scores = []
online_ids = []
for track in online_targets:
if not track.is_confirmed() or track.time_since_update > 1:
continue
online_tlwhs.append(track.to_tlwh())
online_scores.append(1.0)
online_ids.append(track.track_id)
return online_tlwhs, online_scores, online_ids
def predict(self, crops, bbox_tlwh, pred_scores, warmup=0, repeats=1):
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(crops)
self.det_times.preprocess_time_s.end()
pred_dets, pred_embs = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
for i in range(warmup):
self.predictor.run()
output_names = self.predictor.get_output_names()
feature_tensor = self.predictor.get_output_handle(output_names[0])
features = feature_tensor.copy_to_cpu()
self.det_times.inference_time_s.start()
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
feature_tensor = self.predictor.get_output_handle(output_names[0])
features = feature_tensor.copy_to_cpu()
self.det_times.inference_time_s.end(repeats=repeats)
self.det_times.postprocess_time_s.start()
online_tlwhs, online_scores, online_ids = self.postprocess(
bbox_tlwh, pred_scores, features)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += 1
return online_tlwhs, online_scores, online_ids
def predict_image(detector, reid_model, image_list):
results = []
for i, img_file in enumerate(image_list):
frame = cv2.imread(img_file)
if FLAGS.run_benchmark:
pred_bboxes, pred_scores = detector.predict(
frame, FLAGS.threshold, warmup=10, repeats=10)
cm, gm, gu = get_current_memory_mb()
detector.cpu_mem += cm
detector.gpu_mem += gm
detector.gpu_util += gu
print('Test iter {}, file name:{}'.format(i, img_file))
else:
pred_bboxes, pred_scores = detector.predict(frame, FLAGS.threshold)
# process
bbox_tlwh = np.concatenate(
(pred_bboxes[:, 0:2],
pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1),
axis=1)
crops, pred_scores = get_crops(
pred_bboxes, frame, pred_scores, w=64, h=192)
if FLAGS.run_benchmark:
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, bbox_tlwh, pred_scores, warmup=10, repeats=10)
else:
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, bbox_tlwh, pred_scores)
online_im = mot_vis.plot_tracking(
frame, online_tlwhs, online_ids, online_scores, frame_id=i)
if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
cv2.imwrite(os.path.join(FLAGS.output_dir, img_file), online_im)
def predict_video(detector, reid_model, camera_id):
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'mot_output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
fps = 30
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print('frame_count', frame_count)
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# yapf: disable
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# yapf: enable
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer = MOTTimer()
results = []
while (1):
ret, frame = capture.read()
if not ret:
break
timer.tic()
pred_bboxes, pred_scores = detector.predict(frame, FLAGS.threshold)
timer.toc()
bbox_tlwh = np.concatenate(
(pred_bboxes[:, 0:2],
pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1),
axis=1)
crops, pred_scores = get_crops(
pred_bboxes, frame, pred_scores, w=64, h=192)
online_tlwhs, online_scores, online_ids = reid_model.predict(
crops, bbox_tlwh, pred_scores)
results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
fps = 1. / timer.average_time
online_im = mot_vis.plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps)
if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
online_im)
frame_id += 1
print('detect frame:%d' % (frame_id))
im = np.array(online_im)
writer.write(im)
if camera_id != -1:
cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if FLAGS.save_mot_txts:
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
writer.release()
def main():
pred_config = PredictConfig(FLAGS.model_dir)
detector = MOT_Detector(
pred_config,
FLAGS.model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
pred_config = PredictConfig(FLAGS.reid_model_dir)
reid_model = MOT_ReID(
pred_config,
FLAGS.reid_model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn)
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
predict_video(detector, reid_model, FLAGS.camera_id)
else:
# predict from image
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(detector, reid_model, img_list)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
reid_model.det_times.info(average=True)
else:
mode = FLAGS.run_mode
det_model_dir = FLAGS.model_dir
det_model_info = {
'model_name': det_model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(detector, img_list, det_model_info, name='Det')
reid_model_dir = FLAGS.reid_model_dir
reid_model_info = {
'model_name': reid_model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(reid_model, img_list, reid_model_info, name='ReID')
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
FLAGS.device = FLAGS.device.upper()
assert FLAGS.device in ['CPU', 'GPU', 'XPU'
], "device should be CPU, GPU or XPU"
main()
......@@ -193,6 +193,60 @@ class PadStride(object):
return padding_im, im_info
class LetterBoxResize(object):
def __init__(self, target_size):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super(LetterBoxResize, self).__init__()
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
# letterbox: resize a rectangular image to a padded rectangular
shape = img.shape[:2] # [height, width]
ratio_h = float(height) / shape[0]
ratio_w = float(width) / shape[1]
ratio = min(ratio_h, ratio_w)
new_shape = (round(shape[1] * ratio),
round(shape[0] * ratio)) # [width, height]
padw = (width - new_shape[0]) / 2
padh = (height - new_shape[1]) / 2
top, bottom = round(padh - 0.1), round(padh + 0.1)
left, right = round(padw - 0.1), round(padw + 0.1)
img = cv2.resize(
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=color) # padded rectangular
return img, ratio, padw, padh
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
height, width = self.target_size
h, w = im.shape[:2]
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
new_shape = [round(h * ratio), round(w * ratio)]
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
return im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
......
......@@ -16,6 +16,7 @@ This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_so
"""
import numpy as np
from ppdet.modeling.mot.motion import KalmanFilter
from ppdet.modeling.mot.matching.deepsort_matching import NearestNeighborDistanceMetric
from ppdet.modeling.mot.matching.deepsort_matching import iou_cost, min_cost_matching, matching_cascade, gate_cost_matrix
from ppdet.modeling.mot.tracker.base_sde_tracker import Track
......@@ -24,7 +25,6 @@ __all__ = ['DeepSORTTracker']
class DeepSORTTracker(object):
__inject__ = ['motion']
"""
DeepSORT tracker
......@@ -60,7 +60,7 @@ class DeepSORTTracker(object):
self.metric = NearestNeighborDistanceMetric(metric_type,
matching_threshold, budget)
self.max_iou_distance = max_iou_distance
self.motion = motion
self.motion = KalmanFilter()
self.tracks = []
self._next_id = 1
......
......@@ -108,6 +108,12 @@ def argsparser():
'--save_mot_txts',
action='store_true',
help='Save tracking results (txt).')
parser.add_argument(
"--reid_model_dir",
type=str,
default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."))
parser.add_argument(
'--use_dark',
type=bool,
......
......@@ -24,7 +24,7 @@ import numpy as np
from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from ppdet.modeling.mot.utils import Timer, load_det_results
from ppdet.modeling.mot import visualization as mot_vis
......@@ -188,9 +188,12 @@ class Tracker(object):
logger.info('Processing frame {} ({:.2f} fps)'.format(
frame_id, 1. / max(1e-5, timer.average_time)))
ori_image = data['ori_image']
input_shape = data['image'].shape[2:]
im_shape = data['im_shape']
scale_factor = data['scale_factor']
timer.tic()
if not use_detector:
timer.tic()
dets = dets_list[frame_id]
bbox_tlwh = paddle.to_tensor(dets['bbox'], dtype='float32')
pred_scores = paddle.to_tensor(dets['score'], dtype='float32')
......@@ -203,14 +206,35 @@ class Tracker(object):
else:
pred_bboxes = []
pred_scores = []
data.update({
'pred_bboxes': pred_bboxes,
'pred_scores': pred_scores
})
else:
outs = self.model.detector(data)
if outs['bbox_num'] > 0:
pred_bboxes = scale_coords(outs['bbox'][:, 2:], input_shape,
im_shape, scale_factor)
pred_scores = outs['bbox'][:, 1:2]
else:
pred_bboxes = []
pred_scores = []
# forward
timer.tic()
detections = self.model(data)
pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape,
scale_factor)
bbox_tlwh = paddle.concat(
(pred_bboxes[:, 0:2],
pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1),
axis=1)
crops, pred_scores = get_crops(
pred_bboxes, ori_image, pred_scores, w=64, h=192)
crops = paddle.to_tensor(crops)
pred_scores = paddle.to_tensor(pred_scores)
data.update({'crops': crops})
features = self.model(data)
features = features.numpy()
detections = [
Detection(tlwh, score, feat)
for tlwh, score, feat in zip(bbox_tlwh, pred_scores, features)
]
self.model.tracker.predict()
online_targets = self.model.tracker.update(detections)
......
......@@ -36,7 +36,7 @@ from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.utils.visualizer import visualize_results, save_result
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
from ppdet.metrics import RBoxMetric
from ppdet.metrics import RBoxMetric, JDEDetMetric
from ppdet.data.source.category import get_categories
import ppdet.utils.stats as stats
......@@ -243,6 +243,8 @@ class Trainer(object):
len(eval_dataset), self.cfg.num_joints,
self.cfg.save_dir)
]
elif self.cfg.metric == 'MOTDet':
self._metrics = [JDEDetMetric(), ]
else:
logger.warn("Metric not support for metric type {}".format(
self.cfg.metric))
......@@ -545,6 +547,11 @@ class Trainer(object):
"scale_factor": InputSpec(
shape=[None, 2], name='scale_factor')
}]
if self.cfg.architecture == 'DeepSORT':
input_spec[0].update({
"crops": InputSpec(
shape=[None, 3, 192, 64], name='crops')
})
# dy2st and save model
if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
......
......@@ -20,12 +20,14 @@ import copy
import numpy as np
import paddle
import paddle.nn.functional as F
from ppdet.modeling.bbox_utils import bbox_iou_np_expand
from .map_utils import ap_per_class
from .metrics import Metric
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['MOTEvaluator', 'MOTMetric']
__all__ = ['MOTEvaluator', 'MOTMetric', 'JDEDetMetric']
def read_mot_results(filename, is_gt=False, is_ignore=False):
......@@ -236,3 +238,67 @@ class MOTMetric(Metric):
def get_results(self):
return self.strsummary
class JDEDetMetric(Metric):
# Note this detection AP metric is different from COCOMetric or VOCMetric,
# and the bboxes coordinates are not scaled to the original image
def __init__(self, overlap_thresh=0.5):
self.overlap_thresh = overlap_thresh
self.reset()
def reset(self):
self.AP_accum = np.zeros(1)
self.AP_accum_count = np.zeros(1)
def update(self, inputs, outputs):
bboxes = outputs['bbox'][:, 2:].numpy()
scores = outputs['bbox'][:, 1].numpy()
labels = outputs['bbox'][:, 0].numpy()
bbox_lengths = outputs['bbox_num'].numpy()
if bboxes.shape[0] == 1 and bboxes.sum() == 0.0:
return
gt_boxes = inputs['gt_bbox'].numpy()[0]
gt_labels = inputs['gt_class'].numpy()[0]
if gt_labels.shape[0] == 0:
return
correct = []
detected = []
for i in range(bboxes.shape[0]):
obj_pred = 0
pred_bbox = bboxes[i].reshape(1, 4)
# Compute iou with target boxes
iou = bbox_iou_np_expand(pred_bbox, gt_boxes, x1y1x2y2=True)[0]
# Extract index of largest overlap
best_i = np.argmax(iou)
# If overlap exceeds threshold and classification is correct mark as correct
if iou[best_i] > self.overlap_thresh and obj_pred == gt_labels[
best_i] and best_i not in detected:
correct.append(1)
detected.append(best_i)
else:
correct.append(0)
# Compute Average Precision (AP) per class
target_cls = list(gt_labels.T[0])
AP, AP_class, R, P = ap_per_class(
tp=correct,
conf=scores,
pred_cls=np.zeros_like(scores),
target_cls=target_cls)
self.AP_accum_count += np.bincount(AP_class, minlength=1)
self.AP_accum += np.bincount(AP_class, minlength=1, weights=AP)
def accumulate(self):
logger.info("Accumulating evaluatation results...")
self.map_stat = self.AP_accum[0] / (self.AP_accum_count[0] + 1E-16)
def log(self):
map_stat = 100. * self.map_stat
logger.info("mAP({:.2f}) = {:.2f}%".format(self.overlap_thresh,
map_stat))
def get_results(self):
return self.map_stat
......@@ -61,47 +61,9 @@ class DeepSORT(BaseArch):
}
def _forward(self):
load_dets = 'pred_bboxes' in self.inputs and 'pred_scores' in self.inputs
ori_image = self.inputs['ori_image']
input_shape = self.inputs['image'].shape[2:]
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
if self.detector and not load_dets:
outs = self.detector(self.inputs)
if outs['bbox_num'] > 0:
pred_bboxes = scale_coords(outs['bbox'][:, 2:], input_shape,
im_shape, scale_factor)
pred_scores = outs['bbox'][:, 1:2]
else:
pred_bboxes = []
pred_scores = []
else:
pred_bboxes = self.inputs['pred_bboxes']
pred_scores = self.inputs['pred_scores']
if len(pred_bboxes) > 0:
pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape,
scale_factor)
bbox_tlwh = paddle.concat(
(pred_bboxes[:, 0:2],
pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1),
axis=1)
crops, pred_scores = get_crops(
pred_bboxes, ori_image, pred_scores, w=64, h=192)
if len(crops) > 0:
features = self.reid(paddle.to_tensor(crops))
detections = [Detection(bbox_tlwh[i], conf, features[i])\
for i, conf in enumerate(pred_scores)]
else:
detections = []
else:
detections = []
return detections
crops = self.inputs['crops']
features = self.reid(crops)
return features
def get_pred(self):
return self._forward()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
from ..post_process import JDEBBoxPostProcess
__all__ = ['YOLOv3']
......@@ -39,6 +54,7 @@ class YOLOv3(BaseArch):
self.yolo_head = yolo_head
self.post_process = post_process
self.for_mot = for_mot
self.return_idx = isinstance(post_process, JDEBBoxPostProcess)
@classmethod
def from_config(cls, cfg, *args, **kwargs):
......@@ -90,9 +106,13 @@ class YOLOv3(BaseArch):
'emb_feats': emb_feats,
}
else:
bbox, bbox_num = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
if self.return_idx:
_, bbox, bbox_num, _ = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors)
else:
bbox, bbox_num = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num}
return output
......
......@@ -28,9 +28,10 @@ __all__ = ['JDEDetectionLoss', 'JDEEmbeddingLoss', 'JDELoss']
class JDEDetectionLoss(nn.Layer):
__shared__ = ['num_classes']
def __init__(self, num_classes=1):
def __init__(self, num_classes=1, for_mot=True):
super(JDEDetectionLoss, self).__init__()
self.num_classes = num_classes
self.for_mot = for_mot
def det_loss(self, p_det, anchor, t_conf, t_box):
pshape = paddle.shape(p_det)
......@@ -92,7 +93,17 @@ class JDEDetectionLoss(nn.Layer):
loss_conf, loss_box = self.det_loss(p_det, anchor, t_conf, t_box)
loss_confs.append(loss_conf)
loss_boxes.append(loss_box)
return {'loss_confs': loss_confs, 'loss_boxes': loss_boxes}
if self.for_mot:
return {'loss_confs': loss_confs, 'loss_boxes': loss_boxes}
else:
jde_conf_losses = sum(loss_confs)
jde_box_losses = sum(loss_boxes)
jde_det_losses = {
"loss_conf": jde_conf_losses,
"loss_box": jde_box_losses,
"loss": jde_conf_losses + jde_box_losses,
}
return jde_det_losses
@register
......
......@@ -82,7 +82,7 @@ class Detection(object):
def __init__(self, tlwh, confidence, feature):
self.tlwh = np.asarray(tlwh, dtype=np.float32)
self.confidence = np.asarray(confidence, dtype=np.float32)
self.feature = feature.numpy()
self.feature = feature
def to_tlbr(self):
"""
......
......@@ -355,11 +355,7 @@ class JDEBBoxPostProcess(nn.Layer):
[[[0.0]]], dtype='float32'))
self.fake_boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64'))
def forward(self,
head_out,
anchors,
im_shape=[[608, 1088]],
scale_factor=[[1.0, 1.0]]):
def forward(self, head_out, anchors):
"""
Decode the bbox and do NMS for JDE model.
......@@ -389,16 +385,21 @@ class JDEBBoxPostProcess(nn.Layer):
yolo_boxes[:, 4:5], shape=[1, 1, len(boxes_idx)])
boxes_idx = boxes_idx[:, 1:]
bbox_pred, bbox_num, nms_keep_idx = self.nms(
yolo_boxes_out, yolo_scores_out, self.num_classes)
if bbox_pred.shape[0] == 0:
bbox_pred = self.fake_bbox_pred
bbox_num = self.fake_bbox_num
nms_keep_idx = self.fake_nms_keep_idx
if self.return_idx:
bbox_pred, bbox_num, nms_keep_idx = self.nms(
yolo_boxes_out, yolo_scores_out, self.num_classes)
if bbox_pred.shape[0] == 0:
bbox_pred = self.fake_bbox_pred
bbox_num = self.fake_bbox_num
nms_keep_idx = self.fake_nms_keep_idx
return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
else:
return bbox_pred, bbox_num
bbox_pred, bbox_num, _ = self.nms(yolo_boxes_out, yolo_scores_out,
self.num_classes)
if bbox_pred.shape[0] == 0:
bbox_pred = self.fake_bbox_pred
bbox_num = self.fake_bbox_num
return _, bbox_pred, bbox_num, _
@register
......
......@@ -78,8 +78,7 @@ class PCBPyramid(nn.Layer):
for idx_branches in range(self.num_branches):
if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]):
idx_levels += 1
if self.used_levels[idx_levels] == 0:
continue
pyramid_conv_list.append(
nn.Sequential(
nn.Conv2D(input_ch, num_conv_out_channels, 1),
......@@ -89,8 +88,7 @@ class PCBPyramid(nn.Layer):
for idx_branches in range(self.num_branches):
if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]):
idx_levels += 1
if self.used_levels[idx_levels] == 0:
continue
name = "Linear_branch_id_{}".format(idx_branches)
fc = nn.Linear(
in_features=num_conv_out_channels,
......@@ -113,8 +111,6 @@ class PCBPyramid(nn.Layer):
for idx_branches in range(self.num_branches):
if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]):
idx_levels += 1
if self.used_levels[idx_levels] == 0:
continue
idx_in_each_level = idx_branches - sum(self.num_in_each_level[
0:idx_levels])
stripe_size_in_each_level = each_stripe_size * (idx_levels + 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册