pipeline.py 42.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import glob
import cv2
import numpy as np
import math
import paddle
import sys
Z
zhiboniu 已提交
23
import copy
Z
zhiboniu 已提交
24
from collections import Sequence, defaultdict
Z
zhiboniu 已提交
25
from datacollector import DataCollector, Result
26 27 28 29 30

# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)

31 32
from cfg_utils import argsparser, print_arguments, merge_cfg
from pipe_utils import PipeTimer
Z
zhiboniu 已提交
33 34
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint

35
from python.infer import Detector, DetectorPicoDet
J
JYChen 已提交
36 37
from python.keypoint_infer import KeyPointDetector
from python.keypoint_postprocess import translate_to_ori_images
38
from python.preprocess import decode_image, ShortSizeScale
Z
zhiboniu 已提交
39
from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action, visualize_vehicleplate
40 41

from pptracking.python.mot_sde_infer import SDE_Detector
42
from pptracking.python.mot.visualize import plot_tracking_dict
43
from pptracking.python.mot.utils import flow_statistic, update_object_info
44

Z
zhiboniu 已提交
45 46 47 48 49 50 51
from pphuman.attr_infer import AttrDetector
from pphuman.video_action_infer import VideoActionRecognizer
from pphuman.action_infer import SkeletonActionRecognizer, DetActionRecognizer, ClsActionRecognizer
from pphuman.action_utils import KeyPointBuff, ActionVisualHelper
from pphuman.reid import ReID
from pphuman.mtmct import mtmct_process

52 53 54
from ppvehicle.vehicle_plate import PlateRecognizer
from ppvehicle.vehicle_attr import VehicleAttr

55 56
from download import auto_download_model

57 58 59 60 61 62

class Pipeline(object):
    """
    Pipeline

    Args:
J
JYChen 已提交
63
        args (argparse.Namespace): arguments in pipeline, which contains environment and runtime settings
64 65 66
        cfg (dict): config of models in pipeline
    """

Z
zhiboniu 已提交
67
    def __init__(self, args, cfg):
68
        self.multi_camera = False
Z
zhiboniu 已提交
69 70
        reid_cfg = cfg.get('REID', False)
        self.enable_mtmct = reid_cfg['enable'] if reid_cfg else False
71
        self.is_video = False
Z
zhiboniu 已提交
72
        self.output_dir = args.output_dir
Z
zhiboniu 已提交
73
        self.vis_result = cfg['visual']
Z
zhiboniu 已提交
74 75 76
        self.input = self._parse_input(args.image_file, args.image_dir,
                                       args.video_file, args.video_dir,
                                       args.camera_id)
77
        if self.multi_camera:
78 79 80
            self.predictor = []
            for name in self.input:
                predictor_item = PipePredictor(
Z
zhiboniu 已提交
81
                    args, cfg, is_video=True, multi_camera=True)
82 83 84
                predictor_item.set_file_name(name)
                self.predictor.append(predictor_item)

85
        else:
Z
zhiboniu 已提交
86
            self.predictor = PipePredictor(args, cfg, self.is_video)
87
            if self.is_video:
Z
zhiboniu 已提交
88
                self.predictor.set_file_name(args.video_file)
89

Z
zhiboniu 已提交
90 91
    def _parse_input(self, image_file, image_dir, video_file, video_dir,
                     camera_id):
92 93 94 95 96 97 98 99 100

        # parse input as is_video and multi_camera

        if image_file is not None or image_dir is not None:
            input = get_test_images(image_dir, image_file)
            self.is_video = False
            self.multi_camera = False

        elif video_file is not None:
Z
zhiboniu 已提交
101 102 103
            assert os.path.exists(
                video_file
            ) or 'rtsp' in video_file, "video_file not exists and not an rtsp site."
Z
zhiboniu 已提交
104 105 106 107 108 109 110
            self.multi_camera = False
            input = video_file
            self.is_video = True

        elif video_dir is not None:
            videof = [os.path.join(video_dir, x) for x in os.listdir(video_dir)]
            if len(videof) > 1:
111
                self.multi_camera = True
Z
zhiboniu 已提交
112 113
                videof.sort()
                input = videof
114
            else:
Z
zhiboniu 已提交
115
                input = videof[0]
116 117 118
            self.is_video = True

        elif camera_id != -1:
Z
zhiboniu 已提交
119 120
            self.multi_camera = False
            input = camera_id
121 122 123 124
            self.is_video = True

        else:
            raise ValueError(
125
                "Illegal Input, please set one of ['video_file', 'camera_id', 'image_file', 'image_dir']"
126 127 128 129 130 131 132 133 134
            )

        return input

    def run(self):
        if self.multi_camera:
            multi_res = []
            for predictor, input in zip(self.predictor, self.input):
                predictor.run(input)
Z
zhiboniu 已提交
135 136
                collector_data = predictor.get_result()
                multi_res.append(collector_data)
137 138 139 140 141 142
            if self.enable_mtmct:
                mtmct_process(
                    multi_res,
                    self.input,
                    mtmct_vis=self.vis_result,
                    output_dir=self.output_dir)
143 144 145 146 147

        else:
            self.predictor.run(self.input)


148
def get_model_dir(cfg):
J
JYChen 已提交
149 150 151 152
    """ 
        Auto download inference model if the model_path is a url link. 
        Otherwise it will use the model_path directly.
    """
153 154 155 156 157 158 159 160 161 162
    for key in cfg.keys():
        if type(cfg[key]) ==  dict and \
            ("enable" in cfg[key].keys() and cfg[key]['enable']
                or "enable" not in cfg[key].keys()):

            if "model_dir" in cfg[key].keys():
                model_dir = cfg[key]["model_dir"]
                downloaded_model_dir = auto_download_model(model_dir)
                if downloaded_model_dir:
                    model_dir = downloaded_model_dir
J
JYChen 已提交
163 164
                    cfg[key]["model_dir"] = model_dir
                print(key, " model dir: ", model_dir)
165 166 167 168 169
            elif key == "VEHICLE_PLATE":
                det_model_dir = cfg[key]["det_model_dir"]
                downloaded_det_model_dir = auto_download_model(det_model_dir)
                if downloaded_det_model_dir:
                    det_model_dir = downloaded_det_model_dir
J
JYChen 已提交
170 171
                    cfg[key]["det_model_dir"] = det_model_dir
                print("det_model_dir model dir: ", det_model_dir)
172 173 174 175 176

                rec_model_dir = cfg[key]["rec_model_dir"]
                downloaded_rec_model_dir = auto_download_model(rec_model_dir)
                if downloaded_rec_model_dir:
                    rec_model_dir = downloaded_rec_model_dir
J
JYChen 已提交
177 178 179
                    cfg[key]["rec_model_dir"] = rec_model_dir
                print("rec_model_dir model dir: ", rec_model_dir)

180 181 182 183 184
        elif key == "MOT":  # for idbased and skeletonbased actions
            model_dir = cfg[key]["model_dir"]
            downloaded_model_dir = auto_download_model(model_dir)
            if downloaded_model_dir:
                model_dir = downloaded_model_dir
J
JYChen 已提交
185 186
                cfg[key]["model_dir"] = model_dir
            print("mot_model_dir model_dir: ", model_dir)
187 188


189 190 191 192 193 194 195 196 197 198 199 200 201
class PipePredictor(object):
    """
    Predictor in single camera
    
    The pipeline for image input: 

        1. Detection
        2. Detection -> Attribute

    The pipeline for video input: 

        1. Tracking
        2. Tracking -> Attribute
Z
zhiboniu 已提交
202
        3. Tracking -> KeyPoint -> SkeletonAction Recognition
203
        4. VideoAction Recognition
204 205

    Args:
J
JYChen 已提交
206
        args (argparse.Namespace): arguments in pipeline, which contains environment and runtime settings
207 208 209 210 211 212
        cfg (dict): config of models in pipeline
        is_video (bool): whether the input is video, default as False
        multi_camera (bool): whether to use multi camera in pipeline, 
            default as False
    """

Z
zhiboniu 已提交
213 214 215 216
    def __init__(self, args, cfg, is_video=True, multi_camera=False):
        # general module for pphuman and ppvehicle
        self.with_mot = cfg.get('MOT', False)['enable'] if cfg.get(
            'MOT', False) else False
217
        self.with_human_attr = cfg.get('ATTR', False)['enable'] if cfg.get(
Z
zhiboniu 已提交
218
            'ATTR', False) else False
Z
zhiboniu 已提交
219 220
        if self.with_mot:
            print('Multi-Object Tracking enabled')
221 222
        if self.with_human_attr:
            print('Human Attribute Recognition enabled')
Z
zhiboniu 已提交
223 224

        # only for pphuman
Z
zhiboniu 已提交
225 226 227
        self.with_skeleton_action = cfg.get(
            'SKELETON_ACTION', False)['enable'] if cfg.get('SKELETON_ACTION',
                                                           False) else False
Z
zhiboniu 已提交
228 229 230 231 232 233 234 235 236
        self.with_video_action = cfg.get(
            'VIDEO_ACTION', False)['enable'] if cfg.get('VIDEO_ACTION',
                                                        False) else False
        self.with_idbased_detaction = cfg.get(
            'ID_BASED_DETACTION', False)['enable'] if cfg.get(
                'ID_BASED_DETACTION', False) else False
        self.with_idbased_clsaction = cfg.get(
            'ID_BASED_CLSACTION', False)['enable'] if cfg.get(
                'ID_BASED_CLSACTION', False) else False
Z
zhiboniu 已提交
237 238
        self.with_mtmct = cfg.get('REID', False)['enable'] if cfg.get(
            'REID', False) else False
239

Z
zhiboniu 已提交
240 241
        if self.with_skeleton_action:
            print('SkeletonAction Recognition enabled')
Z
zhiboniu 已提交
242 243 244 245 246 247
        if self.with_video_action:
            print('VideoAction Recognition enabled')
        if self.with_idbased_detaction:
            print('IDBASED Detection Action Recognition enabled')
        if self.with_idbased_clsaction:
            print('IDBASED Classification Action Recognition enabled')
Z
zhiboniu 已提交
248 249
        if self.with_mtmct:
            print("MTMCT enabled")
W
wangguanzhong 已提交
250

Z
zhiboniu 已提交
251 252 253 254 255 256 257
        # only for ppvehicle
        self.with_vehicleplate = cfg.get(
            'VEHICLE_PLATE', False)['enable'] if cfg.get('VEHICLE_PLATE',
                                                         False) else False
        if self.with_vehicleplate:
            print('Vehicle Plate Recognition enabled')

258 259 260 261 262 263
        self.with_vehicle_attr = cfg.get(
            'VEHICLE_ATTR', False)['enable'] if cfg.get('VEHICLE_ATTR',
                                                        False) else False
        if self.with_vehicle_attr:
            print('Vehicle Attribute Recognition enabled')

264 265 266 267 268 269
        self.modebase = {
            "framebased": False,
            "videobased": False,
            "idbased": False,
            "skeletonbased": False
        }
270

271 272 273 274 275 276 277 278 279 280 281 282
        self.basemode = {
            "MOT": "idbased",
            "ATTR": "idbased",
            "VIDEO_ACTION": "videobased",
            "SKELETON_ACTION": "skeletonbased",
            "ID_BASED_DETACTION": "idbased",
            "ID_BASED_CLSACTION": "idbased",
            "REID": "idbased",
            "VEHICLE_PLATE": "idbased",
            "VEHICLE_ATTR": "idbased",
        }

283 284 285
        self.is_video = is_video
        self.multi_camera = multi_camera
        self.cfg = cfg
286

J
JYChen 已提交
287 288 289 290 291 292 293
        self.output_dir = args.output_dir
        self.draw_center_traj = args.draw_center_traj
        self.secs_interval = args.secs_interval
        self.do_entrance_counting = args.do_entrance_counting
        self.do_break_in_counting = args.do_break_in_counting
        self.region_type = args.region_type
        self.region_polygon = args.region_polygon
294
        self.illegal_parking_time = args.illegal_parking_time
295

J
JYChen 已提交
296
        self.warmup_frame = self.cfg['warmup_frame']
297 298
        self.pipeline_res = Result()
        self.pipe_timer = PipeTimer()
299
        self.file_name = None
Z
zhiboniu 已提交
300
        self.collector = DataCollector()
301

302
        # auto download inference model
J
JYChen 已提交
303
        get_model_dir(self.cfg)
304

Z
zhiboniu 已提交
305 306 307 308 309 310 311 312 313 314
        if self.with_vehicleplate:
            vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
            self.vehicleplate_detector = PlateRecognizer(args, vehicleplate_cfg)
            basemode = self.basemode['VEHICLE_PLATE']
            self.modebase[basemode] = True

        if self.with_human_attr:
            attr_cfg = self.cfg['ATTR']
            basemode = self.basemode['ATTR']
            self.modebase[basemode] = True
J
JYChen 已提交
315
            self.attr_predictor = AttrDetector.init_with_cfg(args, attr_cfg)
Z
zhiboniu 已提交
316 317 318 319 320

        if self.with_vehicle_attr:
            vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
            basemode = self.basemode['VEHICLE_ATTR']
            self.modebase[basemode] = True
J
JYChen 已提交
321 322
            self.vehicle_attr_predictor = VehicleAttr.init_with_cfg(
                args, vehicleattr_cfg)
Z
zhiboniu 已提交
323

324 325
        if not is_video:
            det_cfg = self.cfg['DET']
J
JYChen 已提交
326
            model_dir = det_cfg['model_dir']
327 328
            batch_size = det_cfg['batch_size']
            self.det_predictor = Detector(
J
JYChen 已提交
329 330 331
                model_dir, args.device, args.run_mode, batch_size,
                args.trt_min_shape, args.trt_max_shape, args.trt_opt_shape,
                args.trt_calib_mode, args.cpu_threads, args.enable_mkldnn)
332
        else:
Z
zhiboniu 已提交
333
            if self.with_idbased_detaction:
J
JYChen 已提交
334
                idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION']
335
                basemode = self.basemode['ID_BASED_DETACTION']
J
JYChen 已提交
336
                self.modebase[basemode] = True
337

J
JYChen 已提交
338 339
                self.det_action_predictor = DetActionRecognizer.init_with_cfg(
                    args, idbased_detaction_cfg)
J
JYChen 已提交
340 341
                self.det_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
342
            if self.with_idbased_clsaction:
J
JYChen 已提交
343
                idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION']
344
                basemode = self.basemode['ID_BASED_CLSACTION']
J
JYChen 已提交
345
                self.modebase[basemode] = True
346

J
JYChen 已提交
347 348
                self.cls_action_predictor = ClsActionRecognizer.init_with_cfg(
                    args, idbased_clsaction_cfg)
J
JYChen 已提交
349 350
                self.cls_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
351 352 353 354
            if self.with_skeleton_action:
                skeleton_action_cfg = self.cfg['SKELETON_ACTION']
                display_frames = skeleton_action_cfg['display_frames']
                self.coord_size = skeleton_action_cfg['coord_size']
355
                basemode = self.basemode['SKELETON_ACTION']
356
                self.modebase[basemode] = True
J
JYChen 已提交
357
                skeleton_action_frames = skeleton_action_cfg['max_frames']
358

J
JYChen 已提交
359 360
                self.skeleton_action_predictor = SkeletonActionRecognizer.init_with_cfg(
                    args, skeleton_action_cfg)
J
JYChen 已提交
361
                self.skeleton_action_visual_helper = ActionVisualHelper(
Z
zhiboniu 已提交
362
                    display_frames)
363

J
JYChen 已提交
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
                kpt_cfg = self.cfg['KPT']
                kpt_model_dir = kpt_cfg['model_dir']
                kpt_batch_size = kpt_cfg['batch_size']
                self.kpt_predictor = KeyPointDetector(
                    kpt_model_dir,
                    args.device,
                    args.run_mode,
                    kpt_batch_size,
                    args.trt_min_shape,
                    args.trt_max_shape,
                    args.trt_opt_shape,
                    args.trt_calib_mode,
                    args.cpu_threads,
                    args.enable_mkldnn,
                    use_dark=False)
                self.kpt_buff = KeyPointBuff(skeleton_action_frames)
Z
zhiboniu 已提交
380

381 382 383 384 385 386 387
            if self.with_vehicleplate:
                vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
                self.vehicleplate_detector = PlateRecognizer(args,
                                                             vehicleplate_cfg)
                basemode = self.basemode['VEHICLE_PLATE']
                self.modebase[basemode] = True

Z
zhiboniu 已提交
388 389
            if self.with_mtmct:
                reid_cfg = self.cfg['REID']
390
                basemode = self.basemode['REID']
Z
zhiboniu 已提交
391
                self.modebase[basemode] = True
J
JYChen 已提交
392
                self.reid_predictor = ReID.init_with_cfg(args, reid_cfg)
Z
zhiboniu 已提交
393

Z
zhiboniu 已提交
394 395 396
            if self.with_mot or self.modebase["idbased"] or self.modebase[
                    "skeletonbased"]:
                mot_cfg = self.cfg['MOT']
J
JYChen 已提交
397
                model_dir = mot_cfg['model_dir']
Z
zhiboniu 已提交
398 399
                tracker_config = mot_cfg['tracker_config']
                batch_size = mot_cfg['batch_size']
400
                skip_frame_num = mot_cfg.get('skip_frame_num', -1)
401
                basemode = self.basemode['MOT']
Z
zhiboniu 已提交
402 403 404 405
                self.modebase[basemode] = True
                self.mot_predictor = SDE_Detector(
                    model_dir,
                    tracker_config,
J
JYChen 已提交
406 407
                    args.device,
                    args.run_mode,
Z
zhiboniu 已提交
408
                    batch_size,
J
JYChen 已提交
409 410 411 412 413 414
                    args.trt_min_shape,
                    args.trt_max_shape,
                    args.trt_opt_shape,
                    args.trt_calib_mode,
                    args.cpu_threads,
                    args.enable_mkldnn,
415
                    skip_frame_num=skip_frame_num,
J
JYChen 已提交
416 417 418 419 420 421
                    draw_center_traj=self.draw_center_traj,
                    secs_interval=self.secs_interval,
                    do_entrance_counting=self.do_entrance_counting,
                    do_break_in_counting=self.do_break_in_counting,
                    region_type=self.region_type,
                    region_polygon=self.region_polygon)
Z
zhiboniu 已提交
422

423 424
            if self.with_video_action:
                video_action_cfg = self.cfg['VIDEO_ACTION']
425
                basemode = self.basemode['VIDEO_ACTION']
426
                self.modebase[basemode] = True
J
JYChen 已提交
427 428
                self.video_action_predictor = VideoActionRecognizer.init_with_cfg(
                    args, video_action_cfg)
429

430
    def set_file_name(self, path):
W
wangguanzhong 已提交
431 432 433 434 435
        if path is not None:
            self.file_name = os.path.split(path)[-1]
        else:
            # use camera id
            self.file_name = None
436

437
    def get_result(self):
Z
zhiboniu 已提交
438
        return self.collector.get_res()
439 440 441 442 443 444

    def run(self, input):
        if self.is_video:
            self.predict_video(input)
        else:
            self.predict_image(input)
445
        self.pipe_timer.info()
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463

    def predict_image(self, input):
        # det
        # det -> attr
        batch_loop_cnt = math.ceil(
            float(len(input)) / self.det_predictor.batch_size)
        for i in range(batch_loop_cnt):
            start_index = i * self.det_predictor.batch_size
            end_index = min((i + 1) * self.det_predictor.batch_size, len(input))
            batch_file = input[start_index:end_index]
            batch_input = [decode_image(f, {})[0] for f in batch_file]

            if i > self.warmup_frame:
                self.pipe_timer.total_time.start()
                self.pipe_timer.module_time['det'].start()
            # det output format: class, score, xmin, ymin, xmax, ymax
            det_res = self.det_predictor.predict_image(
                batch_input, visual=False)
464 465
            det_res = self.det_predictor.filter_box(det_res,
                                                    self.cfg['crop_thresh'])
466 467
            if i > self.warmup_frame:
                self.pipe_timer.module_time['det'].end()
Z
zhiboniu 已提交
468
                self.pipe_timer.track_num += len(det_res['boxes'])
469 470
            self.pipeline_res.update(det_res, 'det')

471
            if self.with_human_attr:
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
                crop_inputs = crop_image_with_det(batch_input, det_res)
                attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].end()

                attr_res = {'output': attr_res_list}
                self.pipeline_res.update(attr_res, 'attr')

489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506
            if self.with_vehicle_attr:
                crop_inputs = crop_image_with_det(batch_input, det_res)
                vehicle_attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    vehicle_attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].end()

                attr_res = {'output': vehicle_attr_res_list}
                self.pipeline_res.update(attr_res, 'vehicle_attr')

Z
zhiboniu 已提交
507 508 509 510 511 512 513 514 515 516 517 518 519 520
            if self.with_vehicleplate:
                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicleplate'].start()
                crop_inputs = crop_image_with_det(batch_input, det_res)
                platelicenses = []
                for crop_input in crop_inputs:
                    platelicense = self.vehicleplate_detector.get_platelicense(
                        crop_input)
                    platelicenses.extend(platelicense['plate'])
                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicleplate'].end()
                vehicleplate_res = {'vehicleplate': platelicenses}
                self.pipeline_res.update(vehicleplate_res, 'vehicleplate')

521 522 523 524 525 526 527
            self.pipe_timer.img_num += len(batch_input)
            if i > self.warmup_frame:
                self.pipe_timer.total_time.end()

            if self.cfg['visual']:
                self.visualize_image(batch_file, batch_input, self.pipeline_res)

Z
zhiboniu 已提交
528
    def predict_video(self, video_file):
529 530 531
        # mot
        # mot -> attr
        # mot -> pose -> action
Z
zhiboniu 已提交
532
        capture = cv2.VideoCapture(video_file)
533
        video_out_name = 'output.mp4' if self.file_name is None else self.file_name
Z
zhiboniu 已提交
534 535
        if "rtsp" in video_file:
            video_out_name = video_out_name + "_rtsp.mp4"
536 537 538 539 540 541

        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
542
        print("video fps: %d, frame_count: %d" % (fps, frame_count))
543 544 545 546 547 548 549

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
        fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
        frame_id = 0
550 551 552 553 554 555 556 557 558 559

        entrance, records, center_traj = None, None, None
        if self.draw_center_traj:
            center_traj = [{}]
        id_set = set()
        interval_id_set = set()
        in_id_list = list()
        out_id_list = list()
        prev_center = dict()
        records = list()
560
        if self.do_entrance_counting or self.do_break_in_counting or self.illegal_parking_time != -1:
561 562 563 564 565 566 567 568 569
            if self.region_type == 'horizontal':
                entrance = [0, height / 2., width, height / 2.]
            elif self.region_type == 'vertical':
                entrance = [width / 2, 0., width / 2, height]
            elif self.region_type == 'custom':
                entrance = []
                assert len(
                    self.region_polygon
                ) % 2 == 0, "region_polygon should be pairs of coords points when do break_in counting."
J
JYChen 已提交
570 571 572 573
                assert len(
                    self.region_polygon
                ) > 6, 'region_type is custom, region_polygon should be at least 3 pairs of point coords.'

574 575 576 577 578 579 580 581
                for i in range(0, len(self.region_polygon), 2):
                    entrance.append(
                        [self.region_polygon[i], self.region_polygon[i + 1]])
                entrance.append([width, height])
            else:
                raise ValueError("region_type:{} unsupported.".format(
                    self.region_type))

582 583
        video_fps = fps

584 585
        video_action_imgs = []

586 587 588 589
        if self.with_video_action:
            short_size = self.cfg["VIDEO_ACTION"]["short_size"]
            scale = ShortSizeScale(short_size)

590 591 592 593
        object_in_region_info = {
        }  # store info for vehicle parking in region       
        illegal_parking_dict = None

594 595 596
        while (1):
            if frame_id % 10 == 0:
                print('frame id: ', frame_id)
597

598 599 600
            ret, frame = capture.read()
            if not ret:
                break
601
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
Z
zhiboniu 已提交
602 603
            if frame_id > self.warmup_frame:
                self.pipe_timer.total_time.start()
604

605
            if self.modebase["idbased"] or self.modebase["skeletonbased"]:
606
                if frame_id > self.warmup_frame:
607
                    self.pipe_timer.module_time['mot'].start()
608

609 610 611 612 613 614 615 616
                mot_skip_frame_num = self.mot_predictor.skip_frame_num
                reuse_det_result = False
                if mot_skip_frame_num > 1 and frame_id > 0 and frame_id % mot_skip_frame_num > 0:
                    reuse_det_result = True
                res = self.mot_predictor.predict_image(
                    [copy.deepcopy(frame_rgb)],
                    visual=False,
                    reuse_det_result=reuse_det_result)
617 618 619

                # mot output format: id, class, score, xmin, ymin, xmax, ymax
                mot_res = parse_mot_res(res)
Z
zhiboniu 已提交
620 621 622
                if frame_id > self.warmup_frame:
                    self.pipe_timer.module_time['mot'].end()
                    self.pipe_timer.track_num += len(mot_res['boxes'])
623 624 625 626 627 628 629

                # flow_statistic only support single class MOT
                boxes, scores, ids = res[0]  # batch size = 1 in MOT
                mot_result = (frame_id + 1, boxes[0], scores[0],
                              ids[0])  # single class
                statistic = flow_statistic(
                    mot_result, self.secs_interval, self.do_entrance_counting,
630 631 632
                    self.do_break_in_counting, self.region_type, video_fps,
                    entrance, id_set, interval_id_set, in_id_list, out_id_list,
                    prev_center, records)
633 634
                records = statistic['records']

635 636 637 638 639 640 641 642 643 644
                if self.illegal_parking_time != -1:
                    object_in_region_info, illegal_parking_dict = update_object_info(
                        object_in_region_info, mot_result, self.region_type,
                        entrance, video_fps, self.illegal_parking_time)
                    if len(illegal_parking_dict) != 0:
                        # build relationship between id and plate
                        for key, value in illegal_parking_dict.items():
                            plate = self.collector.get_carlp(key)
                            illegal_parking_dict[key]['plate'] = plate

645 646 647
                # nothing detected
                if len(mot_res['boxes']) == 0:
                    frame_id += 1
J
JYChen 已提交
648
                    if frame_id > self.warmup_frame:
649 650 651 652 653 654 655 656 657
                        self.pipe_timer.img_num += 1
                        self.pipe_timer.total_time.end()
                    if self.cfg['visual']:
                        _, _, fps = self.pipe_timer.get_total_time()
                        im = self.visualize_video(frame, mot_res, frame_id, fps,
                                                  entrance, records,
                                                  center_traj)  # visualize
                        writer.write(im)
                        if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
658
                            cv2.imshow('Paddle-Pipeline', im)
659 660 661 662 663
                            if cv2.waitKey(1) & 0xFF == ord('q'):
                                break
                    continue

                self.pipeline_res.update(mot_res, 'mot')
J
JYChen 已提交
664
                crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
665
                    frame_rgb, mot_res)
666

667
                if self.with_vehicleplate and frame_id % 10 == 0:
Z
zhiboniu 已提交
668 669
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].start()
Z
zhiboniu 已提交
670 671
                    plate_input, _, _ = crop_image_with_mot(
                        frame_rgb, mot_res, expand=False)
Z
zhiboniu 已提交
672
                    platelicense = self.vehicleplate_detector.get_platelicense(
Z
zhiboniu 已提交
673
                        plate_input)
Z
zhiboniu 已提交
674 675
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].end()
Z
zhiboniu 已提交
676
                    self.pipeline_res.update(platelicense, 'vehicleplate')
677 678
                else:
                    self.pipeline_res.clear('vehicleplate')
Z
zhiboniu 已提交
679

680
                if self.with_human_attr:
J
JYChen 已提交
681
                    if frame_id > self.warmup_frame:
682 683 684 685 686 687 688
                        self.pipe_timer.module_time['attr'].start()
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['attr'].end()
                    self.pipeline_res.update(attr_res, 'attr')

689 690 691 692 693 694 695 696 697
                if self.with_vehicle_attr:
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].start()
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].end()
                    self.pipeline_res.update(attr_res, 'vehicle_attr')

Z
zhiboniu 已提交
698
                if self.with_idbased_detaction:
J
JYChen 已提交
699 700 701 702 703 704 705 706 707 708
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].start()
                    det_action_res = self.det_action_predictor.predict(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].end()
                    self.pipeline_res.update(det_action_res, 'det_action')

                    if self.cfg['visual']:
                        self.det_action_visual_helper.update(det_action_res)
Z
zhiboniu 已提交
709 710

                if self.with_idbased_clsaction:
J
JYChen 已提交
711 712 713 714 715 716 717 718 719 720
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].start()
                    cls_action_res = self.cls_action_predictor.predict_with_mot(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].end()
                    self.pipeline_res.update(cls_action_res, 'cls_action')

                    if self.cfg['visual']:
                        self.cls_action_visual_helper.update(cls_action_res)
Z
zhiboniu 已提交
721

Z
zhiboniu 已提交
722
                if self.with_skeleton_action:
Z
zhiboniu 已提交
723 724 725 726 727 728 729 730 731 732 733 734 735
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].start()
                    kpt_pred = self.kpt_predictor.predict_image(
                        crop_input, visual=False)
                    keypoint_vector, score_vector = translate_to_ori_images(
                        kpt_pred, np.array(new_bboxes))
                    kpt_res = {}
                    kpt_res['keypoint'] = [
                        keypoint_vector.tolist(), score_vector.tolist()
                    ] if len(keypoint_vector) > 0 else [[], []]
                    kpt_res['bbox'] = ori_bboxes
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].end()
736

Z
zhiboniu 已提交
737
                    self.pipeline_res.update(kpt_res, 'kpt')
738

Z
zhiboniu 已提交
739
                    self.kpt_buff.update(kpt_res, mot_res)  # collect kpt output
740 741 742
                    state = self.kpt_buff.get_state(
                    )  # whether frame num is enough or lost tracker

Z
zhiboniu 已提交
743
                    skeleton_action_res = {}
744 745
                    if state:
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
746 747
                            self.pipe_timer.module_time[
                                'skeleton_action'].start()
748 749
                        collected_keypoint = self.kpt_buff.get_collected_keypoint(
                        )  # reoragnize kpt output with ID
Z
zhiboniu 已提交
750 751 752 753
                        skeleton_action_input = parse_mot_keypoint(
                            collected_keypoint, self.coord_size)
                        skeleton_action_res = self.skeleton_action_predictor.predict_skeleton_with_mot(
                            skeleton_action_input)
754
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
755 756 757
                            self.pipe_timer.module_time['skeleton_action'].end()
                        self.pipeline_res.update(skeleton_action_res,
                                                 'skeleton_action')
758 759

                    if self.cfg['visual']:
Z
zhiboniu 已提交
760 761
                        self.skeleton_action_visual_helper.update(
                            skeleton_action_res)
762 763 764

                if self.with_mtmct and frame_id % 10 == 0:
                    crop_input, img_qualities, rects = self.reid_predictor.crop_image_with_mot(
765
                        frame_rgb, mot_res)
766 767 768 769 770 771
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].start()
                    reid_res = self.reid_predictor.predict_batch(crop_input)

                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].end()
J
JYChen 已提交
772

773 774 775 776 777 778 779 780
                    reid_res_dict = {
                        'features': reid_res,
                        "qualities": img_qualities,
                        "rects": rects
                    }
                    self.pipeline_res.update(reid_res_dict, 'reid')
                else:
                    self.pipeline_res.clear('reid')
Z
zhiboniu 已提交
781

Z
zhiboniu 已提交
782
            if self.with_video_action:
783 784 785 786 787 788 789 790 791 792 793 794 795
                # get the params
                frame_len = self.cfg["VIDEO_ACTION"]["frame_len"]
                sample_freq = self.cfg["VIDEO_ACTION"]["sample_freq"]

                if sample_freq * frame_len > frame_count:  # video is too short
                    sample_freq = int(frame_count / frame_len)

                # filter the warmup frames
                if frame_id > self.warmup_frame:
                    self.pipe_timer.module_time['video_action'].start()

                # collect frames
                if frame_id % sample_freq == 0:
796
                    # Scale image
797
                    scaled_img = scale(frame_rgb)
798
                    video_action_imgs.append(scaled_img)
799 800 801 802 803 804 805 806 807 808 809 810 811 812

                # the number of collected frames is enough to predict video action
                if len(video_action_imgs) == frame_len:
                    classes, scores = self.video_action_predictor.predict(
                        video_action_imgs)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['video_action'].end()

                    video_action_res = {"class": classes[0], "score": scores[0]}
                    self.pipeline_res.update(video_action_res, 'video_action')

                    print("video_action_res:", video_action_res)

                    video_action_imgs.clear()  # next clip
Z
zhiboniu 已提交
813 814

            self.collector.append(frame_id, self.pipeline_res)
815 816 817 818 819 820 821

            if frame_id > self.warmup_frame:
                self.pipe_timer.img_num += 1
                self.pipe_timer.total_time.end()
            frame_id += 1

            if self.cfg['visual']:
822
                _, _, fps = self.pipe_timer.get_total_time()
823 824 825 826 827 828

                im = self.visualize_video(frame, self.pipeline_res,
                                          self.collector, frame_id, fps,
                                          entrance, records, center_traj,
                                          self.illegal_parking_time != -1,
                                          illegal_parking_dict)  # visualize
829
                writer.write(im)
W
wangguanzhong 已提交
830
                if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
831
                    cv2.imshow('Paddle-Pipeline', im)
W
wangguanzhong 已提交
832 833
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
834 835 836 837

        writer.release()
        print('save result to {}'.format(out_path))

838 839 840
    def visualize_video(self,
                        image,
                        result,
841
                        collector,
842 843 844 845
                        frame_id,
                        fps,
                        entrance=None,
                        records=None,
846 847 848
                        center_traj=None,
                        do_illegal_parking_recognition=False,
                        illegal_parking_dict=None):
Z
zhiboniu 已提交
849
        mot_res = copy.deepcopy(result.get('mot'))
850 851
        if mot_res is not None:
            ids = mot_res['boxes'][:, 0]
W
wangguanzhong 已提交
852
            scores = mot_res['boxes'][:, 2]
853 854 855 856 857 858
            boxes = mot_res['boxes'][:, 3:]
            boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
            boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        else:
            boxes = np.zeros([0, 4])
            ids = np.zeros([0])
W
wangguanzhong 已提交
859
            scores = np.zeros([0])
860 861 862 863 864 865 866 867 868 869

        # single class, still need to be defaultdict type for ploting
        num_classes = 1
        online_tlwhs = defaultdict(list)
        online_scores = defaultdict(list)
        online_ids = defaultdict(list)
        online_tlwhs[0] = boxes
        online_scores[0] = scores
        online_ids[0] = ids

F
Feng Ni 已提交
870 871 872 873 874 875 876 877 878
        if mot_res is not None:
            image = plot_tracking_dict(
                image,
                num_classes,
                online_tlwhs,
                online_ids,
                online_scores,
                frame_id=frame_id,
                fps=fps,
879
                ids2names=self.mot_predictor.pred_config.labels,
F
Feng Ni 已提交
880
                do_entrance_counting=self.do_entrance_counting,
881
                do_break_in_counting=self.do_break_in_counting,
882 883
                do_illegal_parking_recognition=do_illegal_parking_recognition,
                illegal_parking_dict=illegal_parking_dict,
F
Feng Ni 已提交
884 885 886
                entrance=entrance,
                records=records,
                center_traj=center_traj)
887

888 889 890 891 892 893 894 895 896
        human_attr_res = result.get('attr')
        if human_attr_res is not None:
            boxes = mot_res['boxes'][:, 1:]
            human_attr_res = human_attr_res['output']
            image = visualize_attr(image, human_attr_res, boxes)
            image = np.array(image)

        vehicle_attr_res = result.get('vehicle_attr')
        if vehicle_attr_res is not None:
897
            boxes = mot_res['boxes'][:, 1:]
898 899
            vehicle_attr_res = vehicle_attr_res['output']
            image = visualize_attr(image, vehicle_attr_res, boxes)
900 901
            image = np.array(image)

902 903 904 905 906 907 908 909 910 911 912 913 914 915
        if mot_res is not None:
            vehicleplate = False
            plates = []
            for trackid in mot_res['boxes'][:, 0]:
                plate = collector.get_carlp(trackid)
                if plate != None:
                    vehicleplate = True
                    plates.append(plate)
                else:
                    plates.append("")
            if vehicleplate:
                boxes = mot_res['boxes'][:, 1:]
                image = visualize_vehicleplate(image, plates, boxes)
                image = np.array(image)
Z
zhiboniu 已提交
916

J
JYChen 已提交
917 918 919 920 921 922 923 924
        kpt_res = result.get('kpt')
        if kpt_res is not None:
            image = visualize_pose(
                image,
                kpt_res,
                visual_thresh=self.cfg['kpt_thresh'],
                returnimg=True)

925
        video_action_res = result.get('video_action')
J
JYChen 已提交
926
        if video_action_res is not None:
927 928 929
            video_action_score = None
            if video_action_res and video_action_res["class"] == 1:
                video_action_score = video_action_res["score"]
930 931 932
            mot_boxes = None
            if mot_res:
                mot_boxes = mot_res['boxes']
933 934
            image = visualize_action(
                image,
935
                mot_boxes,
J
JYChen 已提交
936
                action_visual_collector=None,
937 938 939
                action_text="SkeletonAction",
                video_action_score=video_action_score,
                video_action_text="Fight")
J
JYChen 已提交
940

J
JYChen 已提交
941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963
        visual_helper_for_display = []
        action_to_display = []

        skeleton_action_res = result.get('skeleton_action')
        if skeleton_action_res is not None:
            visual_helper_for_display.append(self.skeleton_action_visual_helper)
            action_to_display.append("Falling")

        det_action_res = result.get('det_action')
        if det_action_res is not None:
            visual_helper_for_display.append(self.det_action_visual_helper)
            action_to_display.append("Smoking")

        cls_action_res = result.get('cls_action')
        if cls_action_res is not None:
            visual_helper_for_display.append(self.cls_action_visual_helper)
            action_to_display.append("Calling")

        if len(visual_helper_for_display) > 0:
            image = visualize_action(image, mot_res['boxes'],
                                     visual_helper_for_display,
                                     action_to_display)

964 965 966 967 968
        return image

    def visualize_image(self, im_files, images, result):
        start_idx, boxes_num_i = 0, 0
        det_res = result.get('det')
969 970
        human_attr_res = result.get('attr')
        vehicle_attr_res = result.get('vehicle_attr')
Z
zhiboniu 已提交
971
        vehicleplate_res = result.get('vehicleplate')
972

973 974 975 976 977 978 979 980 981
        for i, (im_file, im) in enumerate(zip(im_files, images)):
            if det_res is not None:
                det_res_i = {}
                boxes_num_i = det_res['boxes_num'][i]
                det_res_i['boxes'] = det_res['boxes'][start_idx:start_idx +
                                                      boxes_num_i, :]
                im = visualize_box_mask(
                    im,
                    det_res_i,
Z
zhiboniu 已提交
982
                    labels=['target'],
983
                    threshold=self.cfg['crop_thresh'])
984 985
                im = np.ascontiguousarray(np.copy(im))
                im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
986 987 988 989 990 991 992 993
            if human_attr_res is not None:
                human_attr_res_i = human_attr_res['output'][start_idx:start_idx
                                                            + boxes_num_i]
                im = visualize_attr(im, human_attr_res_i, det_res_i['boxes'])
            if vehicle_attr_res is not None:
                vehicle_attr_res_i = vehicle_attr_res['output'][
                    start_idx:start_idx + boxes_num_i]
                im = visualize_attr(im, vehicle_attr_res_i, det_res_i['boxes'])
Z
zhiboniu 已提交
994 995 996 997 998
            if vehicleplate_res is not None:
                plates = vehicleplate_res['vehicleplate']
                det_res_i['boxes'][:, 4:6] = det_res_i[
                    'boxes'][:, 4:6] - det_res_i['boxes'][:, 2:4]
                im = visualize_vehicleplate(im, plates, det_res_i['boxes'])
999

1000 1001 1002 1003
            img_name = os.path.split(im_file)[-1]
            if not os.path.exists(self.output_dir):
                os.makedirs(self.output_dir)
            out_path = os.path.join(self.output_dir, img_name)
1004
            cv2.imwrite(out_path, im)
1005 1006 1007 1008 1009
            print("save result to: " + out_path)
            start_idx += boxes_num_i


def main():
1010
    cfg = merge_cfg(FLAGS)  # use command params to update config
1011
    print_arguments(cfg)
1012

Z
zhiboniu 已提交
1013
    pipeline = Pipeline(FLAGS, cfg)
1014 1015 1016 1017 1018
    pipeline.run()


if __name__ == '__main__':
    paddle.enable_static()
1019 1020

    # parse params from command
1021 1022 1023 1024 1025 1026 1027
    parser = argsparser()
    FLAGS = parser.parse_args()
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"

    main()