video_infer.py 5.4 KB
Newer Older
1 2 3 4 5 6
import argparse
import os
import os.path as osp
import cv2
import numpy as np

C
chenguowei01 已提交
7
from utils.humanseg_postprocess import postprocess, threshold_mask
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
import models
import transforms


def parse_args():
    parser = argparse.ArgumentParser(description='HumanSeg inference for video')
    parser.add_argument(
        '--model_dir',
        dest='model_dir',
        help='Model path for inference',
        type=str)
    parser.add_argument(
        '--video_path',
        dest='video_path',
        help=
        'Video path for inference, camera will be used if the path not existing',
        type=str,
        default=None)
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the inference results',
        type=str,
        default='./output')

    return parser.parse_args()


C
chenguowei01 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
def predict(img, model, test_transforms):
    model.arrange_transform(transforms=test_transforms, mode='test')
    img, im_info = test_transforms(img)
    img = np.expand_dims(img, axis=0)
    result = model.exe.run(
        model.test_prog,
        feed={'image': img},
        fetch_list=list(model.test_outputs.values()))
    score_map = result[1]
    score_map = np.squeeze(score_map, axis=0)
    score_map = np.transpose(score_map, (1, 2, 0))
    return score_map, im_info


def recover(img, im_info):
    keys = list(im_info.keys())
    for k in keys[::-1]:
        if k == 'shape_before_resize':
            h, w = im_info[k][0], im_info[k][1]
            img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
        elif k == 'shape_before_padding':
            h, w = im_info[k][0], im_info[k][1]
            img = img[0:h, 0:w]
    return img


62
def video_infer(args):
C
chenguowei01 已提交
63 64 65 66

    resize_h = 192
    resize_w = 192

67
    test_transforms = transforms.Compose(
C
chenguowei01 已提交
68
        [transforms.Resize((resize_w, resize_h)),
69 70 71 72 73 74 75 76 77 78 79
         transforms.Normalize()])
    model = models.load_model(args.model_dir)
    if not args.video_path:
        cap = cv2.VideoCapture(0)
    else:
        cap = cv2.VideoCapture(args.video_path)
    if not cap.isOpened():
        raise IOError("Error opening video stream or file, "
                      "--video_path whether existing: {}"
                      " or camera whether working".format(args.video_path))
        return
C
chenguowei01 已提交
80 81 82 83 84 85 86 87 88 89

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
    prev_gray = np.zeros((resize_h, resize_w), np.uint8)
    prev_cfd = np.zeros((resize_h, resize_w), np.float32)
    is_init = True

    fps = cap.get(cv2.CAP_PROP_FPS)
90
    if args.video_path:
C
chenguowei01 已提交
91

92
        # 用于保存预测结果视频
C
chenguowei01 已提交
93 94
        if not osp.exists(args.save_dir):
            os.makedirs(args.save_dir)
95 96 97 98 99 100 101
        out = cv2.VideoWriter(
            osp.join(args.save_dir, 'result.avi'),
            cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height))
        # 开始获取视频帧
        while cap.isOpened():
            ret, frame = cap.read()
            if ret:
C
chenguowei01 已提交
102
                score_map, im_info = predict(frame, model, test_transforms)
C
chenguowei01 已提交
103 104 105 106 107 108 109 110 111 112 113 114
                cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
                scoremap = 255 * score_map[:, :, 1]
                optflow_map = postprocess(cur_gray, scoremap, prev_gray, prev_cfd, \
                        disflow, is_init)
                prev_gray = cur_gray.copy()
                prev_cfd = optflow_map.copy()
                is_init = False
                optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
                optflow_map = threshold_mask(
                    optflow_map, thresh_bg=0.2, thresh_fg=0.8)
                img_mat = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
C
chenguowei01 已提交
115
                img_mat = recover(img_mat, im_info)
C
chenguowei01 已提交
116 117 118 119
                bg_im = np.ones_like(img_mat) * 255
                comb = (img_mat * frame + (1 - img_mat) * bg_im).astype(
                    np.uint8)
                out.write(comb)
120 121 122 123 124 125 126 127 128
            else:
                break
        cap.release()
        out.release()

    else:
        while cap.isOpened():
            ret, frame = cap.read()
            if ret:
C
chenguowei01 已提交
129
                score_map, im_info = predict(frame, model, test_transforms)
C
chenguowei01 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142
                cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
                scoremap = 255 * score_map[:, :, 1]
                optflow_map = postprocess(cur_gray, scoremap, prev_gray, prev_cfd, \
                                          disflow, is_init)
                prev_gray = cur_gray.copy()
                prev_cfd = optflow_map.copy()
                is_init = False
                # optflow_map = optflow_map/255.0
                optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
                optflow_map = threshold_mask(
                    optflow_map, thresh_bg=0.2, thresh_fg=0.8)
                img_mat = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
C
chenguowei01 已提交
143
                img_mat = recover(img_mat, im_info)
C
chenguowei01 已提交
144 145 146 147
                bg_im = np.ones_like(img_mat) * 255
                comb = (img_mat * frame + (1 - img_mat) * bg_im).astype(
                    np.uint8)
                cv2.imshow('HumanSegmentation', comb)
148 149 150 151 152 153 154 155 156 157
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
            else:
                break
        cap.release()


if __name__ == "__main__":
    args = parse_args()
    video_infer(args)