pipeline.py 44.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import glob
import cv2
import numpy as np
import math
import paddle
import sys
Z
zhiboniu 已提交
23
import copy
Z
zhiboniu 已提交
24
from collections import Sequence, defaultdict
Z
zhiboniu 已提交
25
from datacollector import DataCollector, Result
26 27 28 29 30

# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)

Z
zhiboniu 已提交
31 32 33
from pipe_utils import argsparser, print_arguments, merge_cfg, PipeTimer
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint

34
from python.infer import Detector, DetectorPicoDet
J
JYChen 已提交
35 36
from python.keypoint_infer import KeyPointDetector
from python.keypoint_postprocess import translate_to_ori_images
37
from python.preprocess import decode_image, ShortSizeScale
Z
zhiboniu 已提交
38
from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action, visualize_vehicleplate
39 40

from pptracking.python.mot_sde_infer import SDE_Detector
41 42
from pptracking.python.mot.visualize import plot_tracking_dict
from pptracking.python.mot.utils import flow_statistic
43

Z
zhiboniu 已提交
44 45 46 47 48 49 50
from pphuman.attr_infer import AttrDetector
from pphuman.video_action_infer import VideoActionRecognizer
from pphuman.action_infer import SkeletonActionRecognizer, DetActionRecognizer, ClsActionRecognizer
from pphuman.action_utils import KeyPointBuff, ActionVisualHelper
from pphuman.reid import ReID
from pphuman.mtmct import mtmct_process

51 52 53
from ppvehicle.vehicle_plate import PlateRecognizer
from ppvehicle.vehicle_attr import VehicleAttr

54 55
from download import auto_download_model

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

class Pipeline(object):
    """
    Pipeline

    Args:
        cfg (dict): config of models in pipeline
        image_file (string|None): the path of image file, default as None
        image_dir (string|None): the path of image directory, if not None, 
            then all the images in directory will be predicted, default as None
        video_file (string|None): the path of video file, default as None
        camera_id (int): the device id of camera to predict, default as -1
        device (string): the device to predict, options are: CPU/GPU/XPU, 
            default as CPU
        run_mode (string): the mode of prediction, options are: 
            paddle/trt_fp32/trt_fp16, default as paddle
        trt_min_shape (int): min shape for dynamic shape in trt, default as 1
        trt_max_shape (int): max shape for dynamic shape in trt, default as 1280
        trt_opt_shape (int): opt shape for dynamic shape in trt, default as 640
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True. default as False
        cpu_threads (int): cpu threads, default as 1
        enable_mkldnn (bool): whether to open MKLDNN, default as False
        output_dir (string): The path of output, default as 'output'
80 81 82
        draw_center_traj (bool): Whether drawing the trajectory of center, default as False
        secs_interval (int): The seconds interval to count after tracking, default as 10
        do_entrance_counting(bool): Whether counting the numbers of identifiers entering 
83
            or getting out from the entrance, default as False, only support single class
84
            counting in MOT.
85 86
    """

Z
zhiboniu 已提交
87
    def __init__(self, args, cfg):
88
        self.multi_camera = False
Z
zhiboniu 已提交
89 90
        reid_cfg = cfg.get('REID', False)
        self.enable_mtmct = reid_cfg['enable'] if reid_cfg else False
91
        self.is_video = False
Z
zhiboniu 已提交
92
        self.output_dir = args.output_dir
Z
zhiboniu 已提交
93
        self.vis_result = cfg['visual']
Z
zhiboniu 已提交
94 95 96
        self.input = self._parse_input(args.image_file, args.image_dir,
                                       args.video_file, args.video_dir,
                                       args.camera_id)
97
        if self.multi_camera:
98 99 100
            self.predictor = []
            for name in self.input:
                predictor_item = PipePredictor(
Z
zhiboniu 已提交
101
                    args, cfg, is_video=True, multi_camera=True)
102 103 104
                predictor_item.set_file_name(name)
                self.predictor.append(predictor_item)

105
        else:
Z
zhiboniu 已提交
106
            self.predictor = PipePredictor(args, cfg, self.is_video)
107
            if self.is_video:
Z
zhiboniu 已提交
108
                self.predictor.set_file_name(args.video_file)
109

Z
zhiboniu 已提交
110 111 112 113
        self.output_dir = args.output_dir
        self.draw_center_traj = args.draw_center_traj
        self.secs_interval = args.secs_interval
        self.do_entrance_counting = args.do_entrance_counting
114

Z
zhiboniu 已提交
115 116
    def _parse_input(self, image_file, image_dir, video_file, video_dir,
                     camera_id):
117 118 119 120 121 122 123 124 125

        # parse input as is_video and multi_camera

        if image_file is not None or image_dir is not None:
            input = get_test_images(image_dir, image_file)
            self.is_video = False
            self.multi_camera = False

        elif video_file is not None:
126
            assert os.path.exists(video_file), "video_file not exists."
Z
zhiboniu 已提交
127 128 129 130 131 132 133
            self.multi_camera = False
            input = video_file
            self.is_video = True

        elif video_dir is not None:
            videof = [os.path.join(video_dir, x) for x in os.listdir(video_dir)]
            if len(videof) > 1:
134
                self.multi_camera = True
Z
zhiboniu 已提交
135 136
                videof.sort()
                input = videof
137
            else:
Z
zhiboniu 已提交
138
                input = videof[0]
139 140 141
            self.is_video = True

        elif camera_id != -1:
Z
zhiboniu 已提交
142 143
            self.multi_camera = False
            input = camera_id
144 145 146 147
            self.is_video = True

        else:
            raise ValueError(
148
                "Illegal Input, please set one of ['video_file', 'camera_id', 'image_file', 'image_dir']"
149 150 151 152 153 154 155 156 157
            )

        return input

    def run(self):
        if self.multi_camera:
            multi_res = []
            for predictor, input in zip(self.predictor, self.input):
                predictor.run(input)
Z
zhiboniu 已提交
158 159
                collector_data = predictor.get_result()
                multi_res.append(collector_data)
160 161 162 163 164 165
            if self.enable_mtmct:
                mtmct_process(
                    multi_res,
                    self.input,
                    mtmct_vis=self.vis_result,
                    output_dir=self.output_dir)
166 167 168 169 170

        else:
            self.predictor.run(self.input)


171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
def get_model_dir(cfg):
    # auto download inference model
    model_dir_dict = {}
    for key in cfg.keys():
        if type(cfg[key]) ==  dict and \
            ("enable" in cfg[key].keys() and cfg[key]['enable']
                or "enable" not in cfg[key].keys()):

            if "model_dir" in cfg[key].keys():
                model_dir = cfg[key]["model_dir"]
                downloaded_model_dir = auto_download_model(model_dir)
                if downloaded_model_dir:
                    model_dir = downloaded_model_dir
                model_dir_dict[key] = model_dir
                print(key, " model dir:", model_dir)
            elif key == "VEHICLE_PLATE":
                det_model_dir = cfg[key]["det_model_dir"]
                downloaded_det_model_dir = auto_download_model(det_model_dir)
                if downloaded_det_model_dir:
                    det_model_dir = downloaded_det_model_dir
                model_dir_dict["det_model_dir"] = det_model_dir
                print("det_model_dir model dir:", det_model_dir)

                rec_model_dir = cfg[key]["rec_model_dir"]
                downloaded_rec_model_dir = auto_download_model(rec_model_dir)
                if downloaded_rec_model_dir:
                    rec_model_dir = downloaded_rec_model_dir
                model_dir_dict["rec_model_dir"] = rec_model_dir
                print("rec_model_dir model dir:", rec_model_dir)
        elif key == "MOT":  # for idbased and skeletonbased actions
            model_dir = cfg[key]["model_dir"]
            downloaded_model_dir = auto_download_model(model_dir)
            if downloaded_model_dir:
                model_dir = downloaded_model_dir
            model_dir_dict[key] = model_dir

    return model_dir_dict


210 211 212 213 214 215 216 217 218 219 220 221 222
class PipePredictor(object):
    """
    Predictor in single camera
    
    The pipeline for image input: 

        1. Detection
        2. Detection -> Attribute

    The pipeline for video input: 

        1. Tracking
        2. Tracking -> Attribute
Z
zhiboniu 已提交
223
        3. Tracking -> KeyPoint -> SkeletonAction Recognition
224
        4. VideoAction Recognition
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243

    Args:
        cfg (dict): config of models in pipeline
        is_video (bool): whether the input is video, default as False
        multi_camera (bool): whether to use multi camera in pipeline, 
            default as False
        camera_id (int): the device id of camera to predict, default as -1
        device (string): the device to predict, options are: CPU/GPU/XPU, 
            default as CPU
        run_mode (string): the mode of prediction, options are: 
            paddle/trt_fp32/trt_fp16, default as paddle
        trt_min_shape (int): min shape for dynamic shape in trt, default as 1
        trt_max_shape (int): max shape for dynamic shape in trt, default as 1280
        trt_opt_shape (int): opt shape for dynamic shape in trt, default as 640
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True. default as False
        cpu_threads (int): cpu threads, default as 1
        enable_mkldnn (bool): whether to open MKLDNN, default as False
        output_dir (string): The path of output, default as 'output'
244 245 246
        draw_center_traj (bool): Whether drawing the trajectory of center, default as False
        secs_interval (int): The seconds interval to count after tracking, default as 10
        do_entrance_counting(bool): Whether counting the numbers of identifiers entering 
247
            or getting out from the entrance, default as False, only support single class
248
            counting in MOT.
249 250
    """

Z
zhiboniu 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
    def __init__(self, args, cfg, is_video=True, multi_camera=False):
        device = args.device
        run_mode = args.run_mode
        trt_min_shape = args.trt_min_shape
        trt_max_shape = args.trt_max_shape
        trt_opt_shape = args.trt_opt_shape
        trt_calib_mode = args.trt_calib_mode
        cpu_threads = args.cpu_threads
        enable_mkldnn = args.enable_mkldnn
        output_dir = args.output_dir
        draw_center_traj = args.draw_center_traj
        secs_interval = args.secs_interval
        do_entrance_counting = args.do_entrance_counting

        # general module for pphuman and ppvehicle
        self.with_mot = cfg.get('MOT', False)['enable'] if cfg.get(
            'MOT', False) else False
268
        self.with_human_attr = cfg.get('ATTR', False)['enable'] if cfg.get(
Z
zhiboniu 已提交
269
            'ATTR', False) else False
Z
zhiboniu 已提交
270 271
        if self.with_mot:
            print('Multi-Object Tracking enabled')
272 273
        if self.with_human_attr:
            print('Human Attribute Recognition enabled')
Z
zhiboniu 已提交
274 275

        # only for pphuman
Z
zhiboniu 已提交
276 277 278
        self.with_skeleton_action = cfg.get(
            'SKELETON_ACTION', False)['enable'] if cfg.get('SKELETON_ACTION',
                                                           False) else False
Z
zhiboniu 已提交
279 280 281 282 283 284 285 286 287
        self.with_video_action = cfg.get(
            'VIDEO_ACTION', False)['enable'] if cfg.get('VIDEO_ACTION',
                                                        False) else False
        self.with_idbased_detaction = cfg.get(
            'ID_BASED_DETACTION', False)['enable'] if cfg.get(
                'ID_BASED_DETACTION', False) else False
        self.with_idbased_clsaction = cfg.get(
            'ID_BASED_CLSACTION', False)['enable'] if cfg.get(
                'ID_BASED_CLSACTION', False) else False
Z
zhiboniu 已提交
288 289
        self.with_mtmct = cfg.get('REID', False)['enable'] if cfg.get(
            'REID', False) else False
290

Z
zhiboniu 已提交
291 292
        if self.with_skeleton_action:
            print('SkeletonAction Recognition enabled')
Z
zhiboniu 已提交
293 294 295 296 297 298
        if self.with_video_action:
            print('VideoAction Recognition enabled')
        if self.with_idbased_detaction:
            print('IDBASED Detection Action Recognition enabled')
        if self.with_idbased_clsaction:
            print('IDBASED Classification Action Recognition enabled')
Z
zhiboniu 已提交
299 300
        if self.with_mtmct:
            print("MTMCT enabled")
W
wangguanzhong 已提交
301

Z
zhiboniu 已提交
302 303 304 305 306 307 308
        # only for ppvehicle
        self.with_vehicleplate = cfg.get(
            'VEHICLE_PLATE', False)['enable'] if cfg.get('VEHICLE_PLATE',
                                                         False) else False
        if self.with_vehicleplate:
            print('Vehicle Plate Recognition enabled')

309 310 311 312 313 314
        self.with_vehicle_attr = cfg.get(
            'VEHICLE_ATTR', False)['enable'] if cfg.get('VEHICLE_ATTR',
                                                        False) else False
        if self.with_vehicle_attr:
            print('Vehicle Attribute Recognition enabled')

315 316 317 318 319 320
        self.modebase = {
            "framebased": False,
            "videobased": False,
            "idbased": False,
            "skeletonbased": False
        }
321

322 323 324 325
        self.is_video = is_video
        self.multi_camera = multi_camera
        self.cfg = cfg
        self.output_dir = output_dir
326 327 328
        self.draw_center_traj = draw_center_traj
        self.secs_interval = secs_interval
        self.do_entrance_counting = do_entrance_counting
329

J
JYChen 已提交
330
        self.warmup_frame = self.cfg['warmup_frame']
331 332
        self.pipeline_res = Result()
        self.pipe_timer = PipeTimer()
333
        self.file_name = None
Z
zhiboniu 已提交
334
        self.collector = DataCollector()
335

336 337 338
        # auto download inference model
        model_dir_dict = get_model_dir(self.cfg)

339 340
        if not is_video:
            det_cfg = self.cfg['DET']
341
            model_dir = model_dir_dict['DET']
342 343 344 345 346
            batch_size = det_cfg['batch_size']
            self.det_predictor = Detector(
                model_dir, device, run_mode, batch_size, trt_min_shape,
                trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                enable_mkldnn)
347
            if self.with_human_attr:
348
                attr_cfg = self.cfg['ATTR']
349
                model_dir = model_dir_dict['ATTR']
350
                batch_size = attr_cfg['batch_size']
351 352
                basemode = attr_cfg['basemode']
                self.modebase[basemode] = True
353 354 355 356 357
                self.attr_predictor = AttrDetector(
                    model_dir, device, run_mode, batch_size, trt_min_shape,
                    trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                    enable_mkldnn)

358 359
            if self.with_vehicle_attr:
                vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
360
                model_dir = model_dir_dict['VEHICLE_ATTR']
361 362 363 364 365 366 367 368 369 370
                batch_size = vehicleattr_cfg['batch_size']
                color_threshold = vehicleattr_cfg['color_threshold']
                type_threshold = vehicleattr_cfg['type_threshold']
                basemode = vehicleattr_cfg['basemode']
                self.modebase[basemode] = True
                self.vehicle_attr_predictor = VehicleAttr(
                    model_dir, device, run_mode, batch_size, trt_min_shape,
                    trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                    enable_mkldnn, color_threshold, type_threshold)

371
        else:
372
            if self.with_human_attr:
373
                attr_cfg = self.cfg['ATTR']
374
                model_dir = model_dir_dict['ATTR']
375
                batch_size = attr_cfg['batch_size']
376 377
                basemode = attr_cfg['basemode']
                self.modebase[basemode] = True
378 379 380 381
                self.attr_predictor = AttrDetector(
                    model_dir, device, run_mode, batch_size, trt_min_shape,
                    trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                    enable_mkldnn)
Z
zhiboniu 已提交
382
            if self.with_idbased_detaction:
J
JYChen 已提交
383
                idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION']
384
                model_dir = model_dir_dict['ID_BASED_DETACTION']
J
JYChen 已提交
385 386 387 388
                batch_size = idbased_detaction_cfg['batch_size']
                basemode = idbased_detaction_cfg['basemode']
                threshold = idbased_detaction_cfg['threshold']
                display_frames = idbased_detaction_cfg['display_frames']
389
                skip_frame_num = idbased_detaction_cfg['skip_frame_num']
J
JYChen 已提交
390
                self.modebase[basemode] = True
391

J
JYChen 已提交
392 393 394 395 396 397 398 399 400 401 402 403
                self.det_action_predictor = DetActionRecognizer(
                    model_dir,
                    device,
                    run_mode,
                    batch_size,
                    trt_min_shape,
                    trt_max_shape,
                    trt_opt_shape,
                    trt_calib_mode,
                    cpu_threads,
                    enable_mkldnn,
                    threshold=threshold,
404 405
                    display_frames=display_frames,
                    skip_frame_num=skip_frame_num)
J
JYChen 已提交
406 407
                self.det_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
408
            if self.with_idbased_clsaction:
J
JYChen 已提交
409
                idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION']
410
                model_dir = model_dir_dict['ID_BASED_CLSACTION']
J
JYChen 已提交
411 412 413 414 415
                batch_size = idbased_clsaction_cfg['batch_size']
                basemode = idbased_clsaction_cfg['basemode']
                threshold = idbased_clsaction_cfg['threshold']
                self.modebase[basemode] = True
                display_frames = idbased_clsaction_cfg['display_frames']
416 417
                skip_frame_num = idbased_clsaction_cfg['skip_frame_num']

J
JYChen 已提交
418 419 420 421 422 423 424 425 426 427 428 429
                self.cls_action_predictor = ClsActionRecognizer(
                    model_dir,
                    device,
                    run_mode,
                    batch_size,
                    trt_min_shape,
                    trt_max_shape,
                    trt_opt_shape,
                    trt_calib_mode,
                    cpu_threads,
                    enable_mkldnn,
                    threshold=threshold,
430 431
                    display_frames=display_frames,
                    skip_frame_num=skip_frame_num)
J
JYChen 已提交
432 433
                self.cls_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
434 435
            if self.with_skeleton_action:
                skeleton_action_cfg = self.cfg['SKELETON_ACTION']
436
                skeleton_action_model_dir = model_dir_dict['SKELETON_ACTION']
Z
zhiboniu 已提交
437 438 439 440 441
                skeleton_action_batch_size = skeleton_action_cfg['batch_size']
                skeleton_action_frames = skeleton_action_cfg['max_frames']
                display_frames = skeleton_action_cfg['display_frames']
                self.coord_size = skeleton_action_cfg['coord_size']
                basemode = skeleton_action_cfg['basemode']
442 443
                self.modebase[basemode] = True

Z
zhiboniu 已提交
444 445
                self.skeleton_action_predictor = SkeletonActionRecognizer(
                    skeleton_action_model_dir,
J
JYChen 已提交
446 447
                    device,
                    run_mode,
Z
zhiboniu 已提交
448
                    skeleton_action_batch_size,
J
JYChen 已提交
449 450 451 452 453 454
                    trt_min_shape,
                    trt_max_shape,
                    trt_opt_shape,
                    trt_calib_mode,
                    cpu_threads,
                    enable_mkldnn,
Z
zhiboniu 已提交
455
                    window_size=skeleton_action_frames)
J
JYChen 已提交
456
                self.skeleton_action_visual_helper = ActionVisualHelper(
Z
zhiboniu 已提交
457
                    display_frames)
458 459 460

                if self.modebase["skeletonbased"]:
                    kpt_cfg = self.cfg['KPT']
461
                    kpt_model_dir = model_dir_dict['KPT']
462 463 464 465 466 467 468 469 470 471 472 473 474
                    kpt_batch_size = kpt_cfg['batch_size']
                    self.kpt_predictor = KeyPointDetector(
                        kpt_model_dir,
                        device,
                        run_mode,
                        kpt_batch_size,
                        trt_min_shape,
                        trt_max_shape,
                        trt_opt_shape,
                        trt_calib_mode,
                        cpu_threads,
                        enable_mkldnn,
                        use_dark=False)
Z
zhiboniu 已提交
475
                    self.kpt_buff = KeyPointBuff(skeleton_action_frames)
Z
zhiboniu 已提交
476

Z
zhiboniu 已提交
477 478 479 480 481 482 483
            if self.with_vehicleplate:
                vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
                self.vehicleplate_detector = PlateRecognizer(args,
                                                             vehicleplate_cfg)
                basemode = vehicleplate_cfg['basemode']
                self.modebase[basemode] = True

484 485
            if self.with_vehicle_attr:
                vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
486
                model_dir = model_dir_dict['VEHICLE_ATTR']
487 488 489 490 491 492 493 494 495 496
                batch_size = vehicleattr_cfg['batch_size']
                color_threshold = vehicleattr_cfg['color_threshold']
                type_threshold = vehicleattr_cfg['type_threshold']
                basemode = vehicleattr_cfg['basemode']
                self.modebase[basemode] = True
                self.vehicle_attr_predictor = VehicleAttr(
                    model_dir, device, run_mode, batch_size, trt_min_shape,
                    trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                    enable_mkldnn, color_threshold, type_threshold)

Z
zhiboniu 已提交
497 498
            if self.with_mtmct:
                reid_cfg = self.cfg['REID']
499
                model_dir = model_dir_dict['REID']
Z
zhiboniu 已提交
500 501 502 503 504 505 506 507
                batch_size = reid_cfg['batch_size']
                basemode = reid_cfg['basemode']
                self.modebase[basemode] = True
                self.reid_predictor = ReID(
                    model_dir, device, run_mode, batch_size, trt_min_shape,
                    trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                    enable_mkldnn)

Z
zhiboniu 已提交
508 509 510
            if self.with_mot or self.modebase["idbased"] or self.modebase[
                    "skeletonbased"]:
                mot_cfg = self.cfg['MOT']
511
                model_dir = model_dir_dict['MOT']
Z
zhiboniu 已提交
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
                tracker_config = mot_cfg['tracker_config']
                batch_size = mot_cfg['batch_size']
                basemode = mot_cfg['basemode']
                self.modebase[basemode] = True
                self.mot_predictor = SDE_Detector(
                    model_dir,
                    tracker_config,
                    device,
                    run_mode,
                    batch_size,
                    trt_min_shape,
                    trt_max_shape,
                    trt_opt_shape,
                    trt_calib_mode,
                    cpu_threads,
                    enable_mkldnn,
                    draw_center_traj=draw_center_traj,
                    secs_interval=secs_interval,
                    do_entrance_counting=do_entrance_counting)

532 533 534 535 536 537
            if self.with_video_action:
                video_action_cfg = self.cfg['VIDEO_ACTION']

                basemode = video_action_cfg['basemode']
                self.modebase[basemode] = True

538
                video_action_model_dir = model_dir_dict['VIDEO_ACTION']
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
                video_action_batch_size = video_action_cfg['batch_size']
                short_size = video_action_cfg["short_size"]
                target_size = video_action_cfg["target_size"]

                self.video_action_predictor = VideoActionRecognizer(
                    model_dir=video_action_model_dir,
                    short_size=short_size,
                    target_size=target_size,
                    device=device,
                    run_mode=run_mode,
                    batch_size=video_action_batch_size,
                    trt_min_shape=trt_min_shape,
                    trt_max_shape=trt_max_shape,
                    trt_opt_shape=trt_opt_shape,
                    trt_calib_mode=trt_calib_mode,
                    cpu_threads=cpu_threads,
                    enable_mkldnn=enable_mkldnn)

557
    def set_file_name(self, path):
W
wangguanzhong 已提交
558 559 560 561 562
        if path is not None:
            self.file_name = os.path.split(path)[-1]
        else:
            # use camera id
            self.file_name = None
563

564
    def get_result(self):
Z
zhiboniu 已提交
565
        return self.collector.get_res()
566 567 568 569 570 571

    def run(self, input):
        if self.is_video:
            self.predict_video(input)
        else:
            self.predict_image(input)
572
        self.pipe_timer.info()
573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590

    def predict_image(self, input):
        # det
        # det -> attr
        batch_loop_cnt = math.ceil(
            float(len(input)) / self.det_predictor.batch_size)
        for i in range(batch_loop_cnt):
            start_index = i * self.det_predictor.batch_size
            end_index = min((i + 1) * self.det_predictor.batch_size, len(input))
            batch_file = input[start_index:end_index]
            batch_input = [decode_image(f, {})[0] for f in batch_file]

            if i > self.warmup_frame:
                self.pipe_timer.total_time.start()
                self.pipe_timer.module_time['det'].start()
            # det output format: class, score, xmin, ymin, xmax, ymax
            det_res = self.det_predictor.predict_image(
                batch_input, visual=False)
591 592
            det_res = self.det_predictor.filter_box(det_res,
                                                    self.cfg['crop_thresh'])
593 594 595 596
            if i > self.warmup_frame:
                self.pipe_timer.module_time['det'].end()
            self.pipeline_res.update(det_res, 'det')

597
            if self.with_human_attr:
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
                crop_inputs = crop_image_with_det(batch_input, det_res)
                attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].end()

                attr_res = {'output': attr_res_list}
                self.pipeline_res.update(attr_res, 'attr')

615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
            if self.with_vehicle_attr:
                crop_inputs = crop_image_with_det(batch_input, det_res)
                vehicle_attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    vehicle_attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].end()

                attr_res = {'output': vehicle_attr_res_list}
                self.pipeline_res.update(attr_res, 'vehicle_attr')

633 634 635 636 637 638 639
            self.pipe_timer.img_num += len(batch_input)
            if i > self.warmup_frame:
                self.pipe_timer.total_time.end()

            if self.cfg['visual']:
                self.visualize_image(batch_file, batch_input, self.pipeline_res)

Z
zhiboniu 已提交
640
    def predict_video(self, video_file):
641 642 643
        # mot
        # mot -> attr
        # mot -> pose -> action
Z
zhiboniu 已提交
644
        capture = cv2.VideoCapture(video_file)
645
        video_out_name = 'output.mp4' if self.file_name is None else self.file_name
646 647 648 649 650 651

        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
652
        print("video fps: %d, frame_count: %d" % (fps, frame_count))
653 654 655 656 657 658 659

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
        fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
        frame_id = 0
660 661 662 663 664 665 666 667 668 669 670 671 672

        entrance, records, center_traj = None, None, None
        if self.draw_center_traj:
            center_traj = [{}]
        id_set = set()
        interval_id_set = set()
        in_id_list = list()
        out_id_list = list()
        prev_center = dict()
        records = list()
        entrance = [0, height / 2., width, height / 2.]
        video_fps = fps

673 674
        video_action_imgs = []

675 676 677 678
        if self.with_video_action:
            short_size = self.cfg["VIDEO_ACTION"]["short_size"]
            scale = ShortSizeScale(short_size)

679 680 681
        while (1):
            if frame_id % 10 == 0:
                print('frame id: ', frame_id)
682

683 684 685
            ret, frame = capture.read()
            if not ret:
                break
686
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
687

688
            if self.modebase["idbased"] or self.modebase["skeletonbased"]:
689
                if frame_id > self.warmup_frame:
690 691 692
                    self.pipe_timer.total_time.start()
                    self.pipe_timer.module_time['mot'].start()
                res = self.mot_predictor.predict_image(
693
                    [copy.deepcopy(frame_rgb)], visual=False)
694

J
JYChen 已提交
695
                if frame_id > self.warmup_frame:
696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713
                    self.pipe_timer.module_time['mot'].end()

                # mot output format: id, class, score, xmin, ymin, xmax, ymax
                mot_res = parse_mot_res(res)

                # flow_statistic only support single class MOT
                boxes, scores, ids = res[0]  # batch size = 1 in MOT
                mot_result = (frame_id + 1, boxes[0], scores[0],
                              ids[0])  # single class
                statistic = flow_statistic(
                    mot_result, self.secs_interval, self.do_entrance_counting,
                    video_fps, entrance, id_set, interval_id_set, in_id_list,
                    out_id_list, prev_center, records)
                records = statistic['records']

                # nothing detected
                if len(mot_res['boxes']) == 0:
                    frame_id += 1
J
JYChen 已提交
714
                    if frame_id > self.warmup_frame:
715 716 717 718 719 720 721 722 723
                        self.pipe_timer.img_num += 1
                        self.pipe_timer.total_time.end()
                    if self.cfg['visual']:
                        _, _, fps = self.pipe_timer.get_total_time()
                        im = self.visualize_video(frame, mot_res, frame_id, fps,
                                                  entrance, records,
                                                  center_traj)  # visualize
                        writer.write(im)
                        if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
724
                            cv2.imshow('Paddle-Pipeline', im)
725 726 727 728 729
                            if cv2.waitKey(1) & 0xFF == ord('q'):
                                break
                    continue

                self.pipeline_res.update(mot_res, 'mot')
J
JYChen 已提交
730
                crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
731
                    frame_rgb, mot_res)
732

Z
zhiboniu 已提交
733
                if self.with_vehicleplate:
Z
zhiboniu 已提交
734 735
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].start()
Z
zhiboniu 已提交
736 737
                    platelicense = self.vehicleplate_detector.get_platelicense(
                        crop_input)
Z
zhiboniu 已提交
738 739
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].end()
Z
zhiboniu 已提交
740 741
                    self.pipeline_res.update(platelicense, 'vehicleplate')

742
                if self.with_human_attr:
J
JYChen 已提交
743
                    if frame_id > self.warmup_frame:
744 745 746 747 748 749 750
                        self.pipe_timer.module_time['attr'].start()
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['attr'].end()
                    self.pipeline_res.update(attr_res, 'attr')

751 752 753 754 755 756 757 758 759
                if self.with_vehicle_attr:
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].start()
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].end()
                    self.pipeline_res.update(attr_res, 'vehicle_attr')

Z
zhiboniu 已提交
760
                if self.with_idbased_detaction:
J
JYChen 已提交
761 762 763 764 765 766 767 768 769 770
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].start()
                    det_action_res = self.det_action_predictor.predict(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].end()
                    self.pipeline_res.update(det_action_res, 'det_action')

                    if self.cfg['visual']:
                        self.det_action_visual_helper.update(det_action_res)
Z
zhiboniu 已提交
771 772

                if self.with_idbased_clsaction:
J
JYChen 已提交
773 774 775 776 777 778 779 780 781 782
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].start()
                    cls_action_res = self.cls_action_predictor.predict_with_mot(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].end()
                    self.pipeline_res.update(cls_action_res, 'cls_action')

                    if self.cfg['visual']:
                        self.cls_action_visual_helper.update(cls_action_res)
Z
zhiboniu 已提交
783

Z
zhiboniu 已提交
784
                if self.with_skeleton_action:
Z
zhiboniu 已提交
785 786 787 788 789 790 791 792 793 794 795 796 797
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].start()
                    kpt_pred = self.kpt_predictor.predict_image(
                        crop_input, visual=False)
                    keypoint_vector, score_vector = translate_to_ori_images(
                        kpt_pred, np.array(new_bboxes))
                    kpt_res = {}
                    kpt_res['keypoint'] = [
                        keypoint_vector.tolist(), score_vector.tolist()
                    ] if len(keypoint_vector) > 0 else [[], []]
                    kpt_res['bbox'] = ori_bboxes
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].end()
798

Z
zhiboniu 已提交
799
                    self.pipeline_res.update(kpt_res, 'kpt')
800

Z
zhiboniu 已提交
801
                    self.kpt_buff.update(kpt_res, mot_res)  # collect kpt output
802 803 804
                    state = self.kpt_buff.get_state(
                    )  # whether frame num is enough or lost tracker

Z
zhiboniu 已提交
805
                    skeleton_action_res = {}
806 807
                    if state:
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
808 809
                            self.pipe_timer.module_time[
                                'skeleton_action'].start()
810 811
                        collected_keypoint = self.kpt_buff.get_collected_keypoint(
                        )  # reoragnize kpt output with ID
Z
zhiboniu 已提交
812 813 814 815
                        skeleton_action_input = parse_mot_keypoint(
                            collected_keypoint, self.coord_size)
                        skeleton_action_res = self.skeleton_action_predictor.predict_skeleton_with_mot(
                            skeleton_action_input)
816
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
817 818 819
                            self.pipe_timer.module_time['skeleton_action'].end()
                        self.pipeline_res.update(skeleton_action_res,
                                                 'skeleton_action')
820 821

                    if self.cfg['visual']:
Z
zhiboniu 已提交
822 823
                        self.skeleton_action_visual_helper.update(
                            skeleton_action_res)
824 825 826

                if self.with_mtmct and frame_id % 10 == 0:
                    crop_input, img_qualities, rects = self.reid_predictor.crop_image_with_mot(
827
                        frame_rgb, mot_res)
828 829 830 831 832 833
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].start()
                    reid_res = self.reid_predictor.predict_batch(crop_input)

                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].end()
J
JYChen 已提交
834

835 836 837 838 839 840 841 842
                    reid_res_dict = {
                        'features': reid_res,
                        "qualities": img_qualities,
                        "rects": rects
                    }
                    self.pipeline_res.update(reid_res_dict, 'reid')
                else:
                    self.pipeline_res.clear('reid')
Z
zhiboniu 已提交
843

Z
zhiboniu 已提交
844
            if self.with_video_action:
845 846 847 848 849 850 851 852 853 854 855 856 857
                # get the params
                frame_len = self.cfg["VIDEO_ACTION"]["frame_len"]
                sample_freq = self.cfg["VIDEO_ACTION"]["sample_freq"]

                if sample_freq * frame_len > frame_count:  # video is too short
                    sample_freq = int(frame_count / frame_len)

                # filter the warmup frames
                if frame_id > self.warmup_frame:
                    self.pipe_timer.module_time['video_action'].start()

                # collect frames
                if frame_id % sample_freq == 0:
858
                    # Scale image
859
                    scaled_img = scale(frame_rgb)
860
                    video_action_imgs.append(scaled_img)
861 862 863 864 865 866 867 868 869 870 871 872 873 874

                # the number of collected frames is enough to predict video action
                if len(video_action_imgs) == frame_len:
                    classes, scores = self.video_action_predictor.predict(
                        video_action_imgs)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['video_action'].end()

                    video_action_res = {"class": classes[0], "score": scores[0]}
                    self.pipeline_res.update(video_action_res, 'video_action')

                    print("video_action_res:", video_action_res)

                    video_action_imgs.clear()  # next clip
Z
zhiboniu 已提交
875 876

            self.collector.append(frame_id, self.pipeline_res)
877 878 879 880 881 882 883

            if frame_id > self.warmup_frame:
                self.pipe_timer.img_num += 1
                self.pipe_timer.total_time.end()
            frame_id += 1

            if self.cfg['visual']:
884 885
                _, _, fps = self.pipe_timer.get_total_time()
                im = self.visualize_video(frame, self.pipeline_res, frame_id,
886 887
                                          fps, entrance, records,
                                          center_traj)  # visualize
888
                writer.write(im)
W
wangguanzhong 已提交
889
                if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
890
                    cv2.imshow('Paddle-Pipeline', im)
W
wangguanzhong 已提交
891 892
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
893 894 895 896

        writer.release()
        print('save result to {}'.format(out_path))

897 898 899 900 901 902 903 904
    def visualize_video(self,
                        image,
                        result,
                        frame_id,
                        fps,
                        entrance=None,
                        records=None,
                        center_traj=None):
Z
zhiboniu 已提交
905
        mot_res = copy.deepcopy(result.get('mot'))
906 907
        if mot_res is not None:
            ids = mot_res['boxes'][:, 0]
W
wangguanzhong 已提交
908
            scores = mot_res['boxes'][:, 2]
909 910 911 912 913 914
            boxes = mot_res['boxes'][:, 3:]
            boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
            boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        else:
            boxes = np.zeros([0, 4])
            ids = np.zeros([0])
W
wangguanzhong 已提交
915
            scores = np.zeros([0])
916 917 918 919 920 921 922 923 924 925

        # single class, still need to be defaultdict type for ploting
        num_classes = 1
        online_tlwhs = defaultdict(list)
        online_scores = defaultdict(list)
        online_ids = defaultdict(list)
        online_tlwhs[0] = boxes
        online_scores[0] = scores
        online_ids[0] = ids

F
Feng Ni 已提交
926 927 928 929 930 931 932 933 934 935 936 937 938
        if mot_res is not None:
            image = plot_tracking_dict(
                image,
                num_classes,
                online_tlwhs,
                online_ids,
                online_scores,
                frame_id=frame_id,
                fps=fps,
                do_entrance_counting=self.do_entrance_counting,
                entrance=entrance,
                records=records,
                center_traj=center_traj)
939

940 941 942 943 944 945 946 947 948
        human_attr_res = result.get('attr')
        if human_attr_res is not None:
            boxes = mot_res['boxes'][:, 1:]
            human_attr_res = human_attr_res['output']
            image = visualize_attr(image, human_attr_res, boxes)
            image = np.array(image)

        vehicle_attr_res = result.get('vehicle_attr')
        if vehicle_attr_res is not None:
949
            boxes = mot_res['boxes'][:, 1:]
950 951
            vehicle_attr_res = vehicle_attr_res['output']
            image = visualize_attr(image, vehicle_attr_res, boxes)
952 953
            image = np.array(image)

Z
zhiboniu 已提交
954 955 956 957 958 959 960
        vehicleplate_res = result.get('vehicleplate')
        if vehicleplate_res:
            boxes = mot_res['boxes'][:, 1:]
            image = visualize_vehicleplate(image, vehicleplate_res['plate'],
                                           boxes)
            image = np.array(image)

J
JYChen 已提交
961 962 963 964 965 966 967 968
        kpt_res = result.get('kpt')
        if kpt_res is not None:
            image = visualize_pose(
                image,
                kpt_res,
                visual_thresh=self.cfg['kpt_thresh'],
                returnimg=True)

969
        video_action_res = result.get('video_action')
J
JYChen 已提交
970
        if video_action_res is not None:
971 972 973
            video_action_score = None
            if video_action_res and video_action_res["class"] == 1:
                video_action_score = video_action_res["score"]
974 975 976
            mot_boxes = None
            if mot_res:
                mot_boxes = mot_res['boxes']
977 978
            image = visualize_action(
                image,
979
                mot_boxes,
J
JYChen 已提交
980
                action_visual_collector=None,
981 982 983
                action_text="SkeletonAction",
                video_action_score=video_action_score,
                video_action_text="Fight")
J
JYChen 已提交
984

J
JYChen 已提交
985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007
        visual_helper_for_display = []
        action_to_display = []

        skeleton_action_res = result.get('skeleton_action')
        if skeleton_action_res is not None:
            visual_helper_for_display.append(self.skeleton_action_visual_helper)
            action_to_display.append("Falling")

        det_action_res = result.get('det_action')
        if det_action_res is not None:
            visual_helper_for_display.append(self.det_action_visual_helper)
            action_to_display.append("Smoking")

        cls_action_res = result.get('cls_action')
        if cls_action_res is not None:
            visual_helper_for_display.append(self.cls_action_visual_helper)
            action_to_display.append("Calling")

        if len(visual_helper_for_display) > 0:
            image = visualize_action(image, mot_res['boxes'],
                                     visual_helper_for_display,
                                     action_to_display)

1008 1009 1010 1011 1012
        return image

    def visualize_image(self, im_files, images, result):
        start_idx, boxes_num_i = 0, 0
        det_res = result.get('det')
1013 1014 1015
        human_attr_res = result.get('attr')
        vehicle_attr_res = result.get('vehicle_attr')

1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026
        for i, (im_file, im) in enumerate(zip(im_files, images)):
            if det_res is not None:
                det_res_i = {}
                boxes_num_i = det_res['boxes_num'][i]
                det_res_i['boxes'] = det_res['boxes'][start_idx:start_idx +
                                                      boxes_num_i, :]
                im = visualize_box_mask(
                    im,
                    det_res_i,
                    labels=['person'],
                    threshold=self.cfg['crop_thresh'])
1027 1028
                im = np.ascontiguousarray(np.copy(im))
                im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
1029 1030 1031 1032 1033 1034 1035 1036 1037
            if human_attr_res is not None:
                human_attr_res_i = human_attr_res['output'][start_idx:start_idx
                                                            + boxes_num_i]
                im = visualize_attr(im, human_attr_res_i, det_res_i['boxes'])
            if vehicle_attr_res is not None:
                vehicle_attr_res_i = vehicle_attr_res['output'][
                    start_idx:start_idx + boxes_num_i]
                im = visualize_attr(im, vehicle_attr_res_i, det_res_i['boxes'])

1038 1039 1040 1041
            img_name = os.path.split(im_file)[-1]
            if not os.path.exists(self.output_dir):
                os.makedirs(self.output_dir)
            out_path = os.path.join(self.output_dir, img_name)
1042
            cv2.imwrite(out_path, im)
1043 1044 1045 1046 1047 1048 1049
            print("save result to: " + out_path)
            start_idx += boxes_num_i


def main():
    cfg = merge_cfg(FLAGS)
    print_arguments(cfg)
1050

Z
zhiboniu 已提交
1051
    pipeline = Pipeline(FLAGS, cfg)
1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
    pipeline.run()


if __name__ == '__main__':
    paddle.enable_static()
    parser = argsparser()
    FLAGS = parser.parse_args()
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"

    main()