pipeline.py 42.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import glob
import cv2
import numpy as np
import math
import paddle
import sys
Z
zhiboniu 已提交
23
import copy
Z
zhiboniu 已提交
24
from collections import Sequence, defaultdict
Z
zhiboniu 已提交
25
from datacollector import DataCollector, Result
26 27 28 29 30

# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)

31 32
from cfg_utils import argsparser, print_arguments, merge_cfg
from pipe_utils import PipeTimer
Z
zhiboniu 已提交
33 34
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint

35
from python.infer import Detector, DetectorPicoDet
J
JYChen 已提交
36 37
from python.keypoint_infer import KeyPointDetector
from python.keypoint_postprocess import translate_to_ori_images
38
from python.preprocess import decode_image, ShortSizeScale
Z
zhiboniu 已提交
39
from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action, visualize_vehicleplate
40 41

from pptracking.python.mot_sde_infer import SDE_Detector
42
from pptracking.python.mot.visualize import plot_tracking_dict
43
from pptracking.python.mot.utils import flow_statistic, update_object_info
44

Z
zhiboniu 已提交
45 46 47 48 49 50 51
from pphuman.attr_infer import AttrDetector
from pphuman.video_action_infer import VideoActionRecognizer
from pphuman.action_infer import SkeletonActionRecognizer, DetActionRecognizer, ClsActionRecognizer
from pphuman.action_utils import KeyPointBuff, ActionVisualHelper
from pphuman.reid import ReID
from pphuman.mtmct import mtmct_process

52 53 54
from ppvehicle.vehicle_plate import PlateRecognizer
from ppvehicle.vehicle_attr import VehicleAttr

55 56
from download import auto_download_model

57 58 59 60 61 62

class Pipeline(object):
    """
    Pipeline

    Args:
J
JYChen 已提交
63
        args (argparse.Namespace): arguments in pipeline, which contains environment and runtime settings
64 65 66
        cfg (dict): config of models in pipeline
    """

Z
zhiboniu 已提交
67
    def __init__(self, args, cfg):
68
        self.multi_camera = False
Z
zhiboniu 已提交
69 70
        reid_cfg = cfg.get('REID', False)
        self.enable_mtmct = reid_cfg['enable'] if reid_cfg else False
71
        self.is_video = False
Z
zhiboniu 已提交
72
        self.output_dir = args.output_dir
Z
zhiboniu 已提交
73
        self.vis_result = cfg['visual']
Z
zhiboniu 已提交
74 75 76
        self.input = self._parse_input(args.image_file, args.image_dir,
                                       args.video_file, args.video_dir,
                                       args.camera_id)
77
        if self.multi_camera:
78 79 80
            self.predictor = []
            for name in self.input:
                predictor_item = PipePredictor(
Z
zhiboniu 已提交
81
                    args, cfg, is_video=True, multi_camera=True)
82 83 84
                predictor_item.set_file_name(name)
                self.predictor.append(predictor_item)

85
        else:
Z
zhiboniu 已提交
86
            self.predictor = PipePredictor(args, cfg, self.is_video)
87
            if self.is_video:
Z
zhiboniu 已提交
88
                self.predictor.set_file_name(args.video_file)
89

Z
zhiboniu 已提交
90 91
    def _parse_input(self, image_file, image_dir, video_file, video_dir,
                     camera_id):
92 93 94 95 96 97 98 99 100

        # parse input as is_video and multi_camera

        if image_file is not None or image_dir is not None:
            input = get_test_images(image_dir, image_file)
            self.is_video = False
            self.multi_camera = False

        elif video_file is not None:
Z
zhiboniu 已提交
101 102 103
            assert os.path.exists(
                video_file
            ) or 'rtsp' in video_file, "video_file not exists and not an rtsp site."
Z
zhiboniu 已提交
104 105 106 107 108 109 110
            self.multi_camera = False
            input = video_file
            self.is_video = True

        elif video_dir is not None:
            videof = [os.path.join(video_dir, x) for x in os.listdir(video_dir)]
            if len(videof) > 1:
111
                self.multi_camera = True
Z
zhiboniu 已提交
112 113
                videof.sort()
                input = videof
114
            else:
Z
zhiboniu 已提交
115
                input = videof[0]
116 117 118
            self.is_video = True

        elif camera_id != -1:
Z
zhiboniu 已提交
119 120
            self.multi_camera = False
            input = camera_id
121 122 123 124
            self.is_video = True

        else:
            raise ValueError(
125
                "Illegal Input, please set one of ['video_file', 'camera_id', 'image_file', 'image_dir']"
126 127 128 129 130 131 132 133 134
            )

        return input

    def run(self):
        if self.multi_camera:
            multi_res = []
            for predictor, input in zip(self.predictor, self.input):
                predictor.run(input)
Z
zhiboniu 已提交
135 136
                collector_data = predictor.get_result()
                multi_res.append(collector_data)
137 138 139 140 141 142
            if self.enable_mtmct:
                mtmct_process(
                    multi_res,
                    self.input,
                    mtmct_vis=self.vis_result,
                    output_dir=self.output_dir)
143 144 145 146 147

        else:
            self.predictor.run(self.input)


148
def get_model_dir(cfg):
J
JYChen 已提交
149 150 151 152
    """ 
        Auto download inference model if the model_path is a url link. 
        Otherwise it will use the model_path directly.
    """
153 154 155 156 157 158 159 160 161 162
    for key in cfg.keys():
        if type(cfg[key]) ==  dict and \
            ("enable" in cfg[key].keys() and cfg[key]['enable']
                or "enable" not in cfg[key].keys()):

            if "model_dir" in cfg[key].keys():
                model_dir = cfg[key]["model_dir"]
                downloaded_model_dir = auto_download_model(model_dir)
                if downloaded_model_dir:
                    model_dir = downloaded_model_dir
J
JYChen 已提交
163 164
                    cfg[key]["model_dir"] = model_dir
                print(key, " model dir: ", model_dir)
165 166 167 168 169
            elif key == "VEHICLE_PLATE":
                det_model_dir = cfg[key]["det_model_dir"]
                downloaded_det_model_dir = auto_download_model(det_model_dir)
                if downloaded_det_model_dir:
                    det_model_dir = downloaded_det_model_dir
J
JYChen 已提交
170 171
                    cfg[key]["det_model_dir"] = det_model_dir
                print("det_model_dir model dir: ", det_model_dir)
172 173 174 175 176

                rec_model_dir = cfg[key]["rec_model_dir"]
                downloaded_rec_model_dir = auto_download_model(rec_model_dir)
                if downloaded_rec_model_dir:
                    rec_model_dir = downloaded_rec_model_dir
J
JYChen 已提交
177 178 179
                    cfg[key]["rec_model_dir"] = rec_model_dir
                print("rec_model_dir model dir: ", rec_model_dir)

180 181 182 183 184
        elif key == "MOT":  # for idbased and skeletonbased actions
            model_dir = cfg[key]["model_dir"]
            downloaded_model_dir = auto_download_model(model_dir)
            if downloaded_model_dir:
                model_dir = downloaded_model_dir
J
JYChen 已提交
185 186
                cfg[key]["model_dir"] = model_dir
            print("mot_model_dir model_dir: ", model_dir)
187 188


189 190 191 192 193 194 195 196 197 198 199 200 201
class PipePredictor(object):
    """
    Predictor in single camera
    
    The pipeline for image input: 

        1. Detection
        2. Detection -> Attribute

    The pipeline for video input: 

        1. Tracking
        2. Tracking -> Attribute
Z
zhiboniu 已提交
202
        3. Tracking -> KeyPoint -> SkeletonAction Recognition
203
        4. VideoAction Recognition
204 205

    Args:
J
JYChen 已提交
206
        args (argparse.Namespace): arguments in pipeline, which contains environment and runtime settings
207 208 209 210 211 212
        cfg (dict): config of models in pipeline
        is_video (bool): whether the input is video, default as False
        multi_camera (bool): whether to use multi camera in pipeline, 
            default as False
    """

Z
zhiboniu 已提交
213 214 215 216
    def __init__(self, args, cfg, is_video=True, multi_camera=False):
        # general module for pphuman and ppvehicle
        self.with_mot = cfg.get('MOT', False)['enable'] if cfg.get(
            'MOT', False) else False
217
        self.with_human_attr = cfg.get('ATTR', False)['enable'] if cfg.get(
Z
zhiboniu 已提交
218
            'ATTR', False) else False
Z
zhiboniu 已提交
219 220
        if self.with_mot:
            print('Multi-Object Tracking enabled')
221 222
        if self.with_human_attr:
            print('Human Attribute Recognition enabled')
Z
zhiboniu 已提交
223 224

        # only for pphuman
Z
zhiboniu 已提交
225 226 227
        self.with_skeleton_action = cfg.get(
            'SKELETON_ACTION', False)['enable'] if cfg.get('SKELETON_ACTION',
                                                           False) else False
Z
zhiboniu 已提交
228 229 230 231 232 233 234 235 236
        self.with_video_action = cfg.get(
            'VIDEO_ACTION', False)['enable'] if cfg.get('VIDEO_ACTION',
                                                        False) else False
        self.with_idbased_detaction = cfg.get(
            'ID_BASED_DETACTION', False)['enable'] if cfg.get(
                'ID_BASED_DETACTION', False) else False
        self.with_idbased_clsaction = cfg.get(
            'ID_BASED_CLSACTION', False)['enable'] if cfg.get(
                'ID_BASED_CLSACTION', False) else False
Z
zhiboniu 已提交
237 238
        self.with_mtmct = cfg.get('REID', False)['enable'] if cfg.get(
            'REID', False) else False
239

Z
zhiboniu 已提交
240 241
        if self.with_skeleton_action:
            print('SkeletonAction Recognition enabled')
Z
zhiboniu 已提交
242 243 244 245 246 247
        if self.with_video_action:
            print('VideoAction Recognition enabled')
        if self.with_idbased_detaction:
            print('IDBASED Detection Action Recognition enabled')
        if self.with_idbased_clsaction:
            print('IDBASED Classification Action Recognition enabled')
Z
zhiboniu 已提交
248 249
        if self.with_mtmct:
            print("MTMCT enabled")
W
wangguanzhong 已提交
250

Z
zhiboniu 已提交
251 252 253 254 255 256 257
        # only for ppvehicle
        self.with_vehicleplate = cfg.get(
            'VEHICLE_PLATE', False)['enable'] if cfg.get('VEHICLE_PLATE',
                                                         False) else False
        if self.with_vehicleplate:
            print('Vehicle Plate Recognition enabled')

258 259 260 261 262 263
        self.with_vehicle_attr = cfg.get(
            'VEHICLE_ATTR', False)['enable'] if cfg.get('VEHICLE_ATTR',
                                                        False) else False
        if self.with_vehicle_attr:
            print('Vehicle Attribute Recognition enabled')

264 265 266 267 268 269
        self.modebase = {
            "framebased": False,
            "videobased": False,
            "idbased": False,
            "skeletonbased": False
        }
270

271 272 273 274 275 276 277 278 279 280 281 282
        self.basemode = {
            "MOT": "idbased",
            "ATTR": "idbased",
            "VIDEO_ACTION": "videobased",
            "SKELETON_ACTION": "skeletonbased",
            "ID_BASED_DETACTION": "idbased",
            "ID_BASED_CLSACTION": "idbased",
            "REID": "idbased",
            "VEHICLE_PLATE": "idbased",
            "VEHICLE_ATTR": "idbased",
        }

283 284 285
        self.is_video = is_video
        self.multi_camera = multi_camera
        self.cfg = cfg
286

J
JYChen 已提交
287 288 289 290 291 292 293
        self.output_dir = args.output_dir
        self.draw_center_traj = args.draw_center_traj
        self.secs_interval = args.secs_interval
        self.do_entrance_counting = args.do_entrance_counting
        self.do_break_in_counting = args.do_break_in_counting
        self.region_type = args.region_type
        self.region_polygon = args.region_polygon
294
        self.illegal_parking_time = args.illegal_parking_time
295

J
JYChen 已提交
296
        self.warmup_frame = self.cfg['warmup_frame']
297 298
        self.pipeline_res = Result()
        self.pipe_timer = PipeTimer()
299
        self.file_name = None
Z
zhiboniu 已提交
300
        self.collector = DataCollector()
301

302
        # auto download inference model
J
JYChen 已提交
303
        get_model_dir(self.cfg)
304

Z
zhiboniu 已提交
305 306 307 308 309 310 311 312 313 314
        if self.with_vehicleplate:
            vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
            self.vehicleplate_detector = PlateRecognizer(args, vehicleplate_cfg)
            basemode = self.basemode['VEHICLE_PLATE']
            self.modebase[basemode] = True

        if self.with_human_attr:
            attr_cfg = self.cfg['ATTR']
            basemode = self.basemode['ATTR']
            self.modebase[basemode] = True
J
JYChen 已提交
315
            self.attr_predictor = AttrDetector.init_with_cfg(args, attr_cfg)
Z
zhiboniu 已提交
316 317 318 319 320

        if self.with_vehicle_attr:
            vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
            basemode = self.basemode['VEHICLE_ATTR']
            self.modebase[basemode] = True
J
JYChen 已提交
321 322
            self.vehicle_attr_predictor = VehicleAttr.init_with_cfg(
                args, vehicleattr_cfg)
Z
zhiboniu 已提交
323

324 325
        if not is_video:
            det_cfg = self.cfg['DET']
J
JYChen 已提交
326
            model_dir = det_cfg['model_dir']
327 328
            batch_size = det_cfg['batch_size']
            self.det_predictor = Detector(
J
JYChen 已提交
329 330 331
                model_dir, args.device, args.run_mode, batch_size,
                args.trt_min_shape, args.trt_max_shape, args.trt_opt_shape,
                args.trt_calib_mode, args.cpu_threads, args.enable_mkldnn)
332
        else:
Z
zhiboniu 已提交
333
            if self.with_idbased_detaction:
J
JYChen 已提交
334
                idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION']
335
                basemode = self.basemode['ID_BASED_DETACTION']
J
JYChen 已提交
336
                self.modebase[basemode] = True
337

J
JYChen 已提交
338 339
                self.det_action_predictor = DetActionRecognizer.init_with_cfg(
                    args, idbased_detaction_cfg)
J
JYChen 已提交
340 341
                self.det_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
342
            if self.with_idbased_clsaction:
J
JYChen 已提交
343
                idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION']
344
                basemode = self.basemode['ID_BASED_CLSACTION']
J
JYChen 已提交
345
                self.modebase[basemode] = True
346

J
JYChen 已提交
347 348
                self.cls_action_predictor = ClsActionRecognizer.init_with_cfg(
                    args, idbased_clsaction_cfg)
J
JYChen 已提交
349 350
                self.cls_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
351 352 353 354
            if self.with_skeleton_action:
                skeleton_action_cfg = self.cfg['SKELETON_ACTION']
                display_frames = skeleton_action_cfg['display_frames']
                self.coord_size = skeleton_action_cfg['coord_size']
355
                basemode = self.basemode['SKELETON_ACTION']
356
                self.modebase[basemode] = True
J
JYChen 已提交
357
                skeleton_action_frames = skeleton_action_cfg['max_frames']
358

J
JYChen 已提交
359 360
                self.skeleton_action_predictor = SkeletonActionRecognizer.init_with_cfg(
                    args, skeleton_action_cfg)
J
JYChen 已提交
361
                self.skeleton_action_visual_helper = ActionVisualHelper(
Z
zhiboniu 已提交
362
                    display_frames)
363

J
JYChen 已提交
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
                kpt_cfg = self.cfg['KPT']
                kpt_model_dir = kpt_cfg['model_dir']
                kpt_batch_size = kpt_cfg['batch_size']
                self.kpt_predictor = KeyPointDetector(
                    kpt_model_dir,
                    args.device,
                    args.run_mode,
                    kpt_batch_size,
                    args.trt_min_shape,
                    args.trt_max_shape,
                    args.trt_opt_shape,
                    args.trt_calib_mode,
                    args.cpu_threads,
                    args.enable_mkldnn,
                    use_dark=False)
                self.kpt_buff = KeyPointBuff(skeleton_action_frames)
Z
zhiboniu 已提交
380

381 382 383 384 385 386 387
            if self.with_vehicleplate:
                vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
                self.vehicleplate_detector = PlateRecognizer(args,
                                                             vehicleplate_cfg)
                basemode = self.basemode['VEHICLE_PLATE']
                self.modebase[basemode] = True

Z
zhiboniu 已提交
388 389
            if self.with_mtmct:
                reid_cfg = self.cfg['REID']
390
                basemode = self.basemode['REID']
Z
zhiboniu 已提交
391
                self.modebase[basemode] = True
J
JYChen 已提交
392
                self.reid_predictor = ReID.init_with_cfg(args, reid_cfg)
Z
zhiboniu 已提交
393

Z
zhiboniu 已提交
394 395 396
            if self.with_mot or self.modebase["idbased"] or self.modebase[
                    "skeletonbased"]:
                mot_cfg = self.cfg['MOT']
J
JYChen 已提交
397
                model_dir = mot_cfg['model_dir']
Z
zhiboniu 已提交
398 399
                tracker_config = mot_cfg['tracker_config']
                batch_size = mot_cfg['batch_size']
400
                skip_frame_num = mot_cfg.get('skip_frame_num', -1)
401
                basemode = self.basemode['MOT']
Z
zhiboniu 已提交
402 403 404 405
                self.modebase[basemode] = True
                self.mot_predictor = SDE_Detector(
                    model_dir,
                    tracker_config,
J
JYChen 已提交
406 407
                    args.device,
                    args.run_mode,
Z
zhiboniu 已提交
408
                    batch_size,
J
JYChen 已提交
409 410 411 412 413 414
                    args.trt_min_shape,
                    args.trt_max_shape,
                    args.trt_opt_shape,
                    args.trt_calib_mode,
                    args.cpu_threads,
                    args.enable_mkldnn,
415
                    skip_frame_num=skip_frame_num,
J
JYChen 已提交
416 417 418 419 420 421
                    draw_center_traj=self.draw_center_traj,
                    secs_interval=self.secs_interval,
                    do_entrance_counting=self.do_entrance_counting,
                    do_break_in_counting=self.do_break_in_counting,
                    region_type=self.region_type,
                    region_polygon=self.region_polygon)
Z
zhiboniu 已提交
422

423 424
            if self.with_video_action:
                video_action_cfg = self.cfg['VIDEO_ACTION']
425
                basemode = self.basemode['VIDEO_ACTION']
426
                self.modebase[basemode] = True
J
JYChen 已提交
427 428
                self.video_action_predictor = VideoActionRecognizer.init_with_cfg(
                    args, video_action_cfg)
429

430
    def set_file_name(self, path):
W
wangguanzhong 已提交
431 432 433 434 435
        if path is not None:
            self.file_name = os.path.split(path)[-1]
        else:
            # use camera id
            self.file_name = None
436

437
    def get_result(self):
Z
zhiboniu 已提交
438
        return self.collector.get_res()
439 440 441 442 443 444

    def run(self, input):
        if self.is_video:
            self.predict_video(input)
        else:
            self.predict_image(input)
445
        self.pipe_timer.info()
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463

    def predict_image(self, input):
        # det
        # det -> attr
        batch_loop_cnt = math.ceil(
            float(len(input)) / self.det_predictor.batch_size)
        for i in range(batch_loop_cnt):
            start_index = i * self.det_predictor.batch_size
            end_index = min((i + 1) * self.det_predictor.batch_size, len(input))
            batch_file = input[start_index:end_index]
            batch_input = [decode_image(f, {})[0] for f in batch_file]

            if i > self.warmup_frame:
                self.pipe_timer.total_time.start()
                self.pipe_timer.module_time['det'].start()
            # det output format: class, score, xmin, ymin, xmax, ymax
            det_res = self.det_predictor.predict_image(
                batch_input, visual=False)
464 465
            det_res = self.det_predictor.filter_box(det_res,
                                                    self.cfg['crop_thresh'])
466 467
            if i > self.warmup_frame:
                self.pipe_timer.module_time['det'].end()
Z
zhiboniu 已提交
468
                self.pipe_timer.track_num += len(det_res['boxes'])
469 470
            self.pipeline_res.update(det_res, 'det')

471
            if self.with_human_attr:
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
                crop_inputs = crop_image_with_det(batch_input, det_res)
                attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].end()

                attr_res = {'output': attr_res_list}
                self.pipeline_res.update(attr_res, 'attr')

489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506
            if self.with_vehicle_attr:
                crop_inputs = crop_image_with_det(batch_input, det_res)
                vehicle_attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    vehicle_attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].end()

                attr_res = {'output': vehicle_attr_res_list}
                self.pipeline_res.update(attr_res, 'vehicle_attr')

Z
zhiboniu 已提交
507 508 509 510 511 512 513 514 515 516 517 518 519 520
            if self.with_vehicleplate:
                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicleplate'].start()
                crop_inputs = crop_image_with_det(batch_input, det_res)
                platelicenses = []
                for crop_input in crop_inputs:
                    platelicense = self.vehicleplate_detector.get_platelicense(
                        crop_input)
                    platelicenses.extend(platelicense['plate'])
                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicleplate'].end()
                vehicleplate_res = {'vehicleplate': platelicenses}
                self.pipeline_res.update(vehicleplate_res, 'vehicleplate')

521 522 523 524 525 526 527
            self.pipe_timer.img_num += len(batch_input)
            if i > self.warmup_frame:
                self.pipe_timer.total_time.end()

            if self.cfg['visual']:
                self.visualize_image(batch_file, batch_input, self.pipeline_res)

Z
zhiboniu 已提交
528
    def predict_video(self, video_file):
529 530 531
        # mot
        # mot -> attr
        # mot -> pose -> action
Z
zhiboniu 已提交
532
        capture = cv2.VideoCapture(video_file)
533
        video_out_name = 'output.mp4' if self.file_name is None else self.file_name
Z
zhiboniu 已提交
534 535
        if "rtsp" in video_file:
            video_out_name = video_out_name + "_rtsp.mp4"
536 537 538 539 540 541

        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
542
        print("video fps: %d, frame_count: %d" % (fps, frame_count))
543 544 545 546 547 548 549

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
        fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
        frame_id = 0
550 551 552 553 554 555 556 557 558 559

        entrance, records, center_traj = None, None, None
        if self.draw_center_traj:
            center_traj = [{}]
        id_set = set()
        interval_id_set = set()
        in_id_list = list()
        out_id_list = list()
        prev_center = dict()
        records = list()
560
        if self.do_entrance_counting or self.do_break_in_counting or self.illegal_parking_time != -1:
561 562 563 564 565 566 567 568 569
            if self.region_type == 'horizontal':
                entrance = [0, height / 2., width, height / 2.]
            elif self.region_type == 'vertical':
                entrance = [width / 2, 0., width / 2, height]
            elif self.region_type == 'custom':
                entrance = []
                assert len(
                    self.region_polygon
                ) % 2 == 0, "region_polygon should be pairs of coords points when do break_in counting."
J
JYChen 已提交
570 571 572 573
                assert len(
                    self.region_polygon
                ) > 6, 'region_type is custom, region_polygon should be at least 3 pairs of point coords.'

574 575 576 577 578 579 580 581
                for i in range(0, len(self.region_polygon), 2):
                    entrance.append(
                        [self.region_polygon[i], self.region_polygon[i + 1]])
                entrance.append([width, height])
            else:
                raise ValueError("region_type:{} unsupported.".format(
                    self.region_type))

582 583
        video_fps = fps

584 585
        video_action_imgs = []

586 587 588 589
        if self.with_video_action:
            short_size = self.cfg["VIDEO_ACTION"]["short_size"]
            scale = ShortSizeScale(short_size)

590 591 592 593
        object_in_region_info = {
        }  # store info for vehicle parking in region       
        illegal_parking_dict = None

594 595 596
        while (1):
            if frame_id % 10 == 0:
                print('frame id: ', frame_id)
597

598 599 600
            ret, frame = capture.read()
            if not ret:
                break
601
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
Z
zhiboniu 已提交
602 603
            if frame_id > self.warmup_frame:
                self.pipe_timer.total_time.start()
604

605
            if self.modebase["idbased"] or self.modebase["skeletonbased"]:
606
                if frame_id > self.warmup_frame:
607
                    self.pipe_timer.module_time['mot'].start()
608

609 610 611 612 613 614 615 616
                mot_skip_frame_num = self.mot_predictor.skip_frame_num
                reuse_det_result = False
                if mot_skip_frame_num > 1 and frame_id > 0 and frame_id % mot_skip_frame_num > 0:
                    reuse_det_result = True
                res = self.mot_predictor.predict_image(
                    [copy.deepcopy(frame_rgb)],
                    visual=False,
                    reuse_det_result=reuse_det_result)
617 618 619

                # mot output format: id, class, score, xmin, ymin, xmax, ymax
                mot_res = parse_mot_res(res)
Z
zhiboniu 已提交
620 621 622
                if frame_id > self.warmup_frame:
                    self.pipe_timer.module_time['mot'].end()
                    self.pipe_timer.track_num += len(mot_res['boxes'])
623 624 625 626 627 628

                # flow_statistic only support single class MOT
                boxes, scores, ids = res[0]  # batch size = 1 in MOT
                mot_result = (frame_id + 1, boxes[0], scores[0],
                              ids[0])  # single class
                statistic = flow_statistic(
F
Feng Ni 已提交
629 630 631 632 633 634 635 636 637 638 639 640 641 642
                    mot_result,
                    self.secs_interval,
                    self.do_entrance_counting,
                    self.do_break_in_counting,
                    self.region_type,
                    video_fps,
                    entrance,
                    id_set,
                    interval_id_set,
                    in_id_list,
                    out_id_list,
                    prev_center,
                    records,
                    ids2names=self.mot_predictor.pred_config.labels)
643 644
                records = statistic['records']

645 646 647 648 649 650 651 652 653 654
                if self.illegal_parking_time != -1:
                    object_in_region_info, illegal_parking_dict = update_object_info(
                        object_in_region_info, mot_result, self.region_type,
                        entrance, video_fps, self.illegal_parking_time)
                    if len(illegal_parking_dict) != 0:
                        # build relationship between id and plate
                        for key, value in illegal_parking_dict.items():
                            plate = self.collector.get_carlp(key)
                            illegal_parking_dict[key]['plate'] = plate

655 656 657
                # nothing detected
                if len(mot_res['boxes']) == 0:
                    frame_id += 1
J
JYChen 已提交
658
                    if frame_id > self.warmup_frame:
659 660 661 662 663 664 665 666 667
                        self.pipe_timer.img_num += 1
                        self.pipe_timer.total_time.end()
                    if self.cfg['visual']:
                        _, _, fps = self.pipe_timer.get_total_time()
                        im = self.visualize_video(frame, mot_res, frame_id, fps,
                                                  entrance, records,
                                                  center_traj)  # visualize
                        writer.write(im)
                        if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
668
                            cv2.imshow('Paddle-Pipeline', im)
669 670 671 672 673
                            if cv2.waitKey(1) & 0xFF == ord('q'):
                                break
                    continue

                self.pipeline_res.update(mot_res, 'mot')
J
JYChen 已提交
674
                crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
675
                    frame_rgb, mot_res)
676

677
                if self.with_vehicleplate and frame_id % 10 == 0:
Z
zhiboniu 已提交
678 679
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].start()
Z
zhiboniu 已提交
680 681
                    plate_input, _, _ = crop_image_with_mot(
                        frame_rgb, mot_res, expand=False)
Z
zhiboniu 已提交
682
                    platelicense = self.vehicleplate_detector.get_platelicense(
Z
zhiboniu 已提交
683
                        plate_input)
Z
zhiboniu 已提交
684 685
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].end()
Z
zhiboniu 已提交
686
                    self.pipeline_res.update(platelicense, 'vehicleplate')
687 688
                else:
                    self.pipeline_res.clear('vehicleplate')
Z
zhiboniu 已提交
689

690
                if self.with_human_attr:
J
JYChen 已提交
691
                    if frame_id > self.warmup_frame:
692 693 694 695 696 697 698
                        self.pipe_timer.module_time['attr'].start()
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['attr'].end()
                    self.pipeline_res.update(attr_res, 'attr')

699 700 701 702 703 704 705 706 707
                if self.with_vehicle_attr:
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].start()
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].end()
                    self.pipeline_res.update(attr_res, 'vehicle_attr')

Z
zhiboniu 已提交
708
                if self.with_idbased_detaction:
J
JYChen 已提交
709 710 711 712 713 714 715 716 717 718
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].start()
                    det_action_res = self.det_action_predictor.predict(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].end()
                    self.pipeline_res.update(det_action_res, 'det_action')

                    if self.cfg['visual']:
                        self.det_action_visual_helper.update(det_action_res)
Z
zhiboniu 已提交
719 720

                if self.with_idbased_clsaction:
J
JYChen 已提交
721 722 723 724 725 726 727 728 729 730
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].start()
                    cls_action_res = self.cls_action_predictor.predict_with_mot(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].end()
                    self.pipeline_res.update(cls_action_res, 'cls_action')

                    if self.cfg['visual']:
                        self.cls_action_visual_helper.update(cls_action_res)
Z
zhiboniu 已提交
731

Z
zhiboniu 已提交
732
                if self.with_skeleton_action:
Z
zhiboniu 已提交
733 734 735 736 737 738 739 740 741 742 743 744 745
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].start()
                    kpt_pred = self.kpt_predictor.predict_image(
                        crop_input, visual=False)
                    keypoint_vector, score_vector = translate_to_ori_images(
                        kpt_pred, np.array(new_bboxes))
                    kpt_res = {}
                    kpt_res['keypoint'] = [
                        keypoint_vector.tolist(), score_vector.tolist()
                    ] if len(keypoint_vector) > 0 else [[], []]
                    kpt_res['bbox'] = ori_bboxes
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].end()
746

Z
zhiboniu 已提交
747
                    self.pipeline_res.update(kpt_res, 'kpt')
748

Z
zhiboniu 已提交
749
                    self.kpt_buff.update(kpt_res, mot_res)  # collect kpt output
750 751 752
                    state = self.kpt_buff.get_state(
                    )  # whether frame num is enough or lost tracker

Z
zhiboniu 已提交
753
                    skeleton_action_res = {}
754 755
                    if state:
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
756 757
                            self.pipe_timer.module_time[
                                'skeleton_action'].start()
758 759
                        collected_keypoint = self.kpt_buff.get_collected_keypoint(
                        )  # reoragnize kpt output with ID
Z
zhiboniu 已提交
760 761 762 763
                        skeleton_action_input = parse_mot_keypoint(
                            collected_keypoint, self.coord_size)
                        skeleton_action_res = self.skeleton_action_predictor.predict_skeleton_with_mot(
                            skeleton_action_input)
764
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
765 766 767
                            self.pipe_timer.module_time['skeleton_action'].end()
                        self.pipeline_res.update(skeleton_action_res,
                                                 'skeleton_action')
768 769

                    if self.cfg['visual']:
Z
zhiboniu 已提交
770 771
                        self.skeleton_action_visual_helper.update(
                            skeleton_action_res)
772 773 774

                if self.with_mtmct and frame_id % 10 == 0:
                    crop_input, img_qualities, rects = self.reid_predictor.crop_image_with_mot(
775
                        frame_rgb, mot_res)
776 777 778 779 780 781
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].start()
                    reid_res = self.reid_predictor.predict_batch(crop_input)

                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].end()
J
JYChen 已提交
782

783 784 785 786 787 788 789 790
                    reid_res_dict = {
                        'features': reid_res,
                        "qualities": img_qualities,
                        "rects": rects
                    }
                    self.pipeline_res.update(reid_res_dict, 'reid')
                else:
                    self.pipeline_res.clear('reid')
Z
zhiboniu 已提交
791

Z
zhiboniu 已提交
792
            if self.with_video_action:
793 794 795 796 797 798 799 800 801 802 803 804 805
                # get the params
                frame_len = self.cfg["VIDEO_ACTION"]["frame_len"]
                sample_freq = self.cfg["VIDEO_ACTION"]["sample_freq"]

                if sample_freq * frame_len > frame_count:  # video is too short
                    sample_freq = int(frame_count / frame_len)

                # filter the warmup frames
                if frame_id > self.warmup_frame:
                    self.pipe_timer.module_time['video_action'].start()

                # collect frames
                if frame_id % sample_freq == 0:
806
                    # Scale image
807
                    scaled_img = scale(frame_rgb)
808
                    video_action_imgs.append(scaled_img)
809 810 811 812 813 814 815 816 817 818 819 820 821 822

                # the number of collected frames is enough to predict video action
                if len(video_action_imgs) == frame_len:
                    classes, scores = self.video_action_predictor.predict(
                        video_action_imgs)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['video_action'].end()

                    video_action_res = {"class": classes[0], "score": scores[0]}
                    self.pipeline_res.update(video_action_res, 'video_action')

                    print("video_action_res:", video_action_res)

                    video_action_imgs.clear()  # next clip
Z
zhiboniu 已提交
823 824

            self.collector.append(frame_id, self.pipeline_res)
825 826 827 828 829 830 831

            if frame_id > self.warmup_frame:
                self.pipe_timer.img_num += 1
                self.pipe_timer.total_time.end()
            frame_id += 1

            if self.cfg['visual']:
832
                _, _, fps = self.pipe_timer.get_total_time()
833 834 835 836 837 838

                im = self.visualize_video(frame, self.pipeline_res,
                                          self.collector, frame_id, fps,
                                          entrance, records, center_traj,
                                          self.illegal_parking_time != -1,
                                          illegal_parking_dict)  # visualize
839
                writer.write(im)
W
wangguanzhong 已提交
840
                if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
841
                    cv2.imshow('Paddle-Pipeline', im)
W
wangguanzhong 已提交
842 843
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
844 845 846 847

        writer.release()
        print('save result to {}'.format(out_path))

848 849 850
    def visualize_video(self,
                        image,
                        result,
851
                        collector,
852 853 854 855
                        frame_id,
                        fps,
                        entrance=None,
                        records=None,
856 857 858
                        center_traj=None,
                        do_illegal_parking_recognition=False,
                        illegal_parking_dict=None):
Z
zhiboniu 已提交
859
        mot_res = copy.deepcopy(result.get('mot'))
860 861
        if mot_res is not None:
            ids = mot_res['boxes'][:, 0]
W
wangguanzhong 已提交
862
            scores = mot_res['boxes'][:, 2]
863 864 865 866 867 868
            boxes = mot_res['boxes'][:, 3:]
            boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
            boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        else:
            boxes = np.zeros([0, 4])
            ids = np.zeros([0])
W
wangguanzhong 已提交
869
            scores = np.zeros([0])
870 871 872 873 874 875 876 877 878 879

        # single class, still need to be defaultdict type for ploting
        num_classes = 1
        online_tlwhs = defaultdict(list)
        online_scores = defaultdict(list)
        online_ids = defaultdict(list)
        online_tlwhs[0] = boxes
        online_scores[0] = scores
        online_ids[0] = ids

F
Feng Ni 已提交
880 881 882 883 884 885 886 887 888
        if mot_res is not None:
            image = plot_tracking_dict(
                image,
                num_classes,
                online_tlwhs,
                online_ids,
                online_scores,
                frame_id=frame_id,
                fps=fps,
889
                ids2names=self.mot_predictor.pred_config.labels,
F
Feng Ni 已提交
890
                do_entrance_counting=self.do_entrance_counting,
891
                do_break_in_counting=self.do_break_in_counting,
892 893
                do_illegal_parking_recognition=do_illegal_parking_recognition,
                illegal_parking_dict=illegal_parking_dict,
F
Feng Ni 已提交
894 895 896
                entrance=entrance,
                records=records,
                center_traj=center_traj)
897

898 899 900 901 902 903 904 905 906
        human_attr_res = result.get('attr')
        if human_attr_res is not None:
            boxes = mot_res['boxes'][:, 1:]
            human_attr_res = human_attr_res['output']
            image = visualize_attr(image, human_attr_res, boxes)
            image = np.array(image)

        vehicle_attr_res = result.get('vehicle_attr')
        if vehicle_attr_res is not None:
907
            boxes = mot_res['boxes'][:, 1:]
908 909
            vehicle_attr_res = vehicle_attr_res['output']
            image = visualize_attr(image, vehicle_attr_res, boxes)
910 911
            image = np.array(image)

912 913 914 915 916 917 918 919 920 921 922 923 924 925
        if mot_res is not None:
            vehicleplate = False
            plates = []
            for trackid in mot_res['boxes'][:, 0]:
                plate = collector.get_carlp(trackid)
                if plate != None:
                    vehicleplate = True
                    plates.append(plate)
                else:
                    plates.append("")
            if vehicleplate:
                boxes = mot_res['boxes'][:, 1:]
                image = visualize_vehicleplate(image, plates, boxes)
                image = np.array(image)
Z
zhiboniu 已提交
926

J
JYChen 已提交
927 928 929 930 931 932 933 934
        kpt_res = result.get('kpt')
        if kpt_res is not None:
            image = visualize_pose(
                image,
                kpt_res,
                visual_thresh=self.cfg['kpt_thresh'],
                returnimg=True)

935
        video_action_res = result.get('video_action')
J
JYChen 已提交
936
        if video_action_res is not None:
937 938 939
            video_action_score = None
            if video_action_res and video_action_res["class"] == 1:
                video_action_score = video_action_res["score"]
940 941 942
            mot_boxes = None
            if mot_res:
                mot_boxes = mot_res['boxes']
943 944
            image = visualize_action(
                image,
945
                mot_boxes,
J
JYChen 已提交
946
                action_visual_collector=None,
947 948 949
                action_text="SkeletonAction",
                video_action_score=video_action_score,
                video_action_text="Fight")
J
JYChen 已提交
950

J
JYChen 已提交
951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973
        visual_helper_for_display = []
        action_to_display = []

        skeleton_action_res = result.get('skeleton_action')
        if skeleton_action_res is not None:
            visual_helper_for_display.append(self.skeleton_action_visual_helper)
            action_to_display.append("Falling")

        det_action_res = result.get('det_action')
        if det_action_res is not None:
            visual_helper_for_display.append(self.det_action_visual_helper)
            action_to_display.append("Smoking")

        cls_action_res = result.get('cls_action')
        if cls_action_res is not None:
            visual_helper_for_display.append(self.cls_action_visual_helper)
            action_to_display.append("Calling")

        if len(visual_helper_for_display) > 0:
            image = visualize_action(image, mot_res['boxes'],
                                     visual_helper_for_display,
                                     action_to_display)

974 975 976 977 978
        return image

    def visualize_image(self, im_files, images, result):
        start_idx, boxes_num_i = 0, 0
        det_res = result.get('det')
979 980
        human_attr_res = result.get('attr')
        vehicle_attr_res = result.get('vehicle_attr')
Z
zhiboniu 已提交
981
        vehicleplate_res = result.get('vehicleplate')
982

983 984 985 986 987 988 989 990 991
        for i, (im_file, im) in enumerate(zip(im_files, images)):
            if det_res is not None:
                det_res_i = {}
                boxes_num_i = det_res['boxes_num'][i]
                det_res_i['boxes'] = det_res['boxes'][start_idx:start_idx +
                                                      boxes_num_i, :]
                im = visualize_box_mask(
                    im,
                    det_res_i,
Z
zhiboniu 已提交
992
                    labels=['target'],
993
                    threshold=self.cfg['crop_thresh'])
994 995
                im = np.ascontiguousarray(np.copy(im))
                im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
996 997 998 999 1000 1001 1002 1003
            if human_attr_res is not None:
                human_attr_res_i = human_attr_res['output'][start_idx:start_idx
                                                            + boxes_num_i]
                im = visualize_attr(im, human_attr_res_i, det_res_i['boxes'])
            if vehicle_attr_res is not None:
                vehicle_attr_res_i = vehicle_attr_res['output'][
                    start_idx:start_idx + boxes_num_i]
                im = visualize_attr(im, vehicle_attr_res_i, det_res_i['boxes'])
Z
zhiboniu 已提交
1004 1005 1006 1007 1008
            if vehicleplate_res is not None:
                plates = vehicleplate_res['vehicleplate']
                det_res_i['boxes'][:, 4:6] = det_res_i[
                    'boxes'][:, 4:6] - det_res_i['boxes'][:, 2:4]
                im = visualize_vehicleplate(im, plates, det_res_i['boxes'])
1009

1010 1011 1012 1013
            img_name = os.path.split(im_file)[-1]
            if not os.path.exists(self.output_dir):
                os.makedirs(self.output_dir)
            out_path = os.path.join(self.output_dir, img_name)
1014
            cv2.imwrite(out_path, im)
1015 1016 1017 1018 1019
            print("save result to: " + out_path)
            start_idx += boxes_num_i


def main():
1020
    cfg = merge_cfg(FLAGS)  # use command params to update config
1021
    print_arguments(cfg)
1022

Z
zhiboniu 已提交
1023
    pipeline = Pipeline(FLAGS, cfg)
1024 1025 1026 1027 1028
    pipeline.run()


if __name__ == '__main__':
    paddle.enable_static()
1029 1030

    # parse params from command
1031 1032 1033 1034 1035 1036 1037
    parser = argsparser()
    FLAGS = parser.parse_args()
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"

    main()