mot_keypoint_unite_infer.py 10.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import cv2
import math
import numpy as np
import paddle
20
import copy
21 22 23

from mot_keypoint_unite_utils import argsparser
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint
24
from visualize import draw_pose
25 26 27 28
from benchmark_utils import PaddleInferBenchmark
from utils import Timer

from tracker import JDETracker
29
from mot_jde_infer import JDE_Detector, write_mot_results
G
George Ni 已提交
30
from infer import Detector, PredictConfig, print_arguments, get_test_images
31 32
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.modeling.mot.utils import Timer as FPSTimer
G
George Ni 已提交
33
from utils import get_current_memory_mb
34
from det_keypoint_unite_infer import predict_with_given_det, bench_log
G
George Ni 已提交
35

36 37 38 39 40
# Global dictionary
KEYPOINT_SUPPORT_MODELS = {
    'HigherHRNet': 'keypoint_bottomup',
    'HRNet': 'keypoint_topdown'
}
G
George Ni 已提交
41

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

def convert_mot_to_det(tlwhs, scores):
    results = {}
    num_mot = len(tlwhs)
    xyxys = copy.deepcopy(tlwhs)
    for xyxy in xyxys.copy():
        xyxy[2:] = xyxy[2:] + xyxy[:2]
    # support single class now
    results['boxes'] = np.vstack(
        [np.hstack([0, scores[i], xyxys[i]]) for i in range(num_mot)])
    return results


def mot_keypoint_unite_predict_image(mot_model,
                                     keypoint_model,
                                     image_list,
                                     keypoint_batch_size=1):
G
George Ni 已提交
59
    image_list.sort()
G
George Ni 已提交
60 61 62 63
    for i, img_file in enumerate(image_list):
        frame = cv2.imread(img_file)

        if FLAGS.run_benchmark:
64 65
            online_tlwhs, online_scores, online_ids = mot_model.predict(
                [frame], FLAGS.mot_threshold, warmup=10, repeats=10)
G
George Ni 已提交
66 67 68 69 70
            cm, gm, gu = get_current_memory_mb()
            mot_model.cpu_mem += cm
            mot_model.gpu_mem += gm
            mot_model.gpu_util += gu

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
        else:
            online_tlwhs, online_scores, online_ids = mot_model.predict(
                [frame], FLAGS.mot_threshold)

        keypoint_arch = keypoint_model.pred_config.arch
        if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown':
            results = convert_mot_to_det(online_tlwhs, online_scores)
            keypoint_results = predict_with_given_det(
                frame, results, keypoint_model, keypoint_batch_size,
                FLAGS.mot_threshold, FLAGS.keypoint_threshold,
                FLAGS.run_benchmark)

        else:
            warmup = 10 if FLAGS.run_benchmark else 0
            repeats = 10 if FLAGS.run_benchmark else 1
            keypoint_results = keypoint_model.predict(
                [frame],
                FLAGS.keypoint_threshold,
                warmup=warmup,
                repeats=repeats)

        if FLAGS.run_benchmark:
G
George Ni 已提交
93 94 95 96 97 98 99 100 101
            cm, gm, gu = get_current_memory_mb()
            keypoint_model.cpu_mem += cm
            keypoint_model.gpu_mem += gm
            keypoint_model.gpu_util += gu
        else:
            im = draw_pose(
                frame,
                keypoint_results,
                visual_thread=FLAGS.keypoint_threshold,
102 103 104 105
                returnimg=True,
                ids=online_ids
                if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown'
                else None)
G
George Ni 已提交
106 107 108

            online_im = mot_vis.plot_tracking(
                im, online_tlwhs, online_ids, online_scores, frame_id=i)
109

G
George Ni 已提交
110 111 112
            if FLAGS.save_images:
                if not os.path.exists(FLAGS.output_dir):
                    os.makedirs(FLAGS.output_dir)
113 114 115 116
                img_name = os.path.split(img_file)[-1]
                out_path = os.path.join(FLAGS.output_dir, img_name)
                cv2.imwrite(out_path, online_im)
                print("save result to: " + out_path)
117 118


119 120 121 122
def mot_keypoint_unite_predict_video(mot_model,
                                     keypoint_model,
                                     camera_id,
                                     keypoint_batch_size=1):
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    if camera_id != -1:
        capture = cv2.VideoCapture(camera_id)
        video_name = 'output.mp4'
    else:
        capture = cv2.VideoCapture(FLAGS.video_file)
        video_name = os.path.split(FLAGS.video_file)[-1]
    fps = 30
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    print('frame_count', frame_count)
    width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    # yapf: disable
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    # yapf: enable
    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)
    out_path = os.path.join(FLAGS.output_dir, video_name)
G
George Ni 已提交
140 141
    if not FLAGS.save_images:
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
142 143 144 145 146 147 148 149 150 151 152 153
    frame_id = 0
    timer_mot = FPSTimer()
    timer_kp = FPSTimer()
    timer_mot_kp = FPSTimer()
    mot_results = []
    while (1):
        ret, frame = capture.read()
        if not ret:
            break
        timer_mot_kp.tic()
        timer_mot.tic()
        online_tlwhs, online_scores, online_ids = mot_model.predict(
154
            [frame], FLAGS.mot_threshold)
155 156 157 158 159 160
        timer_mot.toc()
        mot_results.append(
            (frame_id + 1, online_tlwhs, online_scores, online_ids))
        mot_fps = 1. / timer_mot.average_time

        timer_kp.tic()
161 162 163 164 165 166 167 168 169 170 171 172

        keypoint_arch = keypoint_model.pred_config.arch
        if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown':
            results = convert_mot_to_det(online_tlwhs, online_scores)
            keypoint_results = predict_with_given_det(
                frame, results, keypoint_model, keypoint_batch_size,
                FLAGS.mot_threshold, FLAGS.keypoint_threshold,
                FLAGS.run_benchmark)

        else:
            keypoint_results = keypoint_model.predict([frame],
                                                      FLAGS.keypoint_threshold)
173 174 175 176 177 178 179 180 181
        timer_kp.toc()
        timer_mot_kp.toc()
        kp_fps = 1. / timer_kp.average_time
        mot_kp_fps = 1. / timer_mot_kp.average_time

        im = draw_pose(
            frame,
            keypoint_results,
            visual_thread=FLAGS.keypoint_threshold,
182
            returnimg=True,
G
George Ni 已提交
183 184 185
            ids=online_ids
            if KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown' else
            None)
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205

        online_im = mot_vis.plot_tracking(
            im,
            online_tlwhs,
            online_ids,
            online_scores,
            frame_id=frame_id,
            fps=mot_kp_fps)

        im = np.array(online_im)

        frame_id += 1
        print('detect frame:%d' % (frame_id))

        if FLAGS.save_images:
            save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            cv2.imwrite(
                os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
G
George Ni 已提交
206 207
        else:
            writer.write(im)
208 209 210 211 212 213 214 215
        if camera_id != -1:
            cv2.imshow('Tracking and keypoint results', im)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
    if FLAGS.save_mot_txts:
        result_filename = os.path.join(FLAGS.output_dir,
                                       video_name.split('.')[-2] + '.txt')
        write_mot_results(result_filename, mot_results)
G
George Ni 已提交
216 217 218

    if FLAGS.save_images:
        save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
F
Feng Ni 已提交
219
        cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
G
George Ni 已提交
220 221 222 223 224
            save_dir, out_path)
        os.system(cmd_str)
        print('Save video in {}.'.format(out_path))
    else:
        writer.release()
225 226 227


def main():
G
George Ni 已提交
228
    pred_config = PredictConfig(FLAGS.mot_model_dir)
229
    mot_model = JDE_Detector(
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
        pred_config,
        FLAGS.mot_model_dir,
        device=FLAGS.device,
        run_mode=FLAGS.run_mode,
        trt_min_shape=FLAGS.trt_min_shape,
        trt_max_shape=FLAGS.trt_max_shape,
        trt_opt_shape=FLAGS.trt_opt_shape,
        trt_calib_mode=FLAGS.trt_calib_mode,
        cpu_threads=FLAGS.cpu_threads,
        enable_mkldnn=FLAGS.enable_mkldnn)

    pred_config = PredictConfig_KeyPoint(FLAGS.keypoint_model_dir)
    keypoint_model = KeyPoint_Detector(
        pred_config,
        FLAGS.keypoint_model_dir,
        device=FLAGS.device,
        run_mode=FLAGS.run_mode,
247
        batch_size=FLAGS.keypoint_batch_size,
248 249 250 251 252 253 254 255 256 257 258
        trt_min_shape=FLAGS.trt_min_shape,
        trt_max_shape=FLAGS.trt_max_shape,
        trt_opt_shape=FLAGS.trt_opt_shape,
        trt_calib_mode=FLAGS.trt_calib_mode,
        cpu_threads=FLAGS.cpu_threads,
        enable_mkldnn=FLAGS.enable_mkldnn,
        use_dark=FLAGS.use_dark)

    # predict from video file or camera video stream
    if FLAGS.video_file is not None or FLAGS.camera_id != -1:
        mot_keypoint_unite_predict_video(mot_model, keypoint_model,
259 260
                                         FLAGS.camera_id,
                                         FLAGS.keypoint_batch_size)
261
    else:
G
George Ni 已提交
262 263
        # predict from image
        img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
264 265
        mot_keypoint_unite_predict_image(mot_model, keypoint_model, img_list,
                                         FLAGS.keypoint_batch_size)
G
George Ni 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284

        if not FLAGS.run_benchmark:
            mot_model.det_times.info(average=True)
            keypoint_model.det_times.info(average=True)
        else:
            mode = FLAGS.run_mode
            mot_model_dir = FLAGS.mot_model_dir
            mot_model_info = {
                'model_name': mot_model_dir.strip('/').split('/')[-1],
                'precision': mode.split('_')[-1]
            }
            bench_log(mot_model, img_list, mot_model_info, name='MOT')

            keypoint_model_dir = FLAGS.keypoint_model_dir
            keypoint_model_info = {
                'model_name': keypoint_model_dir.strip('/').split('/')[-1],
                'precision': mode.split('_')[-1]
            }
            bench_log(keypoint_model, img_list, keypoint_model_info, 'KeyPoint')
285 286 287 288 289 290 291 292 293 294 295 296


if __name__ == '__main__':
    paddle.enable_static()
    parser = argsparser()
    FLAGS = parser.parse_args()
    print_arguments(FLAGS)
    FLAGS.device = FLAGS.device.upper()
    assert FLAGS.device in ['CPU', 'GPU', 'XPU'
                            ], "device should be CPU, GPU or XPU"

    main()