mot_sde_infer.py 20.6 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
G
George Ni 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import yaml
import cv2
import numpy as np
F
Feng Ni 已提交
20 21 22 23
from collections import defaultdict
import paddle

from benchmark_utils import PaddleInferBenchmark
W
wangguanzhong 已提交
24 25
from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb
26
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
F
Feng Ni 已提交
27

W
wangguanzhong 已提交
28 29 30 31 32
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)

33 34
from pptracking.python.mot import JDETracker, DeepSORTTracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results, get_crops, clip_box
W
wangguanzhong 已提交
35
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
G
George Ni 已提交
36 37


38
class SDE_Detector(Detector):
G
George Ni 已提交
39 40 41
    """
    Args:
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
W
wangguanzhong 已提交
42
        tracker_config (str): tracker config path
43
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU/NPU, default is CPU
44
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
W
wangguanzhong 已提交
45
        batch_size (int): size of pre batch in inference
G
George Ni 已提交
46 47 48 49 50 51 52
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN
53 54 55 56 57
        output_dir (string): The path of output, default as 'output'
        threshold (float): Score threshold of the detected bbox, default as 0.5
        save_images (bool): Whether to save visualization image results, default as False
        save_mot_txts (bool): Whether to save tracking results (txt), default as False
        reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
G
George Ni 已提交
58 59 60 61
    """

    def __init__(self,
                 model_dir,
W
wangguanzhong 已提交
62
                 tracker_config,
G
George Ni 已提交
63
                 device='CPU',
64
                 run_mode='paddle',
65
                 batch_size=1,
G
George Ni 已提交
66
                 trt_min_shape=1,
W
wangguanzhong 已提交
67 68
                 trt_max_shape=1280,
                 trt_opt_shape=640,
G
George Ni 已提交
69 70
                 trt_calib_mode=False,
                 cpu_threads=1,
W
wangguanzhong 已提交
71 72
                 enable_mkldnn=False,
                 output_dir='output',
73 74 75 76
                 threshold=0.5,
                 save_images=False,
                 save_mot_txts=False,
                 reid_model_dir=None):
77 78
        super(SDE_Detector, self).__init__(
            model_dir=model_dir,
G
George Ni 已提交
79
            device=device,
80 81
            run_mode=run_mode,
            batch_size=batch_size,
G
George Ni 已提交
82 83 84 85 86
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
            trt_opt_shape=trt_opt_shape,
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
W
wangguanzhong 已提交
87 88 89
            enable_mkldnn=enable_mkldnn,
            output_dir=output_dir,
            threshold=threshold, )
90 91
        self.save_images = save_images
        self.save_mot_txts = save_mot_txts
W
wangguanzhong 已提交
92 93 94 95
        assert batch_size == 1, "MOT model only supports batch_size=1."
        self.det_times = Timer(with_tracker=True)
        self.num_classes = len(self.pred_config.labels)

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
        # reid config
        self.use_reid = False if reid_model_dir is None else True
        if self.use_reid:
            self.reid_pred_config = self.set_config(reid_model_dir)
            self.reid_predictor, self.config = load_predictor(
                reid_model_dir,
                run_mode=run_mode,
                batch_size=50,  # reid_batch_size
                min_subgraph_size=self.reid_pred_config.min_subgraph_size,
                device=device,
                use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
                trt_min_shape=trt_min_shape,
                trt_max_shape=trt_max_shape,
                trt_opt_shape=trt_opt_shape,
                trt_calib_mode=trt_calib_mode,
                cpu_threads=cpu_threads,
                enable_mkldnn=enable_mkldnn)
        else:
            self.reid_pred_config = None
            self.reid_predictor = None

        assert tracker_config is not None, 'Note that tracker_config should be set.'
W
wangguanzhong 已提交
118
        self.tracker_config = tracker_config
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        tracker_cfg = yaml.safe_load(open(self.tracker_config))
        cfg = tracker_cfg[tracker_cfg['type']]

        # tracker config
        self.use_deepsort_tracker = True if tracker_cfg[
            'type'] == 'DeepSORTTracker' else False
        if self.use_deepsort_tracker:
            # use DeepSORTTracker
            if self.reid_pred_config is not None and hasattr(
                    self.reid_pred_config, 'tracker'):
                cfg = self.reid_pred_config.tracker
            budget = cfg.get('budget', 100)
            max_age = cfg.get('max_age', 30)
            max_iou_distance = cfg.get('max_iou_distance', 0.7)
            matching_threshold = cfg.get('matching_threshold', 0.2)
            min_box_area = cfg.get('min_box_area', 0)
            vertical_ratio = cfg.get('vertical_ratio', 0)

            self.tracker = DeepSORTTracker(
                budget=budget,
                max_age=max_age,
                max_iou_distance=max_iou_distance,
                matching_threshold=matching_threshold,
                min_box_area=min_box_area,
                vertical_ratio=vertical_ratio, )
        else:
            # use ByteTracker
            use_byte = cfg.get('use_byte', False)
            det_thresh = cfg.get('det_thresh', 0.3)
            min_box_area = cfg.get('min_box_area', 0)
            vertical_ratio = cfg.get('vertical_ratio', 0)
            match_thres = cfg.get('match_thres', 0.9)
            conf_thres = cfg.get('conf_thres', 0.6)
            low_conf_thres = cfg.get('low_conf_thres', 0.1)

            self.tracker = JDETracker(
                use_byte=use_byte,
                det_thresh=det_thresh,
                num_classes=self.num_classes,
                min_box_area=min_box_area,
                vertical_ratio=vertical_ratio,
                match_thres=match_thres,
                conf_thres=conf_thres,
                low_conf_thres=low_conf_thres, )

    def postprocess(self, inputs, result):
        # postprocess output of predictor
        np_boxes_num = result['boxes_num']
        if np_boxes_num[0] <= 0:
            print('[WARNNING] No object detected.')
            result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
        result = {k: v for k, v in result.items() if v is not None}
        return result

    def reidprocess(self, det_results, repeats=1):
        pred_dets = det_results['boxes']
        pred_xyxys = pred_dets[:, 2:6]

        ori_image = det_results['ori_image']
        ori_image_shape = ori_image.shape[:2]
        pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape)

        if len(keep_idx[0]) == 0:
            det_results['boxes'] = np.zeros((1, 6), dtype=np.float32)
            det_results['embeddings'] = None
            return det_results

        pred_dets = pred_dets[keep_idx[0]]
        pred_xyxys = pred_dets[:, 2:6]

        w, h = self.tracker.input_size
        crops = get_crops(pred_xyxys, ori_image, w, h)

        # to keep fast speed, only use topk crops
        crops = crops[:50]  # reid_batch_size
        det_results['crops'] = np.array(crops).astype('float32')
        det_results['boxes'] = pred_dets[:50]

        input_names = self.reid_predictor.get_input_names()
        for i in range(len(input_names)):
            input_tensor = self.reid_predictor.get_input_handle(input_names[i])
            input_tensor.copy_from_cpu(det_results[input_names[i]])

        # model prediction
        for i in range(repeats):
            self.reid_predictor.run()
            output_names = self.reid_predictor.get_output_names()
            feature_tensor = self.reid_predictor.get_output_handle(output_names[
                0])
            pred_embs = feature_tensor.copy_to_cpu()

        det_results['embeddings'] = pred_embs
        return det_results
W
wangguanzhong 已提交
212 213

    def tracking(self, det_results):
214
        pred_dets = det_results['boxes']  # 'cls_id, score, x0, y0, x1, y1'
215 216 217 218 219 220 221
        pred_embs = det_results.get('embeddings', None)

        if self.use_deepsort_tracker:
            # use DeepSORTTracker, only support singe class
            self.tracker.predict()
            online_targets = self.tracker.update(pred_dets, pred_embs)
            online_tlwhs, online_scores, online_ids = [], [], []
W
wangguanzhong 已提交
222
            for t in online_targets:
223
                if not t.is_confirmed() or t.time_since_update > 1:
W
wangguanzhong 已提交
224
                    continue
225 226 227
                tlwh = t.to_tlwh()
                tscore = t.score
                tid = t.track_id
W
wangguanzhong 已提交
228 229 230
                if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
                        3] > self.tracker.vertical_ratio:
                    continue
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
                online_tlwhs.append(tlwh)
                online_scores.append(tscore)
                online_ids.append(tid)

            tracking_outs = {
                'online_tlwhs': online_tlwhs,
                'online_scores': online_scores,
                'online_ids': online_ids,
            }
            return tracking_outs
        else:
            # use ByteTracker, support multiple class
            online_tlwhs = defaultdict(list)
            online_scores = defaultdict(list)
            online_ids = defaultdict(list)
            online_targets_dict = self.tracker.update(pred_dets, pred_embs)
            for cls_id in range(self.num_classes):
                online_targets = online_targets_dict[cls_id]
                for t in online_targets:
                    tlwh = t.tlwh
                    tid = t.track_id
                    tscore = t.score
                    if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
                        continue
                    if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
                            3] > self.tracker.vertical_ratio:
                        continue
                    online_tlwhs[cls_id].append(tlwh)
                    online_ids[cls_id].append(tid)
                    online_scores[cls_id].append(tscore)

            tracking_outs = {
                'online_tlwhs': online_tlwhs,
                'online_scores': online_scores,
                'online_ids': online_ids,
            }
            return tracking_outs
268

W
wangguanzhong 已提交
269 270 271 272
    def predict_image(self,
                      image_list,
                      run_benchmark=False,
                      repeats=1,
273 274
                      visual=True,
                      seq_name=None):
W
wangguanzhong 已提交
275 276 277
        num_classes = self.num_classes
        image_list.sort()
        ids2names = self.pred_config.labels
278
        mot_results = []
W
wangguanzhong 已提交
279 280
        for frame_id, img_file in enumerate(image_list):
            batch_image_list = [img_file]  # bs=1 in MOT model
281
            frame, _ = decode_image(img_file, {})
W
wangguanzhong 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
            if run_benchmark:
                # preprocess
                inputs = self.preprocess(batch_image_list)  # warmup
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                # model prediction
                result_warmup = self.predict(repeats=repeats)  # warmup
                self.det_times.inference_time_s.start()
                result = self.predict(repeats=repeats)
                self.det_times.inference_time_s.end(repeats=repeats)

                # postprocess
                result_warmup = self.postprocess(inputs, result)  # warmup
                self.det_times.postprocess_time_s.start()
                det_result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()

                # tracking
302 303 304 305 306
                if self.use_reid:
                    det_result['frame_id'] = frame_id
                    det_result['seq_name'] = seq_name
                    det_result['ori_image'] = frame
                    det_result = self.reidprocess(det_result)
W
wangguanzhong 已提交
307 308
                result_warmup = self.tracking(det_result)
                self.det_times.tracking_time_s.start()
309 310 311
                if self.use_reid:
                    det_result = self.reidprocess(det_result)
                tracking_outs = self.tracking(det_result)
W
wangguanzhong 已提交
312 313 314 315 316 317 318
                self.det_times.tracking_time_s.end()
                self.det_times.img_num += 1

                cm, gm, gu = get_current_memory_mb()
                self.cpu_mem += cm
                self.gpu_mem += gm
                self.gpu_util += gu
G
George Ni 已提交
319

320
            else:
W
wangguanzhong 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                self.det_times.inference_time_s.start()
                result = self.predict()
                self.det_times.inference_time_s.end()

                self.det_times.postprocess_time_s.start()
                det_result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()

                # tracking process
                self.det_times.tracking_time_s.start()
335 336 337 338 339 340
                if self.use_reid:
                    det_result['frame_id'] = frame_id
                    det_result['seq_name'] = seq_name
                    det_result['ori_image'] = frame
                    det_result = self.reidprocess(det_result)
                tracking_outs = self.tracking(det_result)
W
wangguanzhong 已提交
341 342 343
                self.det_times.tracking_time_s.end()
                self.det_times.img_num += 1

344 345 346 347 348 349
            online_tlwhs = tracking_outs['online_tlwhs']
            online_scores = tracking_outs['online_scores']
            online_ids = tracking_outs['online_ids']

            mot_results.append([online_tlwhs, online_scores, online_ids])

W
wangguanzhong 已提交
350
            if visual:
351
                if len(image_list) > 1 and frame_id % 10 == 0:
W
wangguanzhong 已提交
352 353
                    print('Tracking frame {}'.format(frame_id))
                frame, _ = decode_image(img_file, {})
354 355 356 357 358 359 360 361
                if isinstance(online_tlwhs, defaultdict):
                    im = plot_tracking_dict(
                        frame,
                        num_classes,
                        online_tlwhs,
                        online_ids,
                        online_scores,
                        frame_id=frame_id,
362
                        ids2names=ids2names)
363 364 365 366 367 368
                else:
                    im = plot_tracking(
                        frame,
                        online_tlwhs,
                        online_ids,
                        online_scores,
369 370
                        frame_id=frame_id,
                        ids2names=ids2names)
W
wangguanzhong 已提交
371 372 373 374 375 376 377 378 379 380 381 382
                save_dir = os.path.join(self.output_dir, seq_name)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                cv2.imwrite(
                    os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)

        return mot_results

    def predict_video(self, video_file, camera_id):
        video_out_name = 'output.mp4'
        if camera_id != -1:
            capture = cv2.VideoCapture(camera_id)
383
        else:
W
wangguanzhong 已提交
384 385 386 387 388 389 390 391 392 393 394 395
            capture = cv2.VideoCapture(video_file)
            video_out_name = os.path.split(video_file)[-1]
        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        print("fps: %d, frame_count: %d" % (fps, frame_count))

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
396 397
        video_format = 'mp4v'
        fourcc = cv2.VideoWriter_fourcc(*video_format)
W
wangguanzhong 已提交
398 399 400 401
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))

        frame_id = 1
        timer = MOTTimer()
402
        results = defaultdict(list)
W
wangguanzhong 已提交
403
        num_classes = self.num_classes
404 405 406
        data_type = 'mcmot' if num_classes > 1 else 'mot'
        ids2names = self.pred_config.labels

W
wangguanzhong 已提交
407 408 409 410 411 412 413 414 415
        while (1):
            ret, frame = capture.read()
            if not ret:
                break
            if frame_id % 10 == 0:
                print('Tracking frame: %d' % (frame_id))
            frame_id += 1

            timer.tic()
416 417
            seq_name = video_out_name.split('.')[0]
            mot_results = self.predict_image(
L
lazyn1997 已提交
418
                [frame[:, :, ::-1]], visual=False, seq_name=seq_name)
419 420
            timer.toc()

421
            # bs=1 in MOT model
W
wangguanzhong 已提交
422 423 424
            online_tlwhs, online_scores, online_ids = mot_results[0]

            fps = 1. / timer.duration
425 426 427 428 429 430 431 432 433 434
            if self.use_deepsort_tracker:
                # use DeepSORTTracker, only support singe class
                results[0].append(
                    (frame_id + 1, online_tlwhs, online_scores, online_ids))
                im = plot_tracking(
                    frame,
                    online_tlwhs,
                    online_ids,
                    online_scores,
                    frame_id=frame_id,
435 436
                    fps=fps,
                    ids2names=ids2names)
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
            else:
                # use ByteTracker, support multiple class
                for cls_id in range(num_classes):
                    results[cls_id].append(
                        (frame_id + 1, online_tlwhs[cls_id],
                         online_scores[cls_id], online_ids[cls_id]))
                im = plot_tracking_dict(
                    frame,
                    num_classes,
                    online_tlwhs,
                    online_ids,
                    online_scores,
                    frame_id=frame_id,
                    fps=fps,
                    ids2names=ids2names)
452

W
wangguanzhong 已提交
453 454 455 456 457
            writer.write(im)
            if camera_id != -1:
                cv2.imshow('Mask Detection', im)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
458 459 460 461 462 463

        if self.save_mot_txts:
            result_filename = os.path.join(
                self.output_dir, video_out_name.split('.')[-2] + '.txt')
            write_mot_results(result_filename, results)

G
George Ni 已提交
464
        writer.release()
G
George Ni 已提交
465 466 467


def main():
W
wangguanzhong 已提交
468 469 470 471 472 473
    deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
    with open(deploy_file) as f:
        yml_conf = yaml.safe_load(f)
    arch = yml_conf['arch']
    detector = SDE_Detector(
        FLAGS.model_dir,
474
        tracker_config=FLAGS.tracker_config,
G
George Ni 已提交
475 476
        device=FLAGS.device,
        run_mode=FLAGS.run_mode,
477
        batch_size=1,
G
George Ni 已提交
478 479 480 481 482
        trt_min_shape=FLAGS.trt_min_shape,
        trt_max_shape=FLAGS.trt_max_shape,
        trt_opt_shape=FLAGS.trt_opt_shape,
        trt_calib_mode=FLAGS.trt_calib_mode,
        cpu_threads=FLAGS.cpu_threads,
W
wangguanzhong 已提交
483
        enable_mkldnn=FLAGS.enable_mkldnn,
484
        output_dir=FLAGS.output_dir,
W
wangguanzhong 已提交
485
        threshold=FLAGS.threshold,
486 487
        save_images=FLAGS.save_images,
        save_mot_txts=FLAGS.save_mot_txts, )
G
George Ni 已提交
488 489 490

    # predict from video file or camera video stream
    if FLAGS.video_file is not None or FLAGS.camera_id != -1:
W
wangguanzhong 已提交
491
        detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
G
George Ni 已提交
492 493
    else:
        # predict from image
W
wangguanzhong 已提交
494 495
        if FLAGS.image_dir is None and FLAGS.image_file is not None:
            assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
G
George Ni 已提交
496
        img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
497 498 499
        seq_name = FLAGS.image_dir.split('/')[-1]
        detector.predict_image(
            img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
G
George Ni 已提交
500 501 502 503 504

        if not FLAGS.run_benchmark:
            detector.det_times.info(average=True)
        else:
            mode = FLAGS.run_mode
W
wangguanzhong 已提交
505 506 507
            model_dir = FLAGS.model_dir
            model_info = {
                'model_name': model_dir.strip('/').split('/')[-1],
G
George Ni 已提交
508 509
                'precision': mode.split('_')[-1]
            }
W
wangguanzhong 已提交
510
            bench_log(detector, img_list, model_info, name='MOT')
G
George Ni 已提交
511 512 513 514 515 516 517 518


if __name__ == '__main__':
    paddle.enable_static()
    parser = argsparser()
    FLAGS = parser.parse_args()
    print_arguments(FLAGS)
    FLAGS.device = FLAGS.device.upper()
519 520
    assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU'
                            ], "device should be CPU, GPU, NPU or XPU"
G
George Ni 已提交
521 522

    main()