mot_sde_infer.py 20.5 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
G
George Ni 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import yaml
import cv2
import numpy as np
F
Feng Ni 已提交
20 21 22 23
from collections import defaultdict
import paddle

from benchmark_utils import PaddleInferBenchmark
W
wangguanzhong 已提交
24 25
from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb
26
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
F
Feng Ni 已提交
27

W
wangguanzhong 已提交
28 29 30 31 32
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)

33 34
from pptracking.python.mot import JDETracker, DeepSORTTracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results, get_crops, clip_box
W
wangguanzhong 已提交
35
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
G
George Ni 已提交
36 37


38
class SDE_Detector(Detector):
G
George Ni 已提交
39 40 41
    """
    Args:
        model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
W
wangguanzhong 已提交
42
        tracker_config (str): tracker config path
G
George Ni 已提交
43
        device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
44
        run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
W
wangguanzhong 已提交
45
        batch_size (int): size of pre batch in inference
G
George Ni 已提交
46 47 48 49 50 51 52
        trt_min_shape (int): min shape for dynamic shape in trt
        trt_max_shape (int): max shape for dynamic shape in trt
        trt_opt_shape (int): opt shape for dynamic shape in trt
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True
        cpu_threads (int): cpu threads
        enable_mkldnn (bool): whether to open MKLDNN
53 54 55 56 57
        output_dir (string): The path of output, default as 'output'
        threshold (float): Score threshold of the detected bbox, default as 0.5
        save_images (bool): Whether to save visualization image results, default as False
        save_mot_txts (bool): Whether to save tracking results (txt), default as False
        reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
G
George Ni 已提交
58 59 60 61
    """

    def __init__(self,
                 model_dir,
W
wangguanzhong 已提交
62
                 tracker_config,
G
George Ni 已提交
63
                 device='CPU',
64
                 run_mode='paddle',
65
                 batch_size=1,
G
George Ni 已提交
66
                 trt_min_shape=1,
W
wangguanzhong 已提交
67 68
                 trt_max_shape=1280,
                 trt_opt_shape=640,
G
George Ni 已提交
69 70
                 trt_calib_mode=False,
                 cpu_threads=1,
W
wangguanzhong 已提交
71 72
                 enable_mkldnn=False,
                 output_dir='output',
73 74 75 76
                 threshold=0.5,
                 save_images=False,
                 save_mot_txts=False,
                 reid_model_dir=None):
77 78
        super(SDE_Detector, self).__init__(
            model_dir=model_dir,
G
George Ni 已提交
79
            device=device,
80 81
            run_mode=run_mode,
            batch_size=batch_size,
G
George Ni 已提交
82 83 84 85 86
            trt_min_shape=trt_min_shape,
            trt_max_shape=trt_max_shape,
            trt_opt_shape=trt_opt_shape,
            trt_calib_mode=trt_calib_mode,
            cpu_threads=cpu_threads,
W
wangguanzhong 已提交
87 88 89
            enable_mkldnn=enable_mkldnn,
            output_dir=output_dir,
            threshold=threshold, )
90 91
        self.save_images = save_images
        self.save_mot_txts = save_mot_txts
W
wangguanzhong 已提交
92 93 94 95
        assert batch_size == 1, "MOT model only supports batch_size=1."
        self.det_times = Timer(with_tracker=True)
        self.num_classes = len(self.pred_config.labels)

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
        # reid config
        self.use_reid = False if reid_model_dir is None else True
        if self.use_reid:
            self.reid_pred_config = self.set_config(reid_model_dir)
            self.reid_predictor, self.config = load_predictor(
                reid_model_dir,
                run_mode=run_mode,
                batch_size=50,  # reid_batch_size
                min_subgraph_size=self.reid_pred_config.min_subgraph_size,
                device=device,
                use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
                trt_min_shape=trt_min_shape,
                trt_max_shape=trt_max_shape,
                trt_opt_shape=trt_opt_shape,
                trt_calib_mode=trt_calib_mode,
                cpu_threads=cpu_threads,
                enable_mkldnn=enable_mkldnn)
        else:
            self.reid_pred_config = None
            self.reid_predictor = None

        assert tracker_config is not None, 'Note that tracker_config should be set.'
W
wangguanzhong 已提交
118
        self.tracker_config = tracker_config
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        tracker_cfg = yaml.safe_load(open(self.tracker_config))
        cfg = tracker_cfg[tracker_cfg['type']]

        # tracker config
        self.use_deepsort_tracker = True if tracker_cfg[
            'type'] == 'DeepSORTTracker' else False
        if self.use_deepsort_tracker:
            # use DeepSORTTracker
            if self.reid_pred_config is not None and hasattr(
                    self.reid_pred_config, 'tracker'):
                cfg = self.reid_pred_config.tracker
            budget = cfg.get('budget', 100)
            max_age = cfg.get('max_age', 30)
            max_iou_distance = cfg.get('max_iou_distance', 0.7)
            matching_threshold = cfg.get('matching_threshold', 0.2)
            min_box_area = cfg.get('min_box_area', 0)
            vertical_ratio = cfg.get('vertical_ratio', 0)

            self.tracker = DeepSORTTracker(
                budget=budget,
                max_age=max_age,
                max_iou_distance=max_iou_distance,
                matching_threshold=matching_threshold,
                min_box_area=min_box_area,
                vertical_ratio=vertical_ratio, )
        else:
            # use ByteTracker
            use_byte = cfg.get('use_byte', False)
            det_thresh = cfg.get('det_thresh', 0.3)
            min_box_area = cfg.get('min_box_area', 0)
            vertical_ratio = cfg.get('vertical_ratio', 0)
            match_thres = cfg.get('match_thres', 0.9)
            conf_thres = cfg.get('conf_thres', 0.6)
            low_conf_thres = cfg.get('low_conf_thres', 0.1)

            self.tracker = JDETracker(
                use_byte=use_byte,
                det_thresh=det_thresh,
                num_classes=self.num_classes,
                min_box_area=min_box_area,
                vertical_ratio=vertical_ratio,
                match_thres=match_thres,
                conf_thres=conf_thres,
                low_conf_thres=low_conf_thres, )

    def postprocess(self, inputs, result):
        # postprocess output of predictor
        np_boxes_num = result['boxes_num']
        if np_boxes_num[0] <= 0:
            print('[WARNNING] No object detected.')
            result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
        result = {k: v for k, v in result.items() if v is not None}
        return result

    def reidprocess(self, det_results, repeats=1):
        pred_dets = det_results['boxes']
        pred_xyxys = pred_dets[:, 2:6]

        ori_image = det_results['ori_image']
        ori_image_shape = ori_image.shape[:2]
        pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape)

        if len(keep_idx[0]) == 0:
            det_results['boxes'] = np.zeros((1, 6), dtype=np.float32)
            det_results['embeddings'] = None
            return det_results

        pred_dets = pred_dets[keep_idx[0]]
        pred_xyxys = pred_dets[:, 2:6]

        w, h = self.tracker.input_size
        crops = get_crops(pred_xyxys, ori_image, w, h)

        # to keep fast speed, only use topk crops
        crops = crops[:50]  # reid_batch_size
        det_results['crops'] = np.array(crops).astype('float32')
        det_results['boxes'] = pred_dets[:50]

        input_names = self.reid_predictor.get_input_names()
        for i in range(len(input_names)):
            input_tensor = self.reid_predictor.get_input_handle(input_names[i])
            input_tensor.copy_from_cpu(det_results[input_names[i]])

        # model prediction
        for i in range(repeats):
            self.reid_predictor.run()
            output_names = self.reid_predictor.get_output_names()
            feature_tensor = self.reid_predictor.get_output_handle(output_names[
                0])
            pred_embs = feature_tensor.copy_to_cpu()

        det_results['embeddings'] = pred_embs
        return det_results
W
wangguanzhong 已提交
212 213

    def tracking(self, det_results):
214
        pred_dets = det_results['boxes']  # 'cls_id, score, x0, y0, x1, y1'
215 216 217 218 219 220 221
        pred_embs = det_results.get('embeddings', None)

        if self.use_deepsort_tracker:
            # use DeepSORTTracker, only support singe class
            self.tracker.predict()
            online_targets = self.tracker.update(pred_dets, pred_embs)
            online_tlwhs, online_scores, online_ids = [], [], []
W
wangguanzhong 已提交
222
            for t in online_targets:
223
                if not t.is_confirmed() or t.time_since_update > 1:
W
wangguanzhong 已提交
224
                    continue
225 226 227
                tlwh = t.to_tlwh()
                tscore = t.score
                tid = t.track_id
W
wangguanzhong 已提交
228 229 230
                if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
                        3] > self.tracker.vertical_ratio:
                    continue
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
                online_tlwhs.append(tlwh)
                online_scores.append(tscore)
                online_ids.append(tid)

            tracking_outs = {
                'online_tlwhs': online_tlwhs,
                'online_scores': online_scores,
                'online_ids': online_ids,
            }
            return tracking_outs
        else:
            # use ByteTracker, support multiple class
            online_tlwhs = defaultdict(list)
            online_scores = defaultdict(list)
            online_ids = defaultdict(list)
            online_targets_dict = self.tracker.update(pred_dets, pred_embs)
            for cls_id in range(self.num_classes):
                online_targets = online_targets_dict[cls_id]
                for t in online_targets:
                    tlwh = t.tlwh
                    tid = t.track_id
                    tscore = t.score
                    if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
                        continue
                    if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
                            3] > self.tracker.vertical_ratio:
                        continue
                    online_tlwhs[cls_id].append(tlwh)
                    online_ids[cls_id].append(tid)
                    online_scores[cls_id].append(tscore)

            tracking_outs = {
                'online_tlwhs': online_tlwhs,
                'online_scores': online_scores,
                'online_ids': online_ids,
            }
            return tracking_outs
268

W
wangguanzhong 已提交
269 270 271 272
    def predict_image(self,
                      image_list,
                      run_benchmark=False,
                      repeats=1,
273 274
                      visual=True,
                      seq_name=None):
W
wangguanzhong 已提交
275 276 277
        num_classes = self.num_classes
        image_list.sort()
        ids2names = self.pred_config.labels
278
        mot_results = []
W
wangguanzhong 已提交
279 280
        for frame_id, img_file in enumerate(image_list):
            batch_image_list = [img_file]  # bs=1 in MOT model
281
            frame, _ = decode_image(img_file, {})
W
wangguanzhong 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
            if run_benchmark:
                # preprocess
                inputs = self.preprocess(batch_image_list)  # warmup
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                # model prediction
                result_warmup = self.predict(repeats=repeats)  # warmup
                self.det_times.inference_time_s.start()
                result = self.predict(repeats=repeats)
                self.det_times.inference_time_s.end(repeats=repeats)

                # postprocess
                result_warmup = self.postprocess(inputs, result)  # warmup
                self.det_times.postprocess_time_s.start()
                det_result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()

                # tracking
302 303 304 305 306
                if self.use_reid:
                    det_result['frame_id'] = frame_id
                    det_result['seq_name'] = seq_name
                    det_result['ori_image'] = frame
                    det_result = self.reidprocess(det_result)
W
wangguanzhong 已提交
307 308
                result_warmup = self.tracking(det_result)
                self.det_times.tracking_time_s.start()
309 310 311
                if self.use_reid:
                    det_result = self.reidprocess(det_result)
                tracking_outs = self.tracking(det_result)
W
wangguanzhong 已提交
312 313 314 315 316 317 318
                self.det_times.tracking_time_s.end()
                self.det_times.img_num += 1

                cm, gm, gu = get_current_memory_mb()
                self.cpu_mem += cm
                self.gpu_mem += gm
                self.gpu_util += gu
G
George Ni 已提交
319

320
            else:
W
wangguanzhong 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334
                self.det_times.preprocess_time_s.start()
                inputs = self.preprocess(batch_image_list)
                self.det_times.preprocess_time_s.end()

                self.det_times.inference_time_s.start()
                result = self.predict()
                self.det_times.inference_time_s.end()

                self.det_times.postprocess_time_s.start()
                det_result = self.postprocess(inputs, result)
                self.det_times.postprocess_time_s.end()

                # tracking process
                self.det_times.tracking_time_s.start()
335 336 337 338 339 340
                if self.use_reid:
                    det_result['frame_id'] = frame_id
                    det_result['seq_name'] = seq_name
                    det_result['ori_image'] = frame
                    det_result = self.reidprocess(det_result)
                tracking_outs = self.tracking(det_result)
W
wangguanzhong 已提交
341 342 343
                self.det_times.tracking_time_s.end()
                self.det_times.img_num += 1

344 345 346 347 348 349
            online_tlwhs = tracking_outs['online_tlwhs']
            online_scores = tracking_outs['online_scores']
            online_ids = tracking_outs['online_ids']

            mot_results.append([online_tlwhs, online_scores, online_ids])

W
wangguanzhong 已提交
350
            if visual:
351
                if len(image_list) > 1 and frame_id % 10 == 0:
W
wangguanzhong 已提交
352 353
                    print('Tracking frame {}'.format(frame_id))
                frame, _ = decode_image(img_file, {})
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
                if isinstance(online_tlwhs, defaultdict):
                    im = plot_tracking_dict(
                        frame,
                        num_classes,
                        online_tlwhs,
                        online_ids,
                        online_scores,
                        frame_id=frame_id,
                        ids2names=[])
                else:
                    im = plot_tracking(
                        frame,
                        online_tlwhs,
                        online_ids,
                        online_scores,
                        frame_id=frame_id)
W
wangguanzhong 已提交
370 371 372 373 374 375 376 377 378 379 380 381
                save_dir = os.path.join(self.output_dir, seq_name)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                cv2.imwrite(
                    os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)

        return mot_results

    def predict_video(self, video_file, camera_id):
        video_out_name = 'output.mp4'
        if camera_id != -1:
            capture = cv2.VideoCapture(camera_id)
382
        else:
W
wangguanzhong 已提交
383 384 385 386 387 388 389 390 391 392 393 394
            capture = cv2.VideoCapture(video_file)
            video_out_name = os.path.split(video_file)[-1]
        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        print("fps: %d, frame_count: %d" % (fps, frame_count))

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
395 396
        video_format = 'mp4v'
        fourcc = cv2.VideoWriter_fourcc(*video_format)
W
wangguanzhong 已提交
397 398 399 400
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))

        frame_id = 1
        timer = MOTTimer()
401
        results = defaultdict(list)
W
wangguanzhong 已提交
402
        num_classes = self.num_classes
403 404 405
        data_type = 'mcmot' if num_classes > 1 else 'mot'
        ids2names = self.pred_config.labels

W
wangguanzhong 已提交
406 407 408 409 410 411 412 413 414
        while (1):
            ret, frame = capture.read()
            if not ret:
                break
            if frame_id % 10 == 0:
                print('Tracking frame: %d' % (frame_id))
            frame_id += 1

            timer.tic()
415 416
            seq_name = video_out_name.split('.')[0]
            mot_results = self.predict_image(
lazyn's avatar
lazyn 已提交
417
                [frame[:, :, ::-1]], visual=False, seq_name=seq_name)
418 419
            timer.toc()

420
            # bs=1 in MOT model
W
wangguanzhong 已提交
421 422 423
            online_tlwhs, online_scores, online_ids = mot_results[0]

            fps = 1. / timer.duration
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
            if self.use_deepsort_tracker:
                # use DeepSORTTracker, only support singe class
                results[0].append(
                    (frame_id + 1, online_tlwhs, online_scores, online_ids))
                im = plot_tracking(
                    frame,
                    online_tlwhs,
                    online_ids,
                    online_scores,
                    frame_id=frame_id,
                    fps=fps)
            else:
                # use ByteTracker, support multiple class
                for cls_id in range(num_classes):
                    results[cls_id].append(
                        (frame_id + 1, online_tlwhs[cls_id],
                         online_scores[cls_id], online_ids[cls_id]))
                im = plot_tracking_dict(
                    frame,
                    num_classes,
                    online_tlwhs,
                    online_ids,
                    online_scores,
                    frame_id=frame_id,
                    fps=fps,
                    ids2names=ids2names)
450

W
wangguanzhong 已提交
451 452 453 454 455
            writer.write(im)
            if camera_id != -1:
                cv2.imshow('Mask Detection', im)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
456 457 458 459 460 461

        if self.save_mot_txts:
            result_filename = os.path.join(
                self.output_dir, video_out_name.split('.')[-2] + '.txt')
            write_mot_results(result_filename, results)

G
George Ni 已提交
462
        writer.release()
G
George Ni 已提交
463 464 465


def main():
W
wangguanzhong 已提交
466 467 468 469 470 471
    deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
    with open(deploy_file) as f:
        yml_conf = yaml.safe_load(f)
    arch = yml_conf['arch']
    detector = SDE_Detector(
        FLAGS.model_dir,
472
        tracker_config=FLAGS.tracker_config,
G
George Ni 已提交
473 474
        device=FLAGS.device,
        run_mode=FLAGS.run_mode,
475
        batch_size=1,
G
George Ni 已提交
476 477 478 479 480
        trt_min_shape=FLAGS.trt_min_shape,
        trt_max_shape=FLAGS.trt_max_shape,
        trt_opt_shape=FLAGS.trt_opt_shape,
        trt_calib_mode=FLAGS.trt_calib_mode,
        cpu_threads=FLAGS.cpu_threads,
W
wangguanzhong 已提交
481
        enable_mkldnn=FLAGS.enable_mkldnn,
482
        output_dir=FLAGS.output_dir,
W
wangguanzhong 已提交
483
        threshold=FLAGS.threshold,
484 485
        save_images=FLAGS.save_images,
        save_mot_txts=FLAGS.save_mot_txts, )
G
George Ni 已提交
486 487 488

    # predict from video file or camera video stream
    if FLAGS.video_file is not None or FLAGS.camera_id != -1:
W
wangguanzhong 已提交
489
        detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
G
George Ni 已提交
490 491
    else:
        # predict from image
W
wangguanzhong 已提交
492 493
        if FLAGS.image_dir is None and FLAGS.image_file is not None:
            assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
G
George Ni 已提交
494
        img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
495 496 497
        seq_name = FLAGS.image_dir.split('/')[-1]
        detector.predict_image(
            img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
G
George Ni 已提交
498 499 500 501 502

        if not FLAGS.run_benchmark:
            detector.det_times.info(average=True)
        else:
            mode = FLAGS.run_mode
W
wangguanzhong 已提交
503 504 505
            model_dir = FLAGS.model_dir
            model_info = {
                'model_name': model_dir.strip('/').split('/')[-1],
G
George Ni 已提交
506 507
                'precision': mode.split('_')[-1]
            }
W
wangguanzhong 已提交
508
            bench_log(detector, img_list, model_info, name='MOT')
G
George Ni 已提交
509 510 511 512 513 514 515 516 517 518 519 520


if __name__ == '__main__':
    paddle.enable_static()
    parser = argsparser()
    FLAGS = parser.parse_args()
    print_arguments(FLAGS)
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"

    main()