pipeline.py 44.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import glob
import cv2
import numpy as np
import math
import paddle
import sys
Z
zhiboniu 已提交
23
import copy
Z
zhiboniu 已提交
24
from collections import Sequence, defaultdict
Z
zhiboniu 已提交
25
from datacollector import DataCollector, Result
26 27 28 29 30

# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)

31 32
from cfg_utils import argsparser, print_arguments, merge_cfg
from pipe_utils import PipeTimer
Z
zhiboniu 已提交
33 34
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint

35
from python.infer import Detector, DetectorPicoDet
J
JYChen 已提交
36 37
from python.keypoint_infer import KeyPointDetector
from python.keypoint_postprocess import translate_to_ori_images
38
from python.preprocess import decode_image, ShortSizeScale
Z
zhiboniu 已提交
39
from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action, visualize_vehicleplate
40 41

from pptracking.python.mot_sde_infer import SDE_Detector
42
from pptracking.python.mot.visualize import plot_tracking_dict
43
from pptracking.python.mot.utils import flow_statistic, update_object_info
44

Z
zhiboniu 已提交
45 46 47 48 49 50 51
from pphuman.attr_infer import AttrDetector
from pphuman.video_action_infer import VideoActionRecognizer
from pphuman.action_infer import SkeletonActionRecognizer, DetActionRecognizer, ClsActionRecognizer
from pphuman.action_utils import KeyPointBuff, ActionVisualHelper
from pphuman.reid import ReID
from pphuman.mtmct import mtmct_process

52 53 54
from ppvehicle.vehicle_plate import PlateRecognizer
from ppvehicle.vehicle_attr import VehicleAttr

55 56
from download import auto_download_model

57 58 59 60 61 62

class Pipeline(object):
    """
    Pipeline

    Args:
J
JYChen 已提交
63
        args (argparse.Namespace): arguments in pipeline, which contains environment and runtime settings
64 65 66
        cfg (dict): config of models in pipeline
    """

Z
zhiboniu 已提交
67
    def __init__(self, args, cfg):
68
        self.multi_camera = False
Z
zhiboniu 已提交
69 70
        reid_cfg = cfg.get('REID', False)
        self.enable_mtmct = reid_cfg['enable'] if reid_cfg else False
71
        self.is_video = False
Z
zhiboniu 已提交
72
        self.output_dir = args.output_dir
Z
zhiboniu 已提交
73
        self.vis_result = cfg['visual']
Z
zhiboniu 已提交
74 75
        self.input = self._parse_input(args.image_file, args.image_dir,
                                       args.video_file, args.video_dir,
76
                                       args.camera_id, args.rtsp)
77
        if self.multi_camera:
78 79 80
            self.predictor = []
            for name in self.input:
                predictor_item = PipePredictor(
Z
zhiboniu 已提交
81
                    args, cfg, is_video=True, multi_camera=True)
82 83 84
                predictor_item.set_file_name(name)
                self.predictor.append(predictor_item)

85
        else:
Z
zhiboniu 已提交
86
            self.predictor = PipePredictor(args, cfg, self.is_video)
87
            if self.is_video:
88
                self.predictor.set_file_name(self.input)
89

Z
zhiboniu 已提交
90
    def _parse_input(self, image_file, image_dir, video_file, video_dir,
91
                     camera_id, rtsp):
92 93 94 95 96 97 98 99 100

        # parse input as is_video and multi_camera

        if image_file is not None or image_dir is not None:
            input = get_test_images(image_dir, image_file)
            self.is_video = False
            self.multi_camera = False

        elif video_file is not None:
Z
zhiboniu 已提交
101 102 103
            assert os.path.exists(
                video_file
            ) or 'rtsp' in video_file, "video_file not exists and not an rtsp site."
Z
zhiboniu 已提交
104 105 106 107 108 109 110
            self.multi_camera = False
            input = video_file
            self.is_video = True

        elif video_dir is not None:
            videof = [os.path.join(video_dir, x) for x in os.listdir(video_dir)]
            if len(videof) > 1:
111
                self.multi_camera = True
Z
zhiboniu 已提交
112 113
                videof.sort()
                input = videof
114
            else:
Z
zhiboniu 已提交
115
                input = videof[0]
116 117
            self.is_video = True

118 119 120 121 122 123 124 125 126 127
        elif rtsp is not None:
            if len(rtsp) > 1:
                rtsp = [rtsp_item for rtsp_item in rtsp if 'rtsp' in rtsp_item]
                self.multi_camera = True
                input = rtsp
            else:
                self.multi_camera = False
                input = rtsp[0]
            self.is_video = True

128
        elif camera_id != -1:
Z
zhiboniu 已提交
129 130
            self.multi_camera = False
            input = camera_id
131 132 133 134
            self.is_video = True

        else:
            raise ValueError(
135
                "Illegal Input, please set one of ['video_file', 'camera_id', 'image_file', 'image_dir']"
136 137 138 139
            )

        return input

140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    def run_multithreads(self):
        import threading
        if self.multi_camera:
            multi_res = []
            threads = []
            for idx, (predictor,
                      input) in enumerate(zip(self.predictor, self.input)):
                thread = threading.Thread(
                    name=str(idx).zfill(3),
                    target=predictor.run,
                    args=(input, idx))
                threads.append(thread)

            for thread in threads:
                thread.start()

            for predictor, thread in zip(self.predictor, threads):
                thread.join()
                collector_data = predictor.get_result()
                multi_res.append(collector_data)

            if self.enable_mtmct:
                mtmct_process(
                    multi_res,
                    self.input,
                    mtmct_vis=self.vis_result,
                    output_dir=self.output_dir)

        else:
            self.predictor.run(self.input)

171 172 173 174 175
    def run(self):
        if self.multi_camera:
            multi_res = []
            for predictor, input in zip(self.predictor, self.input):
                predictor.run(input)
Z
zhiboniu 已提交
176 177
                collector_data = predictor.get_result()
                multi_res.append(collector_data)
178 179 180 181 182 183
            if self.enable_mtmct:
                mtmct_process(
                    multi_res,
                    self.input,
                    mtmct_vis=self.vis_result,
                    output_dir=self.output_dir)
184 185 186 187 188

        else:
            self.predictor.run(self.input)


189
def get_model_dir(cfg):
J
JYChen 已提交
190 191 192 193
    """ 
        Auto download inference model if the model_path is a url link. 
        Otherwise it will use the model_path directly.
    """
194 195 196 197 198 199 200 201 202 203
    for key in cfg.keys():
        if type(cfg[key]) ==  dict and \
            ("enable" in cfg[key].keys() and cfg[key]['enable']
                or "enable" not in cfg[key].keys()):

            if "model_dir" in cfg[key].keys():
                model_dir = cfg[key]["model_dir"]
                downloaded_model_dir = auto_download_model(model_dir)
                if downloaded_model_dir:
                    model_dir = downloaded_model_dir
J
JYChen 已提交
204 205
                    cfg[key]["model_dir"] = model_dir
                print(key, " model dir: ", model_dir)
206 207 208 209 210
            elif key == "VEHICLE_PLATE":
                det_model_dir = cfg[key]["det_model_dir"]
                downloaded_det_model_dir = auto_download_model(det_model_dir)
                if downloaded_det_model_dir:
                    det_model_dir = downloaded_det_model_dir
J
JYChen 已提交
211 212
                    cfg[key]["det_model_dir"] = det_model_dir
                print("det_model_dir model dir: ", det_model_dir)
213 214 215 216 217

                rec_model_dir = cfg[key]["rec_model_dir"]
                downloaded_rec_model_dir = auto_download_model(rec_model_dir)
                if downloaded_rec_model_dir:
                    rec_model_dir = downloaded_rec_model_dir
J
JYChen 已提交
218 219 220
                    cfg[key]["rec_model_dir"] = rec_model_dir
                print("rec_model_dir model dir: ", rec_model_dir)

221 222 223 224 225
        elif key == "MOT":  # for idbased and skeletonbased actions
            model_dir = cfg[key]["model_dir"]
            downloaded_model_dir = auto_download_model(model_dir)
            if downloaded_model_dir:
                model_dir = downloaded_model_dir
J
JYChen 已提交
226 227
                cfg[key]["model_dir"] = model_dir
            print("mot_model_dir model_dir: ", model_dir)
228 229


230 231 232 233 234 235 236 237 238 239 240 241 242
class PipePredictor(object):
    """
    Predictor in single camera
    
    The pipeline for image input: 

        1. Detection
        2. Detection -> Attribute

    The pipeline for video input: 

        1. Tracking
        2. Tracking -> Attribute
Z
zhiboniu 已提交
243
        3. Tracking -> KeyPoint -> SkeletonAction Recognition
244
        4. VideoAction Recognition
245 246

    Args:
J
JYChen 已提交
247
        args (argparse.Namespace): arguments in pipeline, which contains environment and runtime settings
248 249 250 251 252 253
        cfg (dict): config of models in pipeline
        is_video (bool): whether the input is video, default as False
        multi_camera (bool): whether to use multi camera in pipeline, 
            default as False
    """

Z
zhiboniu 已提交
254 255 256 257
    def __init__(self, args, cfg, is_video=True, multi_camera=False):
        # general module for pphuman and ppvehicle
        self.with_mot = cfg.get('MOT', False)['enable'] if cfg.get(
            'MOT', False) else False
258
        self.with_human_attr = cfg.get('ATTR', False)['enable'] if cfg.get(
Z
zhiboniu 已提交
259
            'ATTR', False) else False
Z
zhiboniu 已提交
260 261
        if self.with_mot:
            print('Multi-Object Tracking enabled')
262 263
        if self.with_human_attr:
            print('Human Attribute Recognition enabled')
Z
zhiboniu 已提交
264 265

        # only for pphuman
Z
zhiboniu 已提交
266 267 268
        self.with_skeleton_action = cfg.get(
            'SKELETON_ACTION', False)['enable'] if cfg.get('SKELETON_ACTION',
                                                           False) else False
Z
zhiboniu 已提交
269 270 271 272 273 274 275 276 277
        self.with_video_action = cfg.get(
            'VIDEO_ACTION', False)['enable'] if cfg.get('VIDEO_ACTION',
                                                        False) else False
        self.with_idbased_detaction = cfg.get(
            'ID_BASED_DETACTION', False)['enable'] if cfg.get(
                'ID_BASED_DETACTION', False) else False
        self.with_idbased_clsaction = cfg.get(
            'ID_BASED_CLSACTION', False)['enable'] if cfg.get(
                'ID_BASED_CLSACTION', False) else False
Z
zhiboniu 已提交
278 279
        self.with_mtmct = cfg.get('REID', False)['enable'] if cfg.get(
            'REID', False) else False
280

Z
zhiboniu 已提交
281 282
        if self.with_skeleton_action:
            print('SkeletonAction Recognition enabled')
Z
zhiboniu 已提交
283 284 285 286 287 288
        if self.with_video_action:
            print('VideoAction Recognition enabled')
        if self.with_idbased_detaction:
            print('IDBASED Detection Action Recognition enabled')
        if self.with_idbased_clsaction:
            print('IDBASED Classification Action Recognition enabled')
Z
zhiboniu 已提交
289 290
        if self.with_mtmct:
            print("MTMCT enabled")
W
wangguanzhong 已提交
291

Z
zhiboniu 已提交
292 293 294 295 296 297 298
        # only for ppvehicle
        self.with_vehicleplate = cfg.get(
            'VEHICLE_PLATE', False)['enable'] if cfg.get('VEHICLE_PLATE',
                                                         False) else False
        if self.with_vehicleplate:
            print('Vehicle Plate Recognition enabled')

299 300 301 302 303 304
        self.with_vehicle_attr = cfg.get(
            'VEHICLE_ATTR', False)['enable'] if cfg.get('VEHICLE_ATTR',
                                                        False) else False
        if self.with_vehicle_attr:
            print('Vehicle Attribute Recognition enabled')

305 306 307 308 309 310
        self.modebase = {
            "framebased": False,
            "videobased": False,
            "idbased": False,
            "skeletonbased": False
        }
311

312 313 314 315 316 317 318 319 320 321 322 323
        self.basemode = {
            "MOT": "idbased",
            "ATTR": "idbased",
            "VIDEO_ACTION": "videobased",
            "SKELETON_ACTION": "skeletonbased",
            "ID_BASED_DETACTION": "idbased",
            "ID_BASED_CLSACTION": "idbased",
            "REID": "idbased",
            "VEHICLE_PLATE": "idbased",
            "VEHICLE_ATTR": "idbased",
        }

324 325 326
        self.is_video = is_video
        self.multi_camera = multi_camera
        self.cfg = cfg
327

J
JYChen 已提交
328 329 330 331 332 333 334
        self.output_dir = args.output_dir
        self.draw_center_traj = args.draw_center_traj
        self.secs_interval = args.secs_interval
        self.do_entrance_counting = args.do_entrance_counting
        self.do_break_in_counting = args.do_break_in_counting
        self.region_type = args.region_type
        self.region_polygon = args.region_polygon
335
        self.illegal_parking_time = args.illegal_parking_time
336

J
JYChen 已提交
337
        self.warmup_frame = self.cfg['warmup_frame']
338 339
        self.pipeline_res = Result()
        self.pipe_timer = PipeTimer()
340
        self.file_name = None
Z
zhiboniu 已提交
341
        self.collector = DataCollector()
342

343
        # auto download inference model
J
JYChen 已提交
344
        get_model_dir(self.cfg)
345

Z
zhiboniu 已提交
346 347 348 349 350 351 352 353 354 355
        if self.with_vehicleplate:
            vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
            self.vehicleplate_detector = PlateRecognizer(args, vehicleplate_cfg)
            basemode = self.basemode['VEHICLE_PLATE']
            self.modebase[basemode] = True

        if self.with_human_attr:
            attr_cfg = self.cfg['ATTR']
            basemode = self.basemode['ATTR']
            self.modebase[basemode] = True
J
JYChen 已提交
356
            self.attr_predictor = AttrDetector.init_with_cfg(args, attr_cfg)
Z
zhiboniu 已提交
357 358 359 360 361

        if self.with_vehicle_attr:
            vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
            basemode = self.basemode['VEHICLE_ATTR']
            self.modebase[basemode] = True
J
JYChen 已提交
362 363
            self.vehicle_attr_predictor = VehicleAttr.init_with_cfg(
                args, vehicleattr_cfg)
Z
zhiboniu 已提交
364

365 366
        if not is_video:
            det_cfg = self.cfg['DET']
J
JYChen 已提交
367
            model_dir = det_cfg['model_dir']
368 369
            batch_size = det_cfg['batch_size']
            self.det_predictor = Detector(
J
JYChen 已提交
370 371 372
                model_dir, args.device, args.run_mode, batch_size,
                args.trt_min_shape, args.trt_max_shape, args.trt_opt_shape,
                args.trt_calib_mode, args.cpu_threads, args.enable_mkldnn)
373
        else:
Z
zhiboniu 已提交
374
            if self.with_idbased_detaction:
J
JYChen 已提交
375
                idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION']
376
                basemode = self.basemode['ID_BASED_DETACTION']
J
JYChen 已提交
377
                self.modebase[basemode] = True
378

J
JYChen 已提交
379 380
                self.det_action_predictor = DetActionRecognizer.init_with_cfg(
                    args, idbased_detaction_cfg)
J
JYChen 已提交
381 382
                self.det_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
383
            if self.with_idbased_clsaction:
J
JYChen 已提交
384
                idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION']
385
                basemode = self.basemode['ID_BASED_CLSACTION']
J
JYChen 已提交
386
                self.modebase[basemode] = True
387

J
JYChen 已提交
388 389
                self.cls_action_predictor = ClsActionRecognizer.init_with_cfg(
                    args, idbased_clsaction_cfg)
J
JYChen 已提交
390 391
                self.cls_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
392 393 394 395
            if self.with_skeleton_action:
                skeleton_action_cfg = self.cfg['SKELETON_ACTION']
                display_frames = skeleton_action_cfg['display_frames']
                self.coord_size = skeleton_action_cfg['coord_size']
396
                basemode = self.basemode['SKELETON_ACTION']
397
                self.modebase[basemode] = True
J
JYChen 已提交
398
                skeleton_action_frames = skeleton_action_cfg['max_frames']
399

J
JYChen 已提交
400 401
                self.skeleton_action_predictor = SkeletonActionRecognizer.init_with_cfg(
                    args, skeleton_action_cfg)
J
JYChen 已提交
402
                self.skeleton_action_visual_helper = ActionVisualHelper(
Z
zhiboniu 已提交
403
                    display_frames)
404

J
JYChen 已提交
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
                kpt_cfg = self.cfg['KPT']
                kpt_model_dir = kpt_cfg['model_dir']
                kpt_batch_size = kpt_cfg['batch_size']
                self.kpt_predictor = KeyPointDetector(
                    kpt_model_dir,
                    args.device,
                    args.run_mode,
                    kpt_batch_size,
                    args.trt_min_shape,
                    args.trt_max_shape,
                    args.trt_opt_shape,
                    args.trt_calib_mode,
                    args.cpu_threads,
                    args.enable_mkldnn,
                    use_dark=False)
                self.kpt_buff = KeyPointBuff(skeleton_action_frames)
Z
zhiboniu 已提交
421

422 423 424 425 426 427 428
            if self.with_vehicleplate:
                vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
                self.vehicleplate_detector = PlateRecognizer(args,
                                                             vehicleplate_cfg)
                basemode = self.basemode['VEHICLE_PLATE']
                self.modebase[basemode] = True

Z
zhiboniu 已提交
429 430
            if self.with_mtmct:
                reid_cfg = self.cfg['REID']
431
                basemode = self.basemode['REID']
Z
zhiboniu 已提交
432
                self.modebase[basemode] = True
J
JYChen 已提交
433
                self.reid_predictor = ReID.init_with_cfg(args, reid_cfg)
Z
zhiboniu 已提交
434

Z
zhiboniu 已提交
435 436 437
            if self.with_mot or self.modebase["idbased"] or self.modebase[
                    "skeletonbased"]:
                mot_cfg = self.cfg['MOT']
J
JYChen 已提交
438
                model_dir = mot_cfg['model_dir']
Z
zhiboniu 已提交
439 440
                tracker_config = mot_cfg['tracker_config']
                batch_size = mot_cfg['batch_size']
441
                skip_frame_num = mot_cfg.get('skip_frame_num', -1)
442
                basemode = self.basemode['MOT']
Z
zhiboniu 已提交
443 444 445 446
                self.modebase[basemode] = True
                self.mot_predictor = SDE_Detector(
                    model_dir,
                    tracker_config,
J
JYChen 已提交
447 448
                    args.device,
                    args.run_mode,
Z
zhiboniu 已提交
449
                    batch_size,
J
JYChen 已提交
450 451 452 453 454 455
                    args.trt_min_shape,
                    args.trt_max_shape,
                    args.trt_opt_shape,
                    args.trt_calib_mode,
                    args.cpu_threads,
                    args.enable_mkldnn,
456
                    skip_frame_num=skip_frame_num,
J
JYChen 已提交
457 458 459 460 461 462
                    draw_center_traj=self.draw_center_traj,
                    secs_interval=self.secs_interval,
                    do_entrance_counting=self.do_entrance_counting,
                    do_break_in_counting=self.do_break_in_counting,
                    region_type=self.region_type,
                    region_polygon=self.region_polygon)
Z
zhiboniu 已提交
463

464 465
            if self.with_video_action:
                video_action_cfg = self.cfg['VIDEO_ACTION']
466
                basemode = self.basemode['VIDEO_ACTION']
467
                self.modebase[basemode] = True
J
JYChen 已提交
468 469
                self.video_action_predictor = VideoActionRecognizer.init_with_cfg(
                    args, video_action_cfg)
470

471
    def set_file_name(self, path):
W
wangguanzhong 已提交
472 473 474 475 476
        if path is not None:
            self.file_name = os.path.split(path)[-1]
        else:
            # use camera id
            self.file_name = None
477

478
    def get_result(self):
Z
zhiboniu 已提交
479
        return self.collector.get_res()
480

481
    def run(self, input, thread_idx=0):
482
        if self.is_video:
483
            self.predict_video(input, thread_idx=thread_idx)
484 485
        else:
            self.predict_image(input)
486
        self.pipe_timer.info()
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504

    def predict_image(self, input):
        # det
        # det -> attr
        batch_loop_cnt = math.ceil(
            float(len(input)) / self.det_predictor.batch_size)
        for i in range(batch_loop_cnt):
            start_index = i * self.det_predictor.batch_size
            end_index = min((i + 1) * self.det_predictor.batch_size, len(input))
            batch_file = input[start_index:end_index]
            batch_input = [decode_image(f, {})[0] for f in batch_file]

            if i > self.warmup_frame:
                self.pipe_timer.total_time.start()
                self.pipe_timer.module_time['det'].start()
            # det output format: class, score, xmin, ymin, xmax, ymax
            det_res = self.det_predictor.predict_image(
                batch_input, visual=False)
505 506
            det_res = self.det_predictor.filter_box(det_res,
                                                    self.cfg['crop_thresh'])
507 508
            if i > self.warmup_frame:
                self.pipe_timer.module_time['det'].end()
Z
zhiboniu 已提交
509
                self.pipe_timer.track_num += len(det_res['boxes'])
510 511
            self.pipeline_res.update(det_res, 'det')

512
            if self.with_human_attr:
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
                crop_inputs = crop_image_with_det(batch_input, det_res)
                attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].end()

                attr_res = {'output': attr_res_list}
                self.pipeline_res.update(attr_res, 'attr')

530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
            if self.with_vehicle_attr:
                crop_inputs = crop_image_with_det(batch_input, det_res)
                vehicle_attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    vehicle_attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].end()

                attr_res = {'output': vehicle_attr_res_list}
                self.pipeline_res.update(attr_res, 'vehicle_attr')

Z
zhiboniu 已提交
548 549 550 551 552 553 554 555 556 557 558 559 560 561
            if self.with_vehicleplate:
                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicleplate'].start()
                crop_inputs = crop_image_with_det(batch_input, det_res)
                platelicenses = []
                for crop_input in crop_inputs:
                    platelicense = self.vehicleplate_detector.get_platelicense(
                        crop_input)
                    platelicenses.extend(platelicense['plate'])
                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicleplate'].end()
                vehicleplate_res = {'vehicleplate': platelicenses}
                self.pipeline_res.update(vehicleplate_res, 'vehicleplate')

562 563 564 565 566 567 568
            self.pipe_timer.img_num += len(batch_input)
            if i > self.warmup_frame:
                self.pipe_timer.total_time.end()

            if self.cfg['visual']:
                self.visualize_image(batch_file, batch_input, self.pipeline_res)

569
    def predict_video(self, video_file, thread_idx=0):
570 571 572
        # mot
        # mot -> attr
        # mot -> pose -> action
Z
zhiboniu 已提交
573
        capture = cv2.VideoCapture(video_file)
574
        video_out_name = 'output.mp4' if self.file_name is None else self.file_name
Z
zhiboniu 已提交
575
        if "rtsp" in video_file:
576 577
            video_out_name = video_out_name + "_t" + str(thread_idx).zfill(
                2) + "_rtsp.mp4"
578 579 580 581 582 583

        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
584
        print("video fps: %d, frame_count: %d" % (fps, frame_count))
585 586 587 588 589 590 591

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
        fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
        frame_id = 0
592 593 594 595 596 597 598 599 600 601

        entrance, records, center_traj = None, None, None
        if self.draw_center_traj:
            center_traj = [{}]
        id_set = set()
        interval_id_set = set()
        in_id_list = list()
        out_id_list = list()
        prev_center = dict()
        records = list()
602
        if self.do_entrance_counting or self.do_break_in_counting or self.illegal_parking_time != -1:
603 604 605 606 607 608 609 610 611
            if self.region_type == 'horizontal':
                entrance = [0, height / 2., width, height / 2.]
            elif self.region_type == 'vertical':
                entrance = [width / 2, 0., width / 2, height]
            elif self.region_type == 'custom':
                entrance = []
                assert len(
                    self.region_polygon
                ) % 2 == 0, "region_polygon should be pairs of coords points when do break_in counting."
J
JYChen 已提交
612 613 614 615
                assert len(
                    self.region_polygon
                ) > 6, 'region_type is custom, region_polygon should be at least 3 pairs of point coords.'

616 617 618 619 620 621 622 623
                for i in range(0, len(self.region_polygon), 2):
                    entrance.append(
                        [self.region_polygon[i], self.region_polygon[i + 1]])
                entrance.append([width, height])
            else:
                raise ValueError("region_type:{} unsupported.".format(
                    self.region_type))

624 625
        video_fps = fps

626 627
        video_action_imgs = []

628 629 630 631
        if self.with_video_action:
            short_size = self.cfg["VIDEO_ACTION"]["short_size"]
            scale = ShortSizeScale(short_size)

632 633 634 635
        object_in_region_info = {
        }  # store info for vehicle parking in region       
        illegal_parking_dict = None

636 637
        while (1):
            if frame_id % 10 == 0:
638
                print('Thread: {}; frame id: {}'.format(thread_idx, frame_id))
639

640 641 642
            ret, frame = capture.read()
            if not ret:
                break
643
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
Z
zhiboniu 已提交
644 645
            if frame_id > self.warmup_frame:
                self.pipe_timer.total_time.start()
646

647
            if self.modebase["idbased"] or self.modebase["skeletonbased"]:
648
                if frame_id > self.warmup_frame:
649
                    self.pipe_timer.module_time['mot'].start()
650

651 652 653 654 655 656 657 658
                mot_skip_frame_num = self.mot_predictor.skip_frame_num
                reuse_det_result = False
                if mot_skip_frame_num > 1 and frame_id > 0 and frame_id % mot_skip_frame_num > 0:
                    reuse_det_result = True
                res = self.mot_predictor.predict_image(
                    [copy.deepcopy(frame_rgb)],
                    visual=False,
                    reuse_det_result=reuse_det_result)
659 660 661

                # mot output format: id, class, score, xmin, ymin, xmax, ymax
                mot_res = parse_mot_res(res)
Z
zhiboniu 已提交
662 663 664
                if frame_id > self.warmup_frame:
                    self.pipe_timer.module_time['mot'].end()
                    self.pipe_timer.track_num += len(mot_res['boxes'])
665

666 667 668 669
                if frame_id % 10 == 0:
                    print("Thread: {}; trackid number: {}".format(
                        thread_idx, len(mot_res['boxes'])))

670 671 672 673 674
                # flow_statistic only support single class MOT
                boxes, scores, ids = res[0]  # batch size = 1 in MOT
                mot_result = (frame_id + 1, boxes[0], scores[0],
                              ids[0])  # single class
                statistic = flow_statistic(
F
Feng Ni 已提交
675 676 677 678 679 680 681 682 683 684 685 686 687 688
                    mot_result,
                    self.secs_interval,
                    self.do_entrance_counting,
                    self.do_break_in_counting,
                    self.region_type,
                    video_fps,
                    entrance,
                    id_set,
                    interval_id_set,
                    in_id_list,
                    out_id_list,
                    prev_center,
                    records,
                    ids2names=self.mot_predictor.pred_config.labels)
689 690
                records = statistic['records']

691 692 693 694 695 696 697 698 699 700
                if self.illegal_parking_time != -1:
                    object_in_region_info, illegal_parking_dict = update_object_info(
                        object_in_region_info, mot_result, self.region_type,
                        entrance, video_fps, self.illegal_parking_time)
                    if len(illegal_parking_dict) != 0:
                        # build relationship between id and plate
                        for key, value in illegal_parking_dict.items():
                            plate = self.collector.get_carlp(key)
                            illegal_parking_dict[key]['plate'] = plate

701 702 703
                # nothing detected
                if len(mot_res['boxes']) == 0:
                    frame_id += 1
J
JYChen 已提交
704
                    if frame_id > self.warmup_frame:
705 706 707 708 709 710 711 712 713
                        self.pipe_timer.img_num += 1
                        self.pipe_timer.total_time.end()
                    if self.cfg['visual']:
                        _, _, fps = self.pipe_timer.get_total_time()
                        im = self.visualize_video(frame, mot_res, frame_id, fps,
                                                  entrance, records,
                                                  center_traj)  # visualize
                        writer.write(im)
                        if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
714
                            cv2.imshow('Paddle-Pipeline', im)
715 716 717 718 719
                            if cv2.waitKey(1) & 0xFF == ord('q'):
                                break
                    continue

                self.pipeline_res.update(mot_res, 'mot')
J
JYChen 已提交
720
                crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
721
                    frame_rgb, mot_res)
722

723
                if self.with_vehicleplate and frame_id % 10 == 0:
Z
zhiboniu 已提交
724 725
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].start()
Z
zhiboniu 已提交
726 727
                    plate_input, _, _ = crop_image_with_mot(
                        frame_rgb, mot_res, expand=False)
Z
zhiboniu 已提交
728
                    platelicense = self.vehicleplate_detector.get_platelicense(
Z
zhiboniu 已提交
729
                        plate_input)
Z
zhiboniu 已提交
730 731
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].end()
Z
zhiboniu 已提交
732
                    self.pipeline_res.update(platelicense, 'vehicleplate')
733 734
                else:
                    self.pipeline_res.clear('vehicleplate')
Z
zhiboniu 已提交
735

736
                if self.with_human_attr:
J
JYChen 已提交
737
                    if frame_id > self.warmup_frame:
738 739 740 741 742 743 744
                        self.pipe_timer.module_time['attr'].start()
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['attr'].end()
                    self.pipeline_res.update(attr_res, 'attr')

745 746 747 748 749 750 751 752 753
                if self.with_vehicle_attr:
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].start()
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].end()
                    self.pipeline_res.update(attr_res, 'vehicle_attr')

Z
zhiboniu 已提交
754
                if self.with_idbased_detaction:
J
JYChen 已提交
755 756 757 758 759 760 761 762 763 764
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].start()
                    det_action_res = self.det_action_predictor.predict(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].end()
                    self.pipeline_res.update(det_action_res, 'det_action')

                    if self.cfg['visual']:
                        self.det_action_visual_helper.update(det_action_res)
Z
zhiboniu 已提交
765 766

                if self.with_idbased_clsaction:
J
JYChen 已提交
767 768 769 770 771 772 773 774 775 776
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].start()
                    cls_action_res = self.cls_action_predictor.predict_with_mot(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].end()
                    self.pipeline_res.update(cls_action_res, 'cls_action')

                    if self.cfg['visual']:
                        self.cls_action_visual_helper.update(cls_action_res)
Z
zhiboniu 已提交
777

Z
zhiboniu 已提交
778
                if self.with_skeleton_action:
Z
zhiboniu 已提交
779 780 781 782 783 784 785 786 787 788 789 790 791
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].start()
                    kpt_pred = self.kpt_predictor.predict_image(
                        crop_input, visual=False)
                    keypoint_vector, score_vector = translate_to_ori_images(
                        kpt_pred, np.array(new_bboxes))
                    kpt_res = {}
                    kpt_res['keypoint'] = [
                        keypoint_vector.tolist(), score_vector.tolist()
                    ] if len(keypoint_vector) > 0 else [[], []]
                    kpt_res['bbox'] = ori_bboxes
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].end()
792

Z
zhiboniu 已提交
793
                    self.pipeline_res.update(kpt_res, 'kpt')
794

Z
zhiboniu 已提交
795
                    self.kpt_buff.update(kpt_res, mot_res)  # collect kpt output
796 797 798
                    state = self.kpt_buff.get_state(
                    )  # whether frame num is enough or lost tracker

Z
zhiboniu 已提交
799
                    skeleton_action_res = {}
800 801
                    if state:
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
802 803
                            self.pipe_timer.module_time[
                                'skeleton_action'].start()
804 805
                        collected_keypoint = self.kpt_buff.get_collected_keypoint(
                        )  # reoragnize kpt output with ID
Z
zhiboniu 已提交
806 807 808 809
                        skeleton_action_input = parse_mot_keypoint(
                            collected_keypoint, self.coord_size)
                        skeleton_action_res = self.skeleton_action_predictor.predict_skeleton_with_mot(
                            skeleton_action_input)
810
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
811 812 813
                            self.pipe_timer.module_time['skeleton_action'].end()
                        self.pipeline_res.update(skeleton_action_res,
                                                 'skeleton_action')
814 815

                    if self.cfg['visual']:
Z
zhiboniu 已提交
816 817
                        self.skeleton_action_visual_helper.update(
                            skeleton_action_res)
818 819 820

                if self.with_mtmct and frame_id % 10 == 0:
                    crop_input, img_qualities, rects = self.reid_predictor.crop_image_with_mot(
821
                        frame_rgb, mot_res)
822 823 824 825 826 827
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].start()
                    reid_res = self.reid_predictor.predict_batch(crop_input)

                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].end()
J
JYChen 已提交
828

829 830 831 832 833 834 835 836
                    reid_res_dict = {
                        'features': reid_res,
                        "qualities": img_qualities,
                        "rects": rects
                    }
                    self.pipeline_res.update(reid_res_dict, 'reid')
                else:
                    self.pipeline_res.clear('reid')
Z
zhiboniu 已提交
837

Z
zhiboniu 已提交
838
            if self.with_video_action:
839 840 841 842 843 844 845 846 847 848 849 850 851
                # get the params
                frame_len = self.cfg["VIDEO_ACTION"]["frame_len"]
                sample_freq = self.cfg["VIDEO_ACTION"]["sample_freq"]

                if sample_freq * frame_len > frame_count:  # video is too short
                    sample_freq = int(frame_count / frame_len)

                # filter the warmup frames
                if frame_id > self.warmup_frame:
                    self.pipe_timer.module_time['video_action'].start()

                # collect frames
                if frame_id % sample_freq == 0:
852
                    # Scale image
853
                    scaled_img = scale(frame_rgb)
854
                    video_action_imgs.append(scaled_img)
855 856 857 858 859 860 861 862 863 864 865 866 867 868

                # the number of collected frames is enough to predict video action
                if len(video_action_imgs) == frame_len:
                    classes, scores = self.video_action_predictor.predict(
                        video_action_imgs)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['video_action'].end()

                    video_action_res = {"class": classes[0], "score": scores[0]}
                    self.pipeline_res.update(video_action_res, 'video_action')

                    print("video_action_res:", video_action_res)

                    video_action_imgs.clear()  # next clip
Z
zhiboniu 已提交
869 870

            self.collector.append(frame_id, self.pipeline_res)
871 872 873 874 875 876 877

            if frame_id > self.warmup_frame:
                self.pipe_timer.img_num += 1
                self.pipe_timer.total_time.end()
            frame_id += 1

            if self.cfg['visual']:
878
                _, _, fps = self.pipe_timer.get_total_time()
879 880 881 882 883 884

                im = self.visualize_video(frame, self.pipeline_res,
                                          self.collector, frame_id, fps,
                                          entrance, records, center_traj,
                                          self.illegal_parking_time != -1,
                                          illegal_parking_dict)  # visualize
885
                writer.write(im)
W
wangguanzhong 已提交
886
                if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
887
                    cv2.imshow('Paddle-Pipeline', im)
W
wangguanzhong 已提交
888 889
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
890 891 892 893

        writer.release()
        print('save result to {}'.format(out_path))

894 895 896
    def visualize_video(self,
                        image,
                        result,
897
                        collector,
898 899 900 901
                        frame_id,
                        fps,
                        entrance=None,
                        records=None,
902 903 904
                        center_traj=None,
                        do_illegal_parking_recognition=False,
                        illegal_parking_dict=None):
Z
zhiboniu 已提交
905
        mot_res = copy.deepcopy(result.get('mot'))
906 907
        if mot_res is not None:
            ids = mot_res['boxes'][:, 0]
W
wangguanzhong 已提交
908
            scores = mot_res['boxes'][:, 2]
909 910 911 912 913 914
            boxes = mot_res['boxes'][:, 3:]
            boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
            boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        else:
            boxes = np.zeros([0, 4])
            ids = np.zeros([0])
W
wangguanzhong 已提交
915
            scores = np.zeros([0])
916 917 918 919 920 921 922 923 924 925

        # single class, still need to be defaultdict type for ploting
        num_classes = 1
        online_tlwhs = defaultdict(list)
        online_scores = defaultdict(list)
        online_ids = defaultdict(list)
        online_tlwhs[0] = boxes
        online_scores[0] = scores
        online_ids[0] = ids

F
Feng Ni 已提交
926 927 928 929 930 931 932 933 934
        if mot_res is not None:
            image = plot_tracking_dict(
                image,
                num_classes,
                online_tlwhs,
                online_ids,
                online_scores,
                frame_id=frame_id,
                fps=fps,
935
                ids2names=self.mot_predictor.pred_config.labels,
F
Feng Ni 已提交
936
                do_entrance_counting=self.do_entrance_counting,
937
                do_break_in_counting=self.do_break_in_counting,
938 939
                do_illegal_parking_recognition=do_illegal_parking_recognition,
                illegal_parking_dict=illegal_parking_dict,
F
Feng Ni 已提交
940 941 942
                entrance=entrance,
                records=records,
                center_traj=center_traj)
943

944 945 946 947 948 949 950 951 952
        human_attr_res = result.get('attr')
        if human_attr_res is not None:
            boxes = mot_res['boxes'][:, 1:]
            human_attr_res = human_attr_res['output']
            image = visualize_attr(image, human_attr_res, boxes)
            image = np.array(image)

        vehicle_attr_res = result.get('vehicle_attr')
        if vehicle_attr_res is not None:
953
            boxes = mot_res['boxes'][:, 1:]
954 955
            vehicle_attr_res = vehicle_attr_res['output']
            image = visualize_attr(image, vehicle_attr_res, boxes)
956 957
            image = np.array(image)

958 959 960 961 962 963 964 965 966 967 968 969 970 971
        if mot_res is not None:
            vehicleplate = False
            plates = []
            for trackid in mot_res['boxes'][:, 0]:
                plate = collector.get_carlp(trackid)
                if plate != None:
                    vehicleplate = True
                    plates.append(plate)
                else:
                    plates.append("")
            if vehicleplate:
                boxes = mot_res['boxes'][:, 1:]
                image = visualize_vehicleplate(image, plates, boxes)
                image = np.array(image)
Z
zhiboniu 已提交
972

J
JYChen 已提交
973 974 975 976 977 978 979 980
        kpt_res = result.get('kpt')
        if kpt_res is not None:
            image = visualize_pose(
                image,
                kpt_res,
                visual_thresh=self.cfg['kpt_thresh'],
                returnimg=True)

981
        video_action_res = result.get('video_action')
J
JYChen 已提交
982
        if video_action_res is not None:
983 984 985
            video_action_score = None
            if video_action_res and video_action_res["class"] == 1:
                video_action_score = video_action_res["score"]
986 987 988
            mot_boxes = None
            if mot_res:
                mot_boxes = mot_res['boxes']
989 990
            image = visualize_action(
                image,
991
                mot_boxes,
J
JYChen 已提交
992
                action_visual_collector=None,
993 994 995
                action_text="SkeletonAction",
                video_action_score=video_action_score,
                video_action_text="Fight")
J
JYChen 已提交
996

J
JYChen 已提交
997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019
        visual_helper_for_display = []
        action_to_display = []

        skeleton_action_res = result.get('skeleton_action')
        if skeleton_action_res is not None:
            visual_helper_for_display.append(self.skeleton_action_visual_helper)
            action_to_display.append("Falling")

        det_action_res = result.get('det_action')
        if det_action_res is not None:
            visual_helper_for_display.append(self.det_action_visual_helper)
            action_to_display.append("Smoking")

        cls_action_res = result.get('cls_action')
        if cls_action_res is not None:
            visual_helper_for_display.append(self.cls_action_visual_helper)
            action_to_display.append("Calling")

        if len(visual_helper_for_display) > 0:
            image = visualize_action(image, mot_res['boxes'],
                                     visual_helper_for_display,
                                     action_to_display)

1020 1021 1022 1023 1024
        return image

    def visualize_image(self, im_files, images, result):
        start_idx, boxes_num_i = 0, 0
        det_res = result.get('det')
1025 1026
        human_attr_res = result.get('attr')
        vehicle_attr_res = result.get('vehicle_attr')
Z
zhiboniu 已提交
1027
        vehicleplate_res = result.get('vehicleplate')
1028

1029 1030 1031 1032 1033 1034 1035 1036 1037
        for i, (im_file, im) in enumerate(zip(im_files, images)):
            if det_res is not None:
                det_res_i = {}
                boxes_num_i = det_res['boxes_num'][i]
                det_res_i['boxes'] = det_res['boxes'][start_idx:start_idx +
                                                      boxes_num_i, :]
                im = visualize_box_mask(
                    im,
                    det_res_i,
Z
zhiboniu 已提交
1038
                    labels=['target'],
1039
                    threshold=self.cfg['crop_thresh'])
1040 1041
                im = np.ascontiguousarray(np.copy(im))
                im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
1042 1043 1044 1045 1046 1047 1048 1049
            if human_attr_res is not None:
                human_attr_res_i = human_attr_res['output'][start_idx:start_idx
                                                            + boxes_num_i]
                im = visualize_attr(im, human_attr_res_i, det_res_i['boxes'])
            if vehicle_attr_res is not None:
                vehicle_attr_res_i = vehicle_attr_res['output'][
                    start_idx:start_idx + boxes_num_i]
                im = visualize_attr(im, vehicle_attr_res_i, det_res_i['boxes'])
Z
zhiboniu 已提交
1050 1051 1052 1053 1054
            if vehicleplate_res is not None:
                plates = vehicleplate_res['vehicleplate']
                det_res_i['boxes'][:, 4:6] = det_res_i[
                    'boxes'][:, 4:6] - det_res_i['boxes'][:, 2:4]
                im = visualize_vehicleplate(im, plates, det_res_i['boxes'])
1055

1056 1057 1058 1059
            img_name = os.path.split(im_file)[-1]
            if not os.path.exists(self.output_dir):
                os.makedirs(self.output_dir)
            out_path = os.path.join(self.output_dir, img_name)
1060
            cv2.imwrite(out_path, im)
1061 1062 1063 1064 1065
            print("save result to: " + out_path)
            start_idx += boxes_num_i


def main():
1066
    cfg = merge_cfg(FLAGS)  # use command params to update config
1067
    print_arguments(cfg)
1068

Z
zhiboniu 已提交
1069
    pipeline = Pipeline(FLAGS, cfg)
1070 1071
    # pipeline.run()
    pipeline.run_multithreads()
1072 1073 1074 1075


if __name__ == '__main__':
    paddle.enable_static()
1076 1077

    # parse params from command
1078 1079 1080 1081 1082 1083 1084
    parser = argsparser()
    FLAGS = parser.parse_args()
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"

    main()