pipeline.py 42.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import yaml
import glob
import cv2
import numpy as np
import math
import paddle
import sys
Z
zhiboniu 已提交
23
import copy
Z
zhiboniu 已提交
24
from collections import Sequence, defaultdict
Z
zhiboniu 已提交
25
from datacollector import DataCollector, Result
26 27 28 29 30

# add deploy path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)

Z
zhiboniu 已提交
31 32 33
from pipe_utils import argsparser, print_arguments, merge_cfg, PipeTimer
from pipe_utils import get_test_images, crop_image_with_det, crop_image_with_mot, parse_mot_res, parse_mot_keypoint

34
from python.infer import Detector, DetectorPicoDet
J
JYChen 已提交
35 36
from python.keypoint_infer import KeyPointDetector
from python.keypoint_postprocess import translate_to_ori_images
37
from python.preprocess import decode_image, ShortSizeScale
Z
zhiboniu 已提交
38
from python.visualize import visualize_box_mask, visualize_attr, visualize_pose, visualize_action, visualize_vehicleplate
39 40

from pptracking.python.mot_sde_infer import SDE_Detector
41 42
from pptracking.python.mot.visualize import plot_tracking_dict
from pptracking.python.mot.utils import flow_statistic
43

Z
zhiboniu 已提交
44 45 46 47 48 49 50
from pphuman.attr_infer import AttrDetector
from pphuman.video_action_infer import VideoActionRecognizer
from pphuman.action_infer import SkeletonActionRecognizer, DetActionRecognizer, ClsActionRecognizer
from pphuman.action_utils import KeyPointBuff, ActionVisualHelper
from pphuman.reid import ReID
from pphuman.mtmct import mtmct_process

51 52 53
from ppvehicle.vehicle_plate import PlateRecognizer
from ppvehicle.vehicle_attr import VehicleAttr

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

class Pipeline(object):
    """
    Pipeline

    Args:
        cfg (dict): config of models in pipeline
        image_file (string|None): the path of image file, default as None
        image_dir (string|None): the path of image directory, if not None, 
            then all the images in directory will be predicted, default as None
        video_file (string|None): the path of video file, default as None
        camera_id (int): the device id of camera to predict, default as -1
        device (string): the device to predict, options are: CPU/GPU/XPU, 
            default as CPU
        run_mode (string): the mode of prediction, options are: 
            paddle/trt_fp32/trt_fp16, default as paddle
        trt_min_shape (int): min shape for dynamic shape in trt, default as 1
        trt_max_shape (int): max shape for dynamic shape in trt, default as 1280
        trt_opt_shape (int): opt shape for dynamic shape in trt, default as 640
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True. default as False
        cpu_threads (int): cpu threads, default as 1
        enable_mkldnn (bool): whether to open MKLDNN, default as False
        output_dir (string): The path of output, default as 'output'
78 79 80
        draw_center_traj (bool): Whether drawing the trajectory of center, default as False
        secs_interval (int): The seconds interval to count after tracking, default as 10
        do_entrance_counting(bool): Whether counting the numbers of identifiers entering 
81
            or getting out from the entrance, default as False, only support single class
82
            counting in MOT.
83 84
    """

Z
zhiboniu 已提交
85
    def __init__(self, args, cfg):
86
        self.multi_camera = False
Z
zhiboniu 已提交
87 88
        reid_cfg = cfg.get('REID', False)
        self.enable_mtmct = reid_cfg['enable'] if reid_cfg else False
89
        self.is_video = False
Z
zhiboniu 已提交
90
        self.output_dir = args.output_dir
Z
zhiboniu 已提交
91
        self.vis_result = cfg['visual']
Z
zhiboniu 已提交
92 93 94
        self.input = self._parse_input(args.image_file, args.image_dir,
                                       args.video_file, args.video_dir,
                                       args.camera_id)
95
        if self.multi_camera:
96 97 98
            self.predictor = []
            for name in self.input:
                predictor_item = PipePredictor(
Z
zhiboniu 已提交
99
                    args, cfg, is_video=True, multi_camera=True)
100 101 102
                predictor_item.set_file_name(name)
                self.predictor.append(predictor_item)

103
        else:
Z
zhiboniu 已提交
104
            self.predictor = PipePredictor(args, cfg, self.is_video)
105
            if self.is_video:
Z
zhiboniu 已提交
106
                self.predictor.set_file_name(args.video_file)
107

Z
zhiboniu 已提交
108 109 110 111
        self.output_dir = args.output_dir
        self.draw_center_traj = args.draw_center_traj
        self.secs_interval = args.secs_interval
        self.do_entrance_counting = args.do_entrance_counting
112

Z
zhiboniu 已提交
113 114
    def _parse_input(self, image_file, image_dir, video_file, video_dir,
                     camera_id):
115 116 117 118 119 120 121 122 123

        # parse input as is_video and multi_camera

        if image_file is not None or image_dir is not None:
            input = get_test_images(image_dir, image_file)
            self.is_video = False
            self.multi_camera = False

        elif video_file is not None:
124
            assert os.path.exists(video_file), "video_file not exists."
Z
zhiboniu 已提交
125 126 127 128 129 130 131
            self.multi_camera = False
            input = video_file
            self.is_video = True

        elif video_dir is not None:
            videof = [os.path.join(video_dir, x) for x in os.listdir(video_dir)]
            if len(videof) > 1:
132
                self.multi_camera = True
Z
zhiboniu 已提交
133 134
                videof.sort()
                input = videof
135
            else:
Z
zhiboniu 已提交
136
                input = videof[0]
137 138 139
            self.is_video = True

        elif camera_id != -1:
Z
zhiboniu 已提交
140 141
            self.multi_camera = False
            input = camera_id
142 143 144 145
            self.is_video = True

        else:
            raise ValueError(
146
                "Illegal Input, please set one of ['video_file', 'camera_id', 'image_file', 'image_dir']"
147 148 149 150 151 152 153 154 155
            )

        return input

    def run(self):
        if self.multi_camera:
            multi_res = []
            for predictor, input in zip(self.predictor, self.input):
                predictor.run(input)
Z
zhiboniu 已提交
156 157
                collector_data = predictor.get_result()
                multi_res.append(collector_data)
158 159 160 161 162 163
            if self.enable_mtmct:
                mtmct_process(
                    multi_res,
                    self.input,
                    mtmct_vis=self.vis_result,
                    output_dir=self.output_dir)
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181

        else:
            self.predictor.run(self.input)


class PipePredictor(object):
    """
    Predictor in single camera
    
    The pipeline for image input: 

        1. Detection
        2. Detection -> Attribute

    The pipeline for video input: 

        1. Tracking
        2. Tracking -> Attribute
Z
zhiboniu 已提交
182
        3. Tracking -> KeyPoint -> SkeletonAction Recognition
183
        4. VideoAction Recognition
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202

    Args:
        cfg (dict): config of models in pipeline
        is_video (bool): whether the input is video, default as False
        multi_camera (bool): whether to use multi camera in pipeline, 
            default as False
        camera_id (int): the device id of camera to predict, default as -1
        device (string): the device to predict, options are: CPU/GPU/XPU, 
            default as CPU
        run_mode (string): the mode of prediction, options are: 
            paddle/trt_fp32/trt_fp16, default as paddle
        trt_min_shape (int): min shape for dynamic shape in trt, default as 1
        trt_max_shape (int): max shape for dynamic shape in trt, default as 1280
        trt_opt_shape (int): opt shape for dynamic shape in trt, default as 640
        trt_calib_mode (bool): If the model is produced by TRT offline quantitative
            calibration, trt_calib_mode need to set True. default as False
        cpu_threads (int): cpu threads, default as 1
        enable_mkldnn (bool): whether to open MKLDNN, default as False
        output_dir (string): The path of output, default as 'output'
203 204 205
        draw_center_traj (bool): Whether drawing the trajectory of center, default as False
        secs_interval (int): The seconds interval to count after tracking, default as 10
        do_entrance_counting(bool): Whether counting the numbers of identifiers entering 
206
            or getting out from the entrance, default as False, only support single class
207
            counting in MOT.
208 209
    """

Z
zhiboniu 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
    def __init__(self, args, cfg, is_video=True, multi_camera=False):
        device = args.device
        run_mode = args.run_mode
        trt_min_shape = args.trt_min_shape
        trt_max_shape = args.trt_max_shape
        trt_opt_shape = args.trt_opt_shape
        trt_calib_mode = args.trt_calib_mode
        cpu_threads = args.cpu_threads
        enable_mkldnn = args.enable_mkldnn
        output_dir = args.output_dir
        draw_center_traj = args.draw_center_traj
        secs_interval = args.secs_interval
        do_entrance_counting = args.do_entrance_counting

        # general module for pphuman and ppvehicle
        self.with_mot = cfg.get('MOT', False)['enable'] if cfg.get(
            'MOT', False) else False
227
        self.with_human_attr = cfg.get('ATTR', False)['enable'] if cfg.get(
Z
zhiboniu 已提交
228
            'ATTR', False) else False
Z
zhiboniu 已提交
229 230
        if self.with_mot:
            print('Multi-Object Tracking enabled')
231 232
        if self.with_human_attr:
            print('Human Attribute Recognition enabled')
Z
zhiboniu 已提交
233 234

        # only for pphuman
Z
zhiboniu 已提交
235 236 237
        self.with_skeleton_action = cfg.get(
            'SKELETON_ACTION', False)['enable'] if cfg.get('SKELETON_ACTION',
                                                           False) else False
Z
zhiboniu 已提交
238 239 240 241 242 243 244 245 246
        self.with_video_action = cfg.get(
            'VIDEO_ACTION', False)['enable'] if cfg.get('VIDEO_ACTION',
                                                        False) else False
        self.with_idbased_detaction = cfg.get(
            'ID_BASED_DETACTION', False)['enable'] if cfg.get(
                'ID_BASED_DETACTION', False) else False
        self.with_idbased_clsaction = cfg.get(
            'ID_BASED_CLSACTION', False)['enable'] if cfg.get(
                'ID_BASED_CLSACTION', False) else False
Z
zhiboniu 已提交
247 248
        self.with_mtmct = cfg.get('REID', False)['enable'] if cfg.get(
            'REID', False) else False
249

Z
zhiboniu 已提交
250 251
        if self.with_skeleton_action:
            print('SkeletonAction Recognition enabled')
Z
zhiboniu 已提交
252 253 254 255 256 257
        if self.with_video_action:
            print('VideoAction Recognition enabled')
        if self.with_idbased_detaction:
            print('IDBASED Detection Action Recognition enabled')
        if self.with_idbased_clsaction:
            print('IDBASED Classification Action Recognition enabled')
Z
zhiboniu 已提交
258 259
        if self.with_mtmct:
            print("MTMCT enabled")
W
wangguanzhong 已提交
260

Z
zhiboniu 已提交
261 262 263 264 265 266 267
        # only for ppvehicle
        self.with_vehicleplate = cfg.get(
            'VEHICLE_PLATE', False)['enable'] if cfg.get('VEHICLE_PLATE',
                                                         False) else False
        if self.with_vehicleplate:
            print('Vehicle Plate Recognition enabled')

268 269 270 271 272 273
        self.with_vehicle_attr = cfg.get(
            'VEHICLE_ATTR', False)['enable'] if cfg.get('VEHICLE_ATTR',
                                                        False) else False
        if self.with_vehicle_attr:
            print('Vehicle Attribute Recognition enabled')

274 275 276 277 278 279
        self.modebase = {
            "framebased": False,
            "videobased": False,
            "idbased": False,
            "skeletonbased": False
        }
280

281 282 283 284
        self.is_video = is_video
        self.multi_camera = multi_camera
        self.cfg = cfg
        self.output_dir = output_dir
285 286 287
        self.draw_center_traj = draw_center_traj
        self.secs_interval = secs_interval
        self.do_entrance_counting = do_entrance_counting
288

J
JYChen 已提交
289
        self.warmup_frame = self.cfg['warmup_frame']
290 291
        self.pipeline_res = Result()
        self.pipe_timer = PipeTimer()
292
        self.file_name = None
Z
zhiboniu 已提交
293
        self.collector = DataCollector()
294 295 296 297 298 299 300 301 302

        if not is_video:
            det_cfg = self.cfg['DET']
            model_dir = det_cfg['model_dir']
            batch_size = det_cfg['batch_size']
            self.det_predictor = Detector(
                model_dir, device, run_mode, batch_size, trt_min_shape,
                trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                enable_mkldnn)
303
            if self.with_human_attr:
304 305 306
                attr_cfg = self.cfg['ATTR']
                model_dir = attr_cfg['model_dir']
                batch_size = attr_cfg['batch_size']
307 308
                basemode = attr_cfg['basemode']
                self.modebase[basemode] = True
309 310 311 312 313
                self.attr_predictor = AttrDetector(
                    model_dir, device, run_mode, batch_size, trt_min_shape,
                    trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                    enable_mkldnn)

314 315 316 317 318 319 320 321 322 323 324 325 326
            if self.with_vehicle_attr:
                vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
                model_dir = vehicleattr_cfg['model_dir']
                batch_size = vehicleattr_cfg['batch_size']
                color_threshold = vehicleattr_cfg['color_threshold']
                type_threshold = vehicleattr_cfg['type_threshold']
                basemode = vehicleattr_cfg['basemode']
                self.modebase[basemode] = True
                self.vehicle_attr_predictor = VehicleAttr(
                    model_dir, device, run_mode, batch_size, trt_min_shape,
                    trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                    enable_mkldnn, color_threshold, type_threshold)

327
        else:
328
            if self.with_human_attr:
329 330 331
                attr_cfg = self.cfg['ATTR']
                model_dir = attr_cfg['model_dir']
                batch_size = attr_cfg['batch_size']
332 333
                basemode = attr_cfg['basemode']
                self.modebase[basemode] = True
334 335 336 337
                self.attr_predictor = AttrDetector(
                    model_dir, device, run_mode, batch_size, trt_min_shape,
                    trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                    enable_mkldnn)
Z
zhiboniu 已提交
338
            if self.with_idbased_detaction:
J
JYChen 已提交
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
                idbased_detaction_cfg = self.cfg['ID_BASED_DETACTION']
                model_dir = idbased_detaction_cfg['model_dir']
                batch_size = idbased_detaction_cfg['batch_size']
                basemode = idbased_detaction_cfg['basemode']
                threshold = idbased_detaction_cfg['threshold']
                display_frames = idbased_detaction_cfg['display_frames']
                self.modebase[basemode] = True
                self.det_action_predictor = DetActionRecognizer(
                    model_dir,
                    device,
                    run_mode,
                    batch_size,
                    trt_min_shape,
                    trt_max_shape,
                    trt_opt_shape,
                    trt_calib_mode,
                    cpu_threads,
                    enable_mkldnn,
                    threshold=threshold,
                    display_frames=display_frames)
                self.det_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
361
            if self.with_idbased_clsaction:
J
JYChen 已提交
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
                idbased_clsaction_cfg = self.cfg['ID_BASED_CLSACTION']
                model_dir = idbased_clsaction_cfg['model_dir']
                batch_size = idbased_clsaction_cfg['batch_size']
                basemode = idbased_clsaction_cfg['basemode']
                threshold = idbased_clsaction_cfg['threshold']
                self.modebase[basemode] = True
                display_frames = idbased_clsaction_cfg['display_frames']
                self.cls_action_predictor = ClsActionRecognizer(
                    model_dir,
                    device,
                    run_mode,
                    batch_size,
                    trt_min_shape,
                    trt_max_shape,
                    trt_opt_shape,
                    trt_calib_mode,
                    cpu_threads,
                    enable_mkldnn,
                    threshold=threshold,
                    display_frames=display_frames)
                self.cls_action_visual_helper = ActionVisualHelper(1)

Z
zhiboniu 已提交
384 385 386 387 388 389 390 391
            if self.with_skeleton_action:
                skeleton_action_cfg = self.cfg['SKELETON_ACTION']
                skeleton_action_model_dir = skeleton_action_cfg['model_dir']
                skeleton_action_batch_size = skeleton_action_cfg['batch_size']
                skeleton_action_frames = skeleton_action_cfg['max_frames']
                display_frames = skeleton_action_cfg['display_frames']
                self.coord_size = skeleton_action_cfg['coord_size']
                basemode = skeleton_action_cfg['basemode']
392 393
                self.modebase[basemode] = True

Z
zhiboniu 已提交
394 395
                self.skeleton_action_predictor = SkeletonActionRecognizer(
                    skeleton_action_model_dir,
J
JYChen 已提交
396 397
                    device,
                    run_mode,
Z
zhiboniu 已提交
398
                    skeleton_action_batch_size,
J
JYChen 已提交
399 400 401 402 403 404
                    trt_min_shape,
                    trt_max_shape,
                    trt_opt_shape,
                    trt_calib_mode,
                    cpu_threads,
                    enable_mkldnn,
Z
zhiboniu 已提交
405
                    window_size=skeleton_action_frames)
J
JYChen 已提交
406
                self.skeleton_action_visual_helper = ActionVisualHelper(
Z
zhiboniu 已提交
407
                    display_frames)
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424

                if self.modebase["skeletonbased"]:
                    kpt_cfg = self.cfg['KPT']
                    kpt_model_dir = kpt_cfg['model_dir']
                    kpt_batch_size = kpt_cfg['batch_size']
                    self.kpt_predictor = KeyPointDetector(
                        kpt_model_dir,
                        device,
                        run_mode,
                        kpt_batch_size,
                        trt_min_shape,
                        trt_max_shape,
                        trt_opt_shape,
                        trt_calib_mode,
                        cpu_threads,
                        enable_mkldnn,
                        use_dark=False)
Z
zhiboniu 已提交
425
                    self.kpt_buff = KeyPointBuff(skeleton_action_frames)
Z
zhiboniu 已提交
426

Z
zhiboniu 已提交
427 428 429 430 431 432 433
            if self.with_vehicleplate:
                vehicleplate_cfg = self.cfg['VEHICLE_PLATE']
                self.vehicleplate_detector = PlateRecognizer(args,
                                                             vehicleplate_cfg)
                basemode = vehicleplate_cfg['basemode']
                self.modebase[basemode] = True

434 435 436 437 438 439 440 441 442 443 444 445 446
            if self.with_vehicle_attr:
                vehicleattr_cfg = self.cfg['VEHICLE_ATTR']
                model_dir = vehicleattr_cfg['model_dir']
                batch_size = vehicleattr_cfg['batch_size']
                color_threshold = vehicleattr_cfg['color_threshold']
                type_threshold = vehicleattr_cfg['type_threshold']
                basemode = vehicleattr_cfg['basemode']
                self.modebase[basemode] = True
                self.vehicle_attr_predictor = VehicleAttr(
                    model_dir, device, run_mode, batch_size, trt_min_shape,
                    trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                    enable_mkldnn, color_threshold, type_threshold)

Z
zhiboniu 已提交
447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
            if self.with_mot or self.modebase["idbased"] or self.modebase[
                    "skeletonbased"]:
                mot_cfg = self.cfg['MOT']
                model_dir = mot_cfg['model_dir']
                tracker_config = mot_cfg['tracker_config']
                batch_size = mot_cfg['batch_size']
                basemode = mot_cfg['basemode']
                self.modebase[basemode] = True
                self.mot_predictor = SDE_Detector(
                    model_dir,
                    tracker_config,
                    device,
                    run_mode,
                    batch_size,
                    trt_min_shape,
                    trt_max_shape,
                    trt_opt_shape,
                    trt_calib_mode,
                    cpu_threads,
                    enable_mkldnn,
                    draw_center_traj=draw_center_traj,
                    secs_interval=secs_interval,
                    do_entrance_counting=do_entrance_counting)

471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
            if self.with_video_action:
                video_action_cfg = self.cfg['VIDEO_ACTION']

                basemode = video_action_cfg['basemode']
                self.modebase[basemode] = True

                video_action_model_dir = video_action_cfg['model_dir']
                video_action_batch_size = video_action_cfg['batch_size']
                short_size = video_action_cfg["short_size"]
                target_size = video_action_cfg["target_size"]

                self.video_action_predictor = VideoActionRecognizer(
                    model_dir=video_action_model_dir,
                    short_size=short_size,
                    target_size=target_size,
                    device=device,
                    run_mode=run_mode,
                    batch_size=video_action_batch_size,
                    trt_min_shape=trt_min_shape,
                    trt_max_shape=trt_max_shape,
                    trt_opt_shape=trt_opt_shape,
                    trt_calib_mode=trt_calib_mode,
                    cpu_threads=cpu_threads,
                    enable_mkldnn=enable_mkldnn)

Z
zhiboniu 已提交
496 497 498 499 500 501 502 503
            if self.with_mtmct:
                reid_cfg = self.cfg['REID']
                model_dir = reid_cfg['model_dir']
                batch_size = reid_cfg['batch_size']
                self.reid_predictor = ReID(
                    model_dir, device, run_mode, batch_size, trt_min_shape,
                    trt_max_shape, trt_opt_shape, trt_calib_mode, cpu_threads,
                    enable_mkldnn)
504

505
    def set_file_name(self, path):
W
wangguanzhong 已提交
506 507 508 509 510
        if path is not None:
            self.file_name = os.path.split(path)[-1]
        else:
            # use camera id
            self.file_name = None
511

512
    def get_result(self):
Z
zhiboniu 已提交
513
        return self.collector.get_res()
514 515 516 517 518 519

    def run(self, input):
        if self.is_video:
            self.predict_video(input)
        else:
            self.predict_image(input)
520
        self.pipe_timer.info()
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538

    def predict_image(self, input):
        # det
        # det -> attr
        batch_loop_cnt = math.ceil(
            float(len(input)) / self.det_predictor.batch_size)
        for i in range(batch_loop_cnt):
            start_index = i * self.det_predictor.batch_size
            end_index = min((i + 1) * self.det_predictor.batch_size, len(input))
            batch_file = input[start_index:end_index]
            batch_input = [decode_image(f, {})[0] for f in batch_file]

            if i > self.warmup_frame:
                self.pipe_timer.total_time.start()
                self.pipe_timer.module_time['det'].start()
            # det output format: class, score, xmin, ymin, xmax, ymax
            det_res = self.det_predictor.predict_image(
                batch_input, visual=False)
539 540
            det_res = self.det_predictor.filter_box(det_res,
                                                    self.cfg['crop_thresh'])
541 542 543 544
            if i > self.warmup_frame:
                self.pipe_timer.module_time['det'].end()
            self.pipeline_res.update(det_res, 'det')

545
            if self.with_human_attr:
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
                crop_inputs = crop_image_with_det(batch_input, det_res)
                attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['attr'].end()

                attr_res = {'output': attr_res_list}
                self.pipeline_res.update(attr_res, 'attr')

563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580
            if self.with_vehicle_attr:
                crop_inputs = crop_image_with_det(batch_input, det_res)
                vehicle_attr_res_list = []

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].start()

                for crop_input in crop_inputs:
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    vehicle_attr_res_list.extend(attr_res['output'])

                if i > self.warmup_frame:
                    self.pipe_timer.module_time['vehicle_attr'].end()

                attr_res = {'output': vehicle_attr_res_list}
                self.pipeline_res.update(attr_res, 'vehicle_attr')

581 582 583 584 585 586 587
            self.pipe_timer.img_num += len(batch_input)
            if i > self.warmup_frame:
                self.pipe_timer.total_time.end()

            if self.cfg['visual']:
                self.visualize_image(batch_file, batch_input, self.pipeline_res)

Z
zhiboniu 已提交
588
    def predict_video(self, video_file):
589 590 591
        # mot
        # mot -> attr
        # mot -> pose -> action
Z
zhiboniu 已提交
592
        capture = cv2.VideoCapture(video_file)
593
        video_out_name = 'output.mp4' if self.file_name is None else self.file_name
594 595 596 597 598 599

        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
600
        print("video fps: %d, frame_count: %d" % (fps, frame_count))
601 602 603 604 605 606 607

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
        fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
        frame_id = 0
608 609 610 611 612 613 614 615 616 617 618 619 620

        entrance, records, center_traj = None, None, None
        if self.draw_center_traj:
            center_traj = [{}]
        id_set = set()
        interval_id_set = set()
        in_id_list = list()
        out_id_list = list()
        prev_center = dict()
        records = list()
        entrance = [0, height / 2., width, height / 2.]
        video_fps = fps

621 622
        video_action_imgs = []

623 624 625 626
        if self.with_video_action:
            short_size = self.cfg["VIDEO_ACTION"]["short_size"]
            scale = ShortSizeScale(short_size)

627 628 629
        while (1):
            if frame_id % 10 == 0:
                print('frame id: ', frame_id)
630

631 632 633
            ret, frame = capture.read()
            if not ret:
                break
634
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
635

636
            if self.modebase["idbased"] or self.modebase["skeletonbased"]:
637
                if frame_id > self.warmup_frame:
638 639 640
                    self.pipe_timer.total_time.start()
                    self.pipe_timer.module_time['mot'].start()
                res = self.mot_predictor.predict_image(
641
                    [copy.deepcopy(frame_rgb)], visual=False)
642

J
JYChen 已提交
643
                if frame_id > self.warmup_frame:
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
                    self.pipe_timer.module_time['mot'].end()

                # mot output format: id, class, score, xmin, ymin, xmax, ymax
                mot_res = parse_mot_res(res)

                # flow_statistic only support single class MOT
                boxes, scores, ids = res[0]  # batch size = 1 in MOT
                mot_result = (frame_id + 1, boxes[0], scores[0],
                              ids[0])  # single class
                statistic = flow_statistic(
                    mot_result, self.secs_interval, self.do_entrance_counting,
                    video_fps, entrance, id_set, interval_id_set, in_id_list,
                    out_id_list, prev_center, records)
                records = statistic['records']

                # nothing detected
                if len(mot_res['boxes']) == 0:
                    frame_id += 1
J
JYChen 已提交
662
                    if frame_id > self.warmup_frame:
663 664 665 666 667 668 669 670 671
                        self.pipe_timer.img_num += 1
                        self.pipe_timer.total_time.end()
                    if self.cfg['visual']:
                        _, _, fps = self.pipe_timer.get_total_time()
                        im = self.visualize_video(frame, mot_res, frame_id, fps,
                                                  entrance, records,
                                                  center_traj)  # visualize
                        writer.write(im)
                        if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
672
                            cv2.imshow('Paddle-Pipeline', im)
673 674 675 676 677
                            if cv2.waitKey(1) & 0xFF == ord('q'):
                                break
                    continue

                self.pipeline_res.update(mot_res, 'mot')
J
JYChen 已提交
678
                crop_input, new_bboxes, ori_bboxes = crop_image_with_mot(
679
                    frame_rgb, mot_res)
680

Z
zhiboniu 已提交
681
                if self.with_vehicleplate:
Z
zhiboniu 已提交
682 683
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].start()
Z
zhiboniu 已提交
684 685
                    platelicense = self.vehicleplate_detector.get_platelicense(
                        crop_input)
Z
zhiboniu 已提交
686 687
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicleplate'].end()
Z
zhiboniu 已提交
688 689
                    self.pipeline_res.update(platelicense, 'vehicleplate')

690
                if self.with_human_attr:
J
JYChen 已提交
691
                    if frame_id > self.warmup_frame:
692 693 694 695 696 697 698
                        self.pipe_timer.module_time['attr'].start()
                    attr_res = self.attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['attr'].end()
                    self.pipeline_res.update(attr_res, 'attr')

699 700 701 702 703 704 705 706 707
                if self.with_vehicle_attr:
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].start()
                    attr_res = self.vehicle_attr_predictor.predict_image(
                        crop_input, visual=False)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['vehicle_attr'].end()
                    self.pipeline_res.update(attr_res, 'vehicle_attr')

Z
zhiboniu 已提交
708
                if self.with_idbased_detaction:
J
JYChen 已提交
709 710 711 712 713 714 715 716 717 718
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].start()
                    det_action_res = self.det_action_predictor.predict(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['det_action'].end()
                    self.pipeline_res.update(det_action_res, 'det_action')

                    if self.cfg['visual']:
                        self.det_action_visual_helper.update(det_action_res)
Z
zhiboniu 已提交
719 720

                if self.with_idbased_clsaction:
J
JYChen 已提交
721 722 723 724 725 726 727 728 729 730
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].start()
                    cls_action_res = self.cls_action_predictor.predict_with_mot(
                        crop_input, mot_res)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['cls_action'].end()
                    self.pipeline_res.update(cls_action_res, 'cls_action')

                    if self.cfg['visual']:
                        self.cls_action_visual_helper.update(cls_action_res)
Z
zhiboniu 已提交
731

Z
zhiboniu 已提交
732
                if self.with_skeleton_action:
Z
zhiboniu 已提交
733 734 735 736 737 738 739 740 741 742 743 744 745
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].start()
                    kpt_pred = self.kpt_predictor.predict_image(
                        crop_input, visual=False)
                    keypoint_vector, score_vector = translate_to_ori_images(
                        kpt_pred, np.array(new_bboxes))
                    kpt_res = {}
                    kpt_res['keypoint'] = [
                        keypoint_vector.tolist(), score_vector.tolist()
                    ] if len(keypoint_vector) > 0 else [[], []]
                    kpt_res['bbox'] = ori_bboxes
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['kpt'].end()
746

Z
zhiboniu 已提交
747
                    self.pipeline_res.update(kpt_res, 'kpt')
748

Z
zhiboniu 已提交
749
                    self.kpt_buff.update(kpt_res, mot_res)  # collect kpt output
750 751 752
                    state = self.kpt_buff.get_state(
                    )  # whether frame num is enough or lost tracker

Z
zhiboniu 已提交
753
                    skeleton_action_res = {}
754 755
                    if state:
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
756 757
                            self.pipe_timer.module_time[
                                'skeleton_action'].start()
758 759
                        collected_keypoint = self.kpt_buff.get_collected_keypoint(
                        )  # reoragnize kpt output with ID
Z
zhiboniu 已提交
760 761 762 763
                        skeleton_action_input = parse_mot_keypoint(
                            collected_keypoint, self.coord_size)
                        skeleton_action_res = self.skeleton_action_predictor.predict_skeleton_with_mot(
                            skeleton_action_input)
764
                        if frame_id > self.warmup_frame:
Z
zhiboniu 已提交
765 766 767
                            self.pipe_timer.module_time['skeleton_action'].end()
                        self.pipeline_res.update(skeleton_action_res,
                                                 'skeleton_action')
768 769

                    if self.cfg['visual']:
Z
zhiboniu 已提交
770 771
                        self.skeleton_action_visual_helper.update(
                            skeleton_action_res)
772 773 774

                if self.with_mtmct and frame_id % 10 == 0:
                    crop_input, img_qualities, rects = self.reid_predictor.crop_image_with_mot(
775
                        frame_rgb, mot_res)
776 777 778 779 780 781
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].start()
                    reid_res = self.reid_predictor.predict_batch(crop_input)

                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['reid'].end()
J
JYChen 已提交
782

783 784 785 786 787 788 789 790
                    reid_res_dict = {
                        'features': reid_res,
                        "qualities": img_qualities,
                        "rects": rects
                    }
                    self.pipeline_res.update(reid_res_dict, 'reid')
                else:
                    self.pipeline_res.clear('reid')
Z
zhiboniu 已提交
791

Z
zhiboniu 已提交
792
            if self.with_video_action:
793 794 795 796 797 798 799 800 801 802 803 804 805
                # get the params
                frame_len = self.cfg["VIDEO_ACTION"]["frame_len"]
                sample_freq = self.cfg["VIDEO_ACTION"]["sample_freq"]

                if sample_freq * frame_len > frame_count:  # video is too short
                    sample_freq = int(frame_count / frame_len)

                # filter the warmup frames
                if frame_id > self.warmup_frame:
                    self.pipe_timer.module_time['video_action'].start()

                # collect frames
                if frame_id % sample_freq == 0:
806
                    # Scale image
807
                    scaled_img = scale(frame_rgb)
808
                    video_action_imgs.append(scaled_img)
809 810 811 812 813 814 815 816 817 818 819 820 821 822

                # the number of collected frames is enough to predict video action
                if len(video_action_imgs) == frame_len:
                    classes, scores = self.video_action_predictor.predict(
                        video_action_imgs)
                    if frame_id > self.warmup_frame:
                        self.pipe_timer.module_time['video_action'].end()

                    video_action_res = {"class": classes[0], "score": scores[0]}
                    self.pipeline_res.update(video_action_res, 'video_action')

                    print("video_action_res:", video_action_res)

                    video_action_imgs.clear()  # next clip
Z
zhiboniu 已提交
823 824

            self.collector.append(frame_id, self.pipeline_res)
825 826 827 828 829 830 831

            if frame_id > self.warmup_frame:
                self.pipe_timer.img_num += 1
                self.pipe_timer.total_time.end()
            frame_id += 1

            if self.cfg['visual']:
832 833
                _, _, fps = self.pipe_timer.get_total_time()
                im = self.visualize_video(frame, self.pipeline_res, frame_id,
834 835
                                          fps, entrance, records,
                                          center_traj)  # visualize
836
                writer.write(im)
W
wangguanzhong 已提交
837
                if self.file_name is None:  # use camera_id
Z
zhiboniu 已提交
838
                    cv2.imshow('Paddle-Pipeline', im)
W
wangguanzhong 已提交
839 840
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
841 842 843 844

        writer.release()
        print('save result to {}'.format(out_path))

845 846 847 848 849 850 851 852
    def visualize_video(self,
                        image,
                        result,
                        frame_id,
                        fps,
                        entrance=None,
                        records=None,
                        center_traj=None):
Z
zhiboniu 已提交
853
        mot_res = copy.deepcopy(result.get('mot'))
854 855
        if mot_res is not None:
            ids = mot_res['boxes'][:, 0]
W
wangguanzhong 已提交
856
            scores = mot_res['boxes'][:, 2]
857 858 859 860 861 862
            boxes = mot_res['boxes'][:, 3:]
            boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
            boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
        else:
            boxes = np.zeros([0, 4])
            ids = np.zeros([0])
W
wangguanzhong 已提交
863
            scores = np.zeros([0])
864 865 866 867 868 869 870 871 872 873

        # single class, still need to be defaultdict type for ploting
        num_classes = 1
        online_tlwhs = defaultdict(list)
        online_scores = defaultdict(list)
        online_ids = defaultdict(list)
        online_tlwhs[0] = boxes
        online_scores[0] = scores
        online_ids[0] = ids

F
Feng Ni 已提交
874 875 876 877 878 879 880 881 882 883 884 885 886
        if mot_res is not None:
            image = plot_tracking_dict(
                image,
                num_classes,
                online_tlwhs,
                online_ids,
                online_scores,
                frame_id=frame_id,
                fps=fps,
                do_entrance_counting=self.do_entrance_counting,
                entrance=entrance,
                records=records,
                center_traj=center_traj)
887

888 889 890 891 892 893 894 895 896
        human_attr_res = result.get('attr')
        if human_attr_res is not None:
            boxes = mot_res['boxes'][:, 1:]
            human_attr_res = human_attr_res['output']
            image = visualize_attr(image, human_attr_res, boxes)
            image = np.array(image)

        vehicle_attr_res = result.get('vehicle_attr')
        if vehicle_attr_res is not None:
897
            boxes = mot_res['boxes'][:, 1:]
898 899
            vehicle_attr_res = vehicle_attr_res['output']
            image = visualize_attr(image, vehicle_attr_res, boxes)
900 901
            image = np.array(image)

Z
zhiboniu 已提交
902 903 904 905 906 907 908
        vehicleplate_res = result.get('vehicleplate')
        if vehicleplate_res:
            boxes = mot_res['boxes'][:, 1:]
            image = visualize_vehicleplate(image, vehicleplate_res['plate'],
                                           boxes)
            image = np.array(image)

J
JYChen 已提交
909 910 911 912 913 914 915 916
        kpt_res = result.get('kpt')
        if kpt_res is not None:
            image = visualize_pose(
                image,
                kpt_res,
                visual_thresh=self.cfg['kpt_thresh'],
                returnimg=True)

917
        video_action_res = result.get('video_action')
J
JYChen 已提交
918
        if video_action_res is not None:
919 920 921 922 923 924
            video_action_score = None
            if video_action_res and video_action_res["class"] == 1:
                video_action_score = video_action_res["score"]
            image = visualize_action(
                image,
                mot_res['boxes'],
J
JYChen 已提交
925
                action_visual_collector=None,
926 927 928
                action_text="SkeletonAction",
                video_action_score=video_action_score,
                video_action_text="Fight")
J
JYChen 已提交
929

J
JYChen 已提交
930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952
        visual_helper_for_display = []
        action_to_display = []

        skeleton_action_res = result.get('skeleton_action')
        if skeleton_action_res is not None:
            visual_helper_for_display.append(self.skeleton_action_visual_helper)
            action_to_display.append("Falling")

        det_action_res = result.get('det_action')
        if det_action_res is not None:
            visual_helper_for_display.append(self.det_action_visual_helper)
            action_to_display.append("Smoking")

        cls_action_res = result.get('cls_action')
        if cls_action_res is not None:
            visual_helper_for_display.append(self.cls_action_visual_helper)
            action_to_display.append("Calling")

        if len(visual_helper_for_display) > 0:
            image = visualize_action(image, mot_res['boxes'],
                                     visual_helper_for_display,
                                     action_to_display)

953 954 955 956 957
        return image

    def visualize_image(self, im_files, images, result):
        start_idx, boxes_num_i = 0, 0
        det_res = result.get('det')
958 959 960
        human_attr_res = result.get('attr')
        vehicle_attr_res = result.get('vehicle_attr')

961 962 963 964 965 966 967 968 969 970 971
        for i, (im_file, im) in enumerate(zip(im_files, images)):
            if det_res is not None:
                det_res_i = {}
                boxes_num_i = det_res['boxes_num'][i]
                det_res_i['boxes'] = det_res['boxes'][start_idx:start_idx +
                                                      boxes_num_i, :]
                im = visualize_box_mask(
                    im,
                    det_res_i,
                    labels=['person'],
                    threshold=self.cfg['crop_thresh'])
972 973
                im = np.ascontiguousarray(np.copy(im))
                im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
974 975 976 977 978 979 980 981 982
            if human_attr_res is not None:
                human_attr_res_i = human_attr_res['output'][start_idx:start_idx
                                                            + boxes_num_i]
                im = visualize_attr(im, human_attr_res_i, det_res_i['boxes'])
            if vehicle_attr_res is not None:
                vehicle_attr_res_i = vehicle_attr_res['output'][
                    start_idx:start_idx + boxes_num_i]
                im = visualize_attr(im, vehicle_attr_res_i, det_res_i['boxes'])

983 984 985 986
            img_name = os.path.split(im_file)[-1]
            if not os.path.exists(self.output_dir):
                os.makedirs(self.output_dir)
            out_path = os.path.join(self.output_dir, img_name)
987
            cv2.imwrite(out_path, im)
988 989 990 991 992 993 994
            print("save result to: " + out_path)
            start_idx += boxes_num_i


def main():
    cfg = merge_cfg(FLAGS)
    print_arguments(cfg)
995

Z
zhiboniu 已提交
996
    pipeline = Pipeline(FLAGS, cfg)
997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
    pipeline.run()


if __name__ == '__main__':
    paddle.enable_static()
    parser = argsparser()
    FLAGS = parser.parse_args()
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"

    main()