pipeline.py 40.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import glob
import cv2
import numpy as np
import math
import paddle
import sys
Z
zhiboniu 已提交
23
import copy
Z
zhiboniu 已提交
24
from collections import Sequence, defaultdict
Z
zhiboniu 已提交
25
from datacollector import DataCollector, Result
26 27 28 29 30

# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)

31 32
from cfg_utils import argsparser, print_arguments, merge_cfg
from pipe_utils import PipeTimer
Z
zhiboniu 已提交
33 34
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint

35
from python.infer import Detector, DetectorPicoDet
J
JYChen 已提交
36 37
from python.keypoint_infer import KeyPointDetector
from python.keypoint_postprocess import translate_to_ori_images
38
from python.preprocess import decode_image, ShortSizeScale
Z
zhiboniu 已提交
39
from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action, visualize_vehicleplate
40 41

from pptracking.python.mot_sde_infer import SDE_Detector
42 43
from pptracking.python.mot.visualize import plot_tracking_dict
from pptracking.python.mot.utils import flow_statistic
44

Z
zhiboniu 已提交
45 46 47 48 49 50 51
from pphuman.attr_infer import AttrDetector
from pphuman.video_action_infer import VideoActionRecognizer
from pphuman.action_infer import SkeletonActionRecognizer, DetActionRecognizer, ClsActionRecognizer
from pphuman.action_utils import KeyPointBuff, ActionVisualHelper
from pphuman.reid import ReID
from pphuman.mtmct import mtmct_process

52 53 54
from ppvehicle.vehicle_plate import PlateRecognizer
from ppvehicle.vehicle_attr import VehicleAttr

55 56
from download import auto_download_model

57 58 59 60 61 62

class Pipeline(object):
    """
    Pipeline

    Args:
J
JYChen 已提交
63
        args (argparse.Namespace): arguments in pipeline, which contains environment and runtime settings
64 65 66
        cfg (dict): config of models in pipeline
    """

Z
zhiboniu 已提交
67
    def __init__(self, args, cfg):
68
        self.multi_camera = False
Z
zhiboniu 已提交
69 70
        reid_cfg = cfg.get('REID', False)
        self.enable_mtmct = reid_cfg['enable'] if reid_cfg else False
71
        self.is_video = False
Z
zhiboniu 已提交
72
        self.output_dir = args.output_dir
Z
zhiboniu 已提交
73
        self.vis_result = cfg['visual']
Z
zhiboniu 已提交
74 75 76
        self.input = self._parse_input(args.image_file, args.image_dir,
                                       args.video_file, args.video_dir,
                                       args.camera_id)
77
        if self.multi_camera:
78 79 80
            self.predictor = []
            for name in self.input:
                predictor_item = PipePredictor(
Z
zhiboniu 已提交
81
                    args, cfg, is_video=True, multi_camera=True)
82 83 84
                predictor_item.set_file_name(name)
                self.predictor.append(predictor_item)

85
        else:
Z
zhiboniu 已提交
86
            self.predictor = PipePredictor(args, cfg, self.is_video)
87
            if self.is_video:
Z
zhiboniu 已提交
88
                self.predictor.set_file_name(args.video_file)
89

Z
zhiboniu 已提交
90 91
    def _parse_input(self, image_file, image_dir, video_file, video_dir,
                     camera_id):
92 93 94 95 96 97 98 99 100

        # parse input as is_video and multi_camera

        if image_file is not None or image_dir is not None:
            input = get_test_images(image_dir, image_file)
            self.is_video = False
            self.multi_camera = False

        elif video_file is not None:
Z
zhiboniu 已提交
101 102 103
            assert os.path.exists(
                video_file
            ) or 'rtsp' in video_file, "video_file not exists and not an rtsp site."
Z
zhiboniu 已提交
104 105 106 107 108 109 110
            self.multi_camera = False
            input = video_file
            self.is_video = True

        elif video_dir is not None:
            videof = [os.path.join(video_dir, x) for x in os.listdir(video_dir)]
            if len(videof) > 1:
111
                self.multi_camera = True
Z
zhiboniu 已提交
112 113
                videof.sort()
                input = videof
114
            else:
Z
zhiboniu 已提交
115
                input = videof[0]
116 117 118
            self.is_video = True

        elif camera_id != -1:
Z
zhiboniu 已提交
119 120
            self.multi_camera = False
            input = camera_id
121 122 123 124
            self.is_video = True

        else:
            raise ValueError(
125
                "Illegal Input, please set one of ['video_file', 'camera_id', 'image_file', 'image_dir']"
126 127 128 129 130 131 132 133 134
            )

        return input

    def run(self):
        if self.multi_camera:
            multi_res = []
            for predictor, input in zip(self.predictor, self.input):
                predictor.run(input)
Z
zhiboniu 已提交
135 136
                collector_data = predictor.get_result()
                multi_res.append(collector_data)
137 138 139 140 141 142
            if self.enable_mtmct:
                mtmct_process(
                    multi_res,
                    self.input,
                    mtmct_vis=self.vis_result,
                    output_dir=self.output_dir)
143 144 145 146 147

        else:
            self.predictor.run(self.input)


148
def get_model_dir(cfg):
J
JYChen 已提交
149 150 151 152
    """ 
        Auto download inference model if the model_path is a url link. 
        Otherwise it will use the model_path directly.
    """
153 154 155 156 157 158 159 160 161 162
    for key in cfg.keys():
        if type(cfg[key]) ==  dict and \
            ("enable" in cfg[key].keys() and cfg[key]['enable']
                or "enable" not in cfg[key].keys()):

            if "model_dir" in cfg[key].keys():
                model_dir = cfg[key]["model_dir"]
                downloaded_model_dir = auto_download_model(model_dir)
                if downloaded_model_dir:
                    model_dir = downloaded_model_dir
J
JYChen 已提交
163 164
                    cfg[key]["model_dir"] = model_dir
                print(key, " model dir: ", model_dir)
165 166 167 168 169
            elif key == "VEHICLE_PLATE":
                det_model_dir = cfg[key]["det_model_dir"]
                downloaded_det_model_dir = auto_download_model(det_model_dir)
                if downloaded_det_model_dir:
                    det_model_dir = downloaded_det_model_dir
J
JYChen 已提交
170 171
                    cfg[key]["det_model_dir"] = det_model_dir
                print("det_model_dir model dir: ", det_model_dir)
172 173 174 175 176

                rec_model_dir = cfg[key]["rec_model_dir"]
                downloaded_rec_model_dir = auto_download_model(rec_model_dir)
                if downloaded_rec_model_dir:
                    rec_model_dir = downloaded_rec_model_dir
J
JYChen 已提交
177 178 179
                    cfg[key]["rec_model_dir"] = rec_model_dir
                print("rec_model_dir model dir: ", rec_model_dir)

180 181 182 183 184
        elif key == "MOT":  # for idbased and skeletonbased actions
            model_dir = cfg[key]["model_dir"]
            downloaded_model_dir = auto_download_model(model_dir)
            if downloaded_model_dir:
                model_dir = downloaded_model_dir
J
JYChen 已提交
185 186
                cfg[key]["model_dir"] = model_dir
            print("mot_model_dir model_dir: ", model_dir)
187 188


189 190 191 192 193 194 195 196 197 198 199 200 201
class PipePredictor(object):
    """
    Predictor in single camera
    
    The pipeline for image input: 

        1. Detection
        2. Detection -> Attribute

    The pipeline for video input: 

        1. Tracking
        2. Tracking -> Attribute
Z
zhiboniu 已提交
202
        3. Tracking -> KeyPoint -> SkeletonAction Recognition
203
        4. VideoAction Recognition
204 205

    Args:
J
JYChen 已提交
206
        args (argparse.Namespace): arguments in pipeline, which contains environment and runtime settings
207 208 209 210 211 212
        cfg (dict): config of models in pipeline
        is_video (bool): whether the input is video, default as False
        multi_camera (bool): whether to use multi camera in pipeline, 
            default as False
    """

Z
zhiboniu 已提交
213 214 215 216
    def __init__(self, args, cfg, is_video=True, multi_camera=False):
        # general module for pphuman and ppvehicle
        self.with_mot = cfg.get('MOT', False)['enable'] if cfg.get(
            'MOT', False) else False
217
        self.with_human_attr = cfg.get('ATTR', False)['enable'] if cfg.get(
Z
zhiboniu 已提交
218
            'ATTR', False) else False
Z
zhiboniu 已提交
219 220
        if self.with_mot:
            print('Multi-Object Tracking enabled')
221 222
        if self.with_human_attr:
            print('Human Attribute Recognition enabled')
Z
zhiboniu 已提交
223 224

        # only for pphuman
Z
zhiboniu 已提交
225 226 227
        self.with_skeleton_action = cfg.get(
            'SKELETON_ACTION', False)['enable'] if cfg.get('SKELETON_ACTION',
                                                           False) else False
Z
zhiboniu 已提交
228 229 230 231 232 233 234 235 236
        self.with_video_action = cfg.get(
            'VIDEO_ACTION', False)['enable'] if cfg.get('VIDEO_ACTION',
                                                        False) else False
        self.with_idbased_detaction = cfg.get(
            'ID_BASED_DETACTION', False)['enable'] if cfg.get(
                'ID_BASED_DETACTION', False) else False
        self.with_idbased_clsaction = cfg.get(
            'ID_BASED_CLSACTION', False)['enable'] if cfg.get(
                'ID_BASED_CLSACTION', False) else False
Z
zhiboniu 已提交
237 238
        self.with_mtmct = cfg.get('REID', False)['enable'] if cfg.get(
            'REID', False) else False
239

Z
zhiboniu 已提交
240 241
        if self.with_skeleton_action:
            print('SkeletonAction Recognition enabled')
Z
zhiboniu 已提交
242 243 244 245 246 247
        if self.with_video_action:
            print('VideoAction Recognition enabled')
        if self.with_idbased_detaction:
            print('IDBASED Detection Action Recognition enabled')
        if self.with_idbased_clsaction:
            print('IDBASED Classification Action Recognition enabled')
Z
zhiboniu 已提交
248 249
        if self.with_mtmct:
            print("MTMCT enabled")
W
wangguanzhong 已提交
250

Z
zhiboniu 已提交
251 252 253 254 255 256 257
        # only for ppvehicle
        self.with_vehicleplate = cfg.get(
            'VEHICLE_PLATE', False)['enable'] if cfg.get('VEHICLE_PLATE',
                                                         False) else False
        if self.with_vehicleplate:
            print('Vehicle Plate Recognition enabled')

258 259 260 261 262 263
        self.with_vehicle_attr = cfg.get(
            'VEHICLE_ATTR', False)['enable'] if cfg.get('VEHICLE_ATTR',
                                                        False) else False
        if self.with_vehicle_attr:
            print('Vehicle Attribute Recognition enabled')

264 265 266 267 268 269
        self.modebase = {
            "framebased": False,
            "videobased": False,
            "idbased": False,
            "skeletonbased": False
        }
270

271 272 273 274 275 276 277 278 279 280 281 282
        self.basemode = {
            "MOT": "idbased",
            "ATTR": "idbased",
            "VIDEO_ACTION": "videobased",
            "SKELETON_ACTION": "skeletonbased",
            "ID_BASED_DETACTION": "idbased",
            "ID_BASED_CLSACTION": "idbased",
            "REID": "idbased",
            "VEHICLE_PLATE": "idbased",
            "VEHICLE_ATTR": "idbased",
        }

283 284 285
        self.is_video = is_video
        self.multi_camera = multi_camera
        self.cfg = cfg
J
JYChen 已提交
286 287 288 289 290 291 292
        self.output_dir = args.output_dir
        self.draw_center_traj = args.draw_center_traj
        self.secs_interval = args.secs_interval
        self.do_entrance_counting = args.do_entrance_counting
        self.do_break_in_counting = args.do_break_in_counting
        self.region_type = args.region_type
        self.region_polygon = args.region_polygon
293

J
JYChen 已提交
294
        self.warmup_frame = self.cfg['warmup_frame']
295 296
        self.pipeline_res = Result()
        self.pipe_timer = PipeTimer()
297
        self.file_name = None
Z
zhiboniu 已提交
298
        self.collector = DataCollector()
299

300
        # auto download inference model
J
JYChen 已提交
301
        get_model_dir(self.cfg)
302

Z
zhiboniu 已提交
303 304 305 306 307 308 309 310 311 312
        if self.with_vehicleplate:
            vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
            self.vehicleplate_detector = PlateRecognizer(args, vehicleplate_cfg)
            basemode = self.basemode['VEHICLE_PLATE']
            self.modebase[basemode] = True

        if self.with_human_attr:
            attr_cfg = self.cfg['ATTR']
            basemode = self.basemode['ATTR']
            self.modebase[basemode] = True
J
JYChen 已提交
313
            self.attr_predictor = AttrDetector.init_with_cfg(args, attr_cfg)
Z
zhiboniu 已提交
314 315 316 317 318

        if self.with_vehicle_attr:
            vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
            basemode = self.basemode['VEHICLE_ATTR']
            self.modebase[basemode] = True
J
JYChen 已提交
319 320
            self.vehicle_attr_predictor = VehicleAttr.init_with_cfg(
                args, vehicleattr_cfg)
Z
zhiboniu 已提交
321

322 323
        if not is_video:
            det_cfg = self.cfg['DET']
J
JYChen 已提交
324
            model_dir = det_cfg['model_dir']
325 326
            batch_size = det_cfg['batch_size']
            self.det_predictor = Detector(
J
JYChen 已提交
327 328 329
                model_dir, args.device, args.run_mode, batch_size,
                args.trt_min_shape, args.trt_max_shape, args.trt_opt_shape,
                args.trt_calib_mode, args.cpu_threads, args.enable_mkldnn)
330
        else:
Z
zhiboniu 已提交
331
            if self.with_idbased_detaction:
J
JYChen 已提交
332
                idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION']
333
                basemode = self.basemode['ID_BASED_DETACTION']
J
JYChen 已提交
334
                self.modebase[basemode] = True
335

J
JYChen 已提交
336 337
                self.det_action_predictor = DetActionRecognizer.init_with_cfg(
                    args, idbased_detaction_cfg)
J
JYChen 已提交
338 339
                self.det_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
340
            if self.with_idbased_clsaction:
J
JYChen 已提交
341
                idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION']
342
                basemode = self.basemode['ID_BASED_CLSACTION']
J
JYChen 已提交
343
                self.modebase[basemode] = True
344

J
JYChen 已提交
345 346
                self.cls_action_predictor = ClsActionRecognizer.init_with_cfg(
                    args, idbased_clsaction_cfg)
J
JYChen 已提交
347 348
                self.cls_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
349 350 351 352
            if self.with_skeleton_action:
                skeleton_action_cfg = self.cfg['SKELETON_ACTION']
                display_frames = skeleton_action_cfg['display_frames']
                self.coord_size = skeleton_action_cfg['coord_size']
353
                basemode = self.basemode['SKELETON_ACTION']
354
                self.modebase[basemode] = True
J
JYChen 已提交
355
                skeleton_action_frames = skeleton_action_cfg['max_frames']
356

J
JYChen 已提交
357 358
                self.skeleton_action_predictor = SkeletonActionRecognizer.init_with_cfg(
                    args, skeleton_action_cfg)
J
JYChen 已提交
359
                self.skeleton_action_visual_helper = ActionVisualHelper(
Z
zhiboniu 已提交
360
                    display_frames)
361

J
JYChen 已提交
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
                kpt_cfg = self.cfg['KPT']
                kpt_model_dir = kpt_cfg['model_dir']
                kpt_batch_size = kpt_cfg['batch_size']
                self.kpt_predictor = KeyPointDetector(
                    kpt_model_dir,
                    args.device,
                    args.run_mode,
                    kpt_batch_size,
                    args.trt_min_shape,
                    args.trt_max_shape,
                    args.trt_opt_shape,
                    args.trt_calib_mode,
                    args.cpu_threads,
                    args.enable_mkldnn,
                    use_dark=False)
                self.kpt_buff = KeyPointBuff(skeleton_action_frames)
Z
zhiboniu 已提交
378

Z
zhiboniu 已提交
379 380
            if self.with_mtmct:
                reid_cfg = self.cfg['REID']
381
                basemode = self.basemode['REID']
Z
zhiboniu 已提交
382
                self.modebase[basemode] = True
J
JYChen 已提交
383
                self.reid_predictor = ReID.init_with_cfg(args, reid_cfg)
Z
zhiboniu 已提交
384

Z
zhiboniu 已提交
385 386 387
            if self.with_mot or self.modebase["idbased"] or self.modebase[
                    "skeletonbased"]:
                mot_cfg = self.cfg['MOT']
J
JYChen 已提交
388
                model_dir = mot_cfg['model_dir']
Z
zhiboniu 已提交
389 390
                tracker_config = mot_cfg['tracker_config']
                batch_size = mot_cfg['batch_size']
391
                basemode = self.basemode['MOT']
Z
zhiboniu 已提交
392 393 394 395
                self.modebase[basemode] = True
                self.mot_predictor = SDE_Detector(
                    model_dir,
                    tracker_config,
J
JYChen 已提交
396 397
                    args.device,
                    args.run_mode,
Z
zhiboniu 已提交
398
                    batch_size,
J
JYChen 已提交
399 400 401 402 403 404 405 406 407 408 409 410
                    args.trt_min_shape,
                    args.trt_max_shape,
                    args.trt_opt_shape,
                    args.trt_calib_mode,
                    args.cpu_threads,
                    args.enable_mkldnn,
                    draw_center_traj=self.draw_center_traj,
                    secs_interval=self.secs_interval,
                    do_entrance_counting=self.do_entrance_counting,
                    do_break_in_counting=self.do_break_in_counting,
                    region_type=self.region_type,
                    region_polygon=self.region_polygon)
Z
zhiboniu 已提交
411

412 413
            if self.with_video_action:
                video_action_cfg = self.cfg['VIDEO_ACTION']
414
                basemode = self.basemode['VIDEO_ACTION']
415
                self.modebase[basemode] = True
J
JYChen 已提交
416 417
                self.video_action_predictor = VideoActionRecognizer.init_with_cfg(
                    args, video_action_cfg)
418

419
    def set_file_name(self, path):
W
wangguanzhong 已提交
420 421 422 423 424
        if path is not None:
            self.file_name = os.path.split(path)[-1]
        else:
            # use camera id
            self.file_name = None
425

426
    def get_result(self):
Z
zhiboniu 已提交
427
        return self.collector.get_res()
428 429 430 431 432 433

    def run(self, input):
        if self.is_video:
            self.predict_video(input)
        else:
            self.predict_image(input)
434
        self.pipe_timer.info()
435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452

    def predict_image(self, input):
        # det
        # det -> attr
        batch_loop_cnt = math.ceil(
            float(len(input)) / self.det_predictor.batch_size)
        for i in range(batch_loop_cnt):
            start_index = i * self.det_predictor.batch_size
            end_index = min((i + 1) * self.det_predictor.batch_size, len(input))
            batch_file = input[start_index:end_index]
            batch_input = [decode_image(f, {})[0] for f in batch_file]

            if i > self.warmup_frame:
                self.pipe_timer.total_time.start()
                self.pipe_timer.module_time['det'].start()
            # det output format: class, score, xmin, ymin, xmax, ymax
            det_res = self.det_predictor.predict_image(
                batch_input, visual=False)
453 454
            det_res = self.det_predictor.filter_box(det_res,
                                                    self.cfg['crop_thresh'])
455 456 457 458
            if i > self.warmup_frame:
                self.pipe_timer.module_time['det'].end()
            self.pipeline_res.update(det_res, 'det')

459
            if self.with_human_attr:
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
                crop_inputs = crop_image_with_det(batch_input, det_res)
                attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].end()

                attr_res = {'output': attr_res_list}
                self.pipeline_res.update(attr_res, 'attr')

477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
            if self.with_vehicle_attr:
                crop_inputs = crop_image_with_det(batch_input, det_res)
                vehicle_attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    vehicle_attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].end()

                attr_res = {'output': vehicle_attr_res_list}
                self.pipeline_res.update(attr_res, 'vehicle_attr')

Z
zhiboniu 已提交
495 496 497 498 499 500 501 502 503 504 505 506 507 508
            if self.with_vehicleplate:
                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicleplate'].start()
                crop_inputs = crop_image_with_det(batch_input, det_res)
                platelicenses = []
                for crop_input in crop_inputs:
                    platelicense = self.vehicleplate_detector.get_platelicense(
                        crop_input)
                    platelicenses.extend(platelicense['plate'])
                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicleplate'].end()
                vehicleplate_res = {'vehicleplate': platelicenses}
                self.pipeline_res.update(vehicleplate_res, 'vehicleplate')

509 510 511 512 513 514 515
            self.pipe_timer.img_num += len(batch_input)
            if i > self.warmup_frame:
                self.pipe_timer.total_time.end()

            if self.cfg['visual']:
                self.visualize_image(batch_file, batch_input, self.pipeline_res)

Z
zhiboniu 已提交
516
    def predict_video(self, video_file):
517 518 519
        # mot
        # mot -> attr
        # mot -> pose -> action
Z
zhiboniu 已提交
520
        capture = cv2.VideoCapture(video_file)
521
        video_out_name = 'output.mp4' if self.file_name is None else self.file_name
Z
zhiboniu 已提交
522 523
        if "rtsp" in video_file:
            video_out_name = video_out_name + "_rtsp.mp4"
524 525 526 527 528 529

        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
530
        print("video fps: %d, frame_count: %d" % (fps, frame_count))
531 532 533 534 535 536 537

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
        fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
        frame_id = 0
538 539 540 541 542 543 544 545 546 547

        entrance, records, center_traj = None, None, None
        if self.draw_center_traj:
            center_traj = [{}]
        id_set = set()
        interval_id_set = set()
        in_id_list = list()
        out_id_list = list()
        prev_center = dict()
        records = list()
548 549 550 551 552 553 554 555 556 557
        if self.do_entrance_counting or self.do_break_in_counting:
            if self.region_type == 'horizontal':
                entrance = [0, height / 2., width, height / 2.]
            elif self.region_type == 'vertical':
                entrance = [width / 2, 0., width / 2, height]
            elif self.region_type == 'custom':
                entrance = []
                assert len(
                    self.region_polygon
                ) % 2 == 0, "region_polygon should be pairs of coords points when do break_in counting."
J
JYChen 已提交
558 559 560 561
                assert len(
                    self.region_polygon
                ) > 6, 'region_type is custom, region_polygon should be at least 3 pairs of point coords.'

562 563 564 565 566 567 568 569
                for i in range(0, len(self.region_polygon), 2):
                    entrance.append(
                        [self.region_polygon[i], self.region_polygon[i + 1]])
                entrance.append([width, height])
            else:
                raise ValueError("region_type:{} unsupported.".format(
                    self.region_type))

570 571
        video_fps = fps

572 573
        video_action_imgs = []

574 575 576 577
        if self.with_video_action:
            short_size = self.cfg["VIDEO_ACTION"]["short_size"]
            scale = ShortSizeScale(short_size)

578 579 580
        while (1):
            if frame_id % 10 == 0:
                print('frame id: ', frame_id)
581

582 583 584
            ret, frame = capture.read()
            if not ret:
                break
585
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
586

587
            if self.modebase["idbased"] or self.modebase["skeletonbased"]:
588
                if frame_id > self.warmup_frame:
589 590 591
                    self.pipe_timer.total_time.start()
                    self.pipe_timer.module_time['mot'].start()
                res = self.mot_predictor.predict_image(
592
                    [copy.deepcopy(frame_rgb)], visual=False)
593

J
JYChen 已提交
594
                if frame_id > self.warmup_frame:
595 596 597 598 599 600 601 602 603 604 605
                    self.pipe_timer.module_time['mot'].end()

                # mot output format: id, class, score, xmin, ymin, xmax, ymax
                mot_res = parse_mot_res(res)

                # flow_statistic only support single class MOT
                boxes, scores, ids = res[0]  # batch size = 1 in MOT
                mot_result = (frame_id + 1, boxes[0], scores[0],
                              ids[0])  # single class
                statistic = flow_statistic(
                    mot_result, self.secs_interval, self.do_entrance_counting,
606 607 608
                    self.do_break_in_counting, self.region_type, video_fps,
                    entrance, id_set, interval_id_set, in_id_list, out_id_list,
                    prev_center, records)
609 610 611 612 613
                records = statistic['records']

                # nothing detected
                if len(mot_res['boxes']) == 0:
                    frame_id += 1
J
JYChen 已提交
614
                    if frame_id > self.warmup_frame:
615 616 617 618 619 620 621 622 623
                        self.pipe_timer.img_num += 1
                        self.pipe_timer.total_time.end()
                    if self.cfg['visual']:
                        _, _, fps = self.pipe_timer.get_total_time()
                        im = self.visualize_video(frame, mot_res, frame_id, fps,
                                                  entrance, records,
                                                  center_traj)  # visualize
                        writer.write(im)
                        if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
624
                            cv2.imshow('Paddle-Pipeline', im)
625 626 627 628 629
                            if cv2.waitKey(1) & 0xFF == ord('q'):
                                break
                    continue

                self.pipeline_res.update(mot_res, 'mot')
J
JYChen 已提交
630
                crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
631
                    frame_rgb, mot_res)
632

633
                if self.with_vehicleplate and frame_id % 10 == 0:
Z
zhiboniu 已提交
634 635
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].start()
Z
zhiboniu 已提交
636 637
                    plate_input, _, _ = crop_image_with_mot(
                        frame_rgb, mot_res, expand=False)
Z
zhiboniu 已提交
638
                    platelicense = self.vehicleplate_detector.get_platelicense(
Z
zhiboniu 已提交
639
                        plate_input)
Z
zhiboniu 已提交
640 641
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].end()
Z
zhiboniu 已提交
642
                    self.pipeline_res.update(platelicense, 'vehicleplate')
643 644
                else:
                    self.pipeline_res.clear('vehicleplate')
Z
zhiboniu 已提交
645

646
                if self.with_human_attr:
J
JYChen 已提交
647
                    if frame_id > self.warmup_frame:
648 649 650 651 652 653 654
                        self.pipe_timer.module_time['attr'].start()
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['attr'].end()
                    self.pipeline_res.update(attr_res, 'attr')

655 656 657 658 659 660 661 662 663
                if self.with_vehicle_attr:
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].start()
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].end()
                    self.pipeline_res.update(attr_res, 'vehicle_attr')

Z
zhiboniu 已提交
664
                if self.with_idbased_detaction:
J
JYChen 已提交
665 666 667 668 669 670 671 672 673 674
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].start()
                    det_action_res = self.det_action_predictor.predict(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].end()
                    self.pipeline_res.update(det_action_res, 'det_action')

                    if self.cfg['visual']:
                        self.det_action_visual_helper.update(det_action_res)
Z
zhiboniu 已提交
675 676

                if self.with_idbased_clsaction:
J
JYChen 已提交
677 678 679 680 681 682 683 684 685 686
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].start()
                    cls_action_res = self.cls_action_predictor.predict_with_mot(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].end()
                    self.pipeline_res.update(cls_action_res, 'cls_action')

                    if self.cfg['visual']:
                        self.cls_action_visual_helper.update(cls_action_res)
Z
zhiboniu 已提交
687

Z
zhiboniu 已提交
688
                if self.with_skeleton_action:
Z
zhiboniu 已提交
689 690 691 692 693 694 695 696 697 698 699 700 701
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].start()
                    kpt_pred = self.kpt_predictor.predict_image(
                        crop_input, visual=False)
                    keypoint_vector, score_vector = translate_to_ori_images(
                        kpt_pred, np.array(new_bboxes))
                    kpt_res = {}
                    kpt_res['keypoint'] = [
                        keypoint_vector.tolist(), score_vector.tolist()
                    ] if len(keypoint_vector) > 0 else [[], []]
                    kpt_res['bbox'] = ori_bboxes
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].end()
702

Z
zhiboniu 已提交
703
                    self.pipeline_res.update(kpt_res, 'kpt')
704

Z
zhiboniu 已提交
705
                    self.kpt_buff.update(kpt_res, mot_res)  # collect kpt output
706 707 708
                    state = self.kpt_buff.get_state(
                    )  # whether frame num is enough or lost tracker

Z
zhiboniu 已提交
709
                    skeleton_action_res = {}
710 711
                    if state:
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
712 713
                            self.pipe_timer.module_time[
                                'skeleton_action'].start()
714 715
                        collected_keypoint = self.kpt_buff.get_collected_keypoint(
                        )  # reoragnize kpt output with ID
Z
zhiboniu 已提交
716 717 718 719
                        skeleton_action_input = parse_mot_keypoint(
                            collected_keypoint, self.coord_size)
                        skeleton_action_res = self.skeleton_action_predictor.predict_skeleton_with_mot(
                            skeleton_action_input)
720
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
721 722 723
                            self.pipe_timer.module_time['skeleton_action'].end()
                        self.pipeline_res.update(skeleton_action_res,
                                                 'skeleton_action')
724 725

                    if self.cfg['visual']:
Z
zhiboniu 已提交
726 727
                        self.skeleton_action_visual_helper.update(
                            skeleton_action_res)
728 729 730

                if self.with_mtmct and frame_id % 10 == 0:
                    crop_input, img_qualities, rects = self.reid_predictor.crop_image_with_mot(
731
                        frame_rgb, mot_res)
732 733 734 735 736 737
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].start()
                    reid_res = self.reid_predictor.predict_batch(crop_input)

                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].end()
J
JYChen 已提交
738

739 740 741 742 743 744 745 746
                    reid_res_dict = {
                        'features': reid_res,
                        "qualities": img_qualities,
                        "rects": rects
                    }
                    self.pipeline_res.update(reid_res_dict, 'reid')
                else:
                    self.pipeline_res.clear('reid')
Z
zhiboniu 已提交
747

Z
zhiboniu 已提交
748
            if self.with_video_action:
749 750 751 752 753 754 755 756 757 758 759 760 761
                # get the params
                frame_len = self.cfg["VIDEO_ACTION"]["frame_len"]
                sample_freq = self.cfg["VIDEO_ACTION"]["sample_freq"]

                if sample_freq * frame_len > frame_count:  # video is too short
                    sample_freq = int(frame_count / frame_len)

                # filter the warmup frames
                if frame_id > self.warmup_frame:
                    self.pipe_timer.module_time['video_action'].start()

                # collect frames
                if frame_id % sample_freq == 0:
762
                    # Scale image
763
                    scaled_img = scale(frame_rgb)
764
                    video_action_imgs.append(scaled_img)
765 766 767 768 769 770 771 772 773 774 775 776 777 778

                # the number of collected frames is enough to predict video action
                if len(video_action_imgs) == frame_len:
                    classes, scores = self.video_action_predictor.predict(
                        video_action_imgs)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['video_action'].end()

                    video_action_res = {"class": classes[0], "score": scores[0]}
                    self.pipeline_res.update(video_action_res, 'video_action')

                    print("video_action_res:", video_action_res)

                    video_action_imgs.clear()  # next clip
Z
zhiboniu 已提交
779 780

            self.collector.append(frame_id, self.pipeline_res)
781 782 783 784 785 786 787

            if frame_id > self.warmup_frame:
                self.pipe_timer.img_num += 1
                self.pipe_timer.total_time.end()
            frame_id += 1

            if self.cfg['visual']:
788
                _, _, fps = self.pipe_timer.get_total_time()
789 790 791
                im = self.visualize_video(
                    frame, self.pipeline_res, self.collector, frame_id, fps,
                    entrance, records, center_traj)  # visualize
792
                writer.write(im)
W
wangguanzhong 已提交
793
                if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
794
                    cv2.imshow('Paddle-Pipeline', im)
W
wangguanzhong 已提交
795 796
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
797 798 799 800

        writer.release()
        print('save result to {}'.format(out_path))

801 802 803
    def visualize_video(self,
                        image,
                        result,
804
                        collector,
805 806 807 808 809
                        frame_id,
                        fps,
                        entrance=None,
                        records=None,
                        center_traj=None):
Z
zhiboniu 已提交
810
        mot_res = copy.deepcopy(result.get('mot'))
811 812
        if mot_res is not None:
            ids = mot_res['boxes'][:, 0]
W
wangguanzhong 已提交
813
            scores = mot_res['boxes'][:, 2]
814 815 816 817 818 819
            boxes = mot_res['boxes'][:, 3:]
            boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
            boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        else:
            boxes = np.zeros([0, 4])
            ids = np.zeros([0])
W
wangguanzhong 已提交
820
            scores = np.zeros([0])
821 822 823 824 825 826 827 828 829 830

        # single class, still need to be defaultdict type for ploting
        num_classes = 1
        online_tlwhs = defaultdict(list)
        online_scores = defaultdict(list)
        online_ids = defaultdict(list)
        online_tlwhs[0] = boxes
        online_scores[0] = scores
        online_ids[0] = ids

F
Feng Ni 已提交
831 832 833 834 835 836 837 838 839
        if mot_res is not None:
            image = plot_tracking_dict(
                image,
                num_classes,
                online_tlwhs,
                online_ids,
                online_scores,
                frame_id=frame_id,
                fps=fps,
840
                ids2names=self.mot_predictor.pred_config.labels,
F
Feng Ni 已提交
841
                do_entrance_counting=self.do_entrance_counting,
842
                do_break_in_counting=self.do_break_in_counting,
F
Feng Ni 已提交
843 844 845
                entrance=entrance,
                records=records,
                center_traj=center_traj)
846

847 848 849 850 851 852 853 854 855
        human_attr_res = result.get('attr')
        if human_attr_res is not None:
            boxes = mot_res['boxes'][:, 1:]
            human_attr_res = human_attr_res['output']
            image = visualize_attr(image, human_attr_res, boxes)
            image = np.array(image)

        vehicle_attr_res = result.get('vehicle_attr')
        if vehicle_attr_res is not None:
856
            boxes = mot_res['boxes'][:, 1:]
857 858
            vehicle_attr_res = vehicle_attr_res['output']
            image = visualize_attr(image, vehicle_attr_res, boxes)
859 860
            image = np.array(image)

861 862 863 864 865 866 867 868 869 870 871 872 873 874
        if mot_res is not None:
            vehicleplate = False
            plates = []
            for trackid in mot_res['boxes'][:, 0]:
                plate = collector.get_carlp(trackid)
                if plate != None:
                    vehicleplate = True
                    plates.append(plate)
                else:
                    plates.append("")
            if vehicleplate:
                boxes = mot_res['boxes'][:, 1:]
                image = visualize_vehicleplate(image, plates, boxes)
                image = np.array(image)
Z
zhiboniu 已提交
875

J
JYChen 已提交
876 877 878 879 880 881 882 883
        kpt_res = result.get('kpt')
        if kpt_res is not None:
            image = visualize_pose(
                image,
                kpt_res,
                visual_thresh=self.cfg['kpt_thresh'],
                returnimg=True)

884
        video_action_res = result.get('video_action')
J
JYChen 已提交
885
        if video_action_res is not None:
886 887 888
            video_action_score = None
            if video_action_res and video_action_res["class"] == 1:
                video_action_score = video_action_res["score"]
889 890 891
            mot_boxes = None
            if mot_res:
                mot_boxes = mot_res['boxes']
892 893
            image = visualize_action(
                image,
894
                mot_boxes,
J
JYChen 已提交
895
                action_visual_collector=None,
896 897 898
                action_text="SkeletonAction",
                video_action_score=video_action_score,
                video_action_text="Fight")
J
JYChen 已提交
899

J
JYChen 已提交
900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922
        visual_helper_for_display = []
        action_to_display = []

        skeleton_action_res = result.get('skeleton_action')
        if skeleton_action_res is not None:
            visual_helper_for_display.append(self.skeleton_action_visual_helper)
            action_to_display.append("Falling")

        det_action_res = result.get('det_action')
        if det_action_res is not None:
            visual_helper_for_display.append(self.det_action_visual_helper)
            action_to_display.append("Smoking")

        cls_action_res = result.get('cls_action')
        if cls_action_res is not None:
            visual_helper_for_display.append(self.cls_action_visual_helper)
            action_to_display.append("Calling")

        if len(visual_helper_for_display) > 0:
            image = visualize_action(image, mot_res['boxes'],
                                     visual_helper_for_display,
                                     action_to_display)

923 924 925 926 927
        return image

    def visualize_image(self, im_files, images, result):
        start_idx, boxes_num_i = 0, 0
        det_res = result.get('det')
928 929
        human_attr_res = result.get('attr')
        vehicle_attr_res = result.get('vehicle_attr')
Z
zhiboniu 已提交
930
        vehicleplate_res = result.get('vehicleplate')
931

932 933 934 935 936 937 938 939 940
        for i, (im_file, im) in enumerate(zip(im_files, images)):
            if det_res is not None:
                det_res_i = {}
                boxes_num_i = det_res['boxes_num'][i]
                det_res_i['boxes'] = det_res['boxes'][start_idx:start_idx +
                                                      boxes_num_i, :]
                im = visualize_box_mask(
                    im,
                    det_res_i,
Z
zhiboniu 已提交
941
                    labels=['target'],
942
                    threshold=self.cfg['crop_thresh'])
943 944
                im = np.ascontiguousarray(np.copy(im))
                im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
945 946 947 948 949 950 951 952
            if human_attr_res is not None:
                human_attr_res_i = human_attr_res['output'][start_idx:start_idx
                                                            + boxes_num_i]
                im = visualize_attr(im, human_attr_res_i, det_res_i['boxes'])
            if vehicle_attr_res is not None:
                vehicle_attr_res_i = vehicle_attr_res['output'][
                    start_idx:start_idx + boxes_num_i]
                im = visualize_attr(im, vehicle_attr_res_i, det_res_i['boxes'])
Z
zhiboniu 已提交
953 954 955 956 957
            if vehicleplate_res is not None:
                plates = vehicleplate_res['vehicleplate']
                det_res_i['boxes'][:, 4:6] = det_res_i[
                    'boxes'][:, 4:6] - det_res_i['boxes'][:, 2:4]
                im = visualize_vehicleplate(im, plates, det_res_i['boxes'])
958

959 960 961 962
            img_name = os.path.split(im_file)[-1]
            if not os.path.exists(self.output_dir):
                os.makedirs(self.output_dir)
            out_path = os.path.join(self.output_dir, img_name)
963
            cv2.imwrite(out_path, im)
964 965 966 967 968
            print("save result to: " + out_path)
            start_idx += boxes_num_i


def main():
969
    cfg = merge_cfg(FLAGS)  # use command params to update config
970
    print_arguments(cfg)
971

Z
zhiboniu 已提交
972
    pipeline = Pipeline(FLAGS, cfg)
973 974 975 976 977
    pipeline.run()


if __name__ == '__main__':
    paddle.enable_static()
978 979

    # parse params from command
980 981 982 983 984 985 986
    parser = argsparser()
    FLAGS = parser.parse_args()
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"

    main()