infer.py 14.0 KB
Newer Older
S
sjtubinlong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
sjtubinlong 已提交
15 16
# ==============================================================================
"""Python Inference solution for realtime humansegmentation"""
S
sjtubinlong 已提交
17 18

import os
S
sjtubinlong 已提交
19
import argparse
S
sjtubinlong 已提交
20 21 22 23 24 25
import numpy as np
import cv2

import paddle.fluid as fluid


S
sjtubinlong 已提交
26
def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
S
sjtubinlong 已提交
27 28 29 30 31 32 33 34 35 36
    """Optical flow tracking for human segmentation
    Args:
        pre_gray: Grayscale of previous frame.
        cur_gray: Grayscale of current frame.
        prev_cfd: Optical flow of previous frame.
        dl_weights: Merged weights data.
        disflow: A data structure represents optical flow.
    Returns:
        is_track: Binary graph, whethe a pixel matched with a optical flow point.
        track_cfd: tracking optical flow image.
S
sjtubinlong 已提交
37 38
    """
    check_thres = 8
S
sjtubinlong 已提交
39
    hgt, wdh = pre_gray.shape[:2]
S
sjtubinlong 已提交
40 41
    track_cfd = np.zeros_like(prev_cfd)
    is_track = np.zeros_like(pre_gray)
S
sjtubinlong 已提交
42
    # compute forward optical flow
S
sjtubinlong 已提交
43
    flow_fw = disflow.calc(pre_gray, cur_gray, None)
S
sjtubinlong 已提交
44
    # compute backword optical flow
S
sjtubinlong 已提交
45
    flow_bw = disflow.calc(cur_gray, pre_gray, None)
S
sjtubinlong 已提交
46
    get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -0.5)
S
sjtubinlong 已提交
47 48
    for row in range(hgt):
        for col in range(wdh):
S
sjtubinlong 已提交
49 50
            # Calculate new coordinate after optfow process.
            # (row, col) -> (cur_x, cur_y)
S
sjtubinlong 已提交
51
            fxy_fw = flow_fw[row, col]
S
sjtubinlong 已提交
52
            dx_fw = get_round(fxy_fw[0])
S
sjtubinlong 已提交
53
            cur_x = dx_fw + col
S
sjtubinlong 已提交
54
            dy_fw = get_round(fxy_fw[1])
S
sjtubinlong 已提交
55 56
            cur_y = dy_fw + row
            if cur_x < 0 or cur_x >= wdh or cur_y < 0 or cur_y >= hgt:
S
sjtubinlong 已提交
57 58 59 60
                continue
            fxy_bw = flow_bw[cur_y, cur_x]
            dx_bw = get_round(fxy_bw[0])
            dy_bw = get_round(fxy_bw[1])
S
sjtubinlong 已提交
61
            # Filt the Optical flow point with a threshold
S
sjtubinlong 已提交
62 63
            lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) + (dx_fw + dx_bw) * (dx_fw + dx_bw))
            if lmt >= check_thres:
S
sjtubinlong 已提交
64
                continue
S
sjtubinlong 已提交
65
            # Downgrade still points
S
sjtubinlong 已提交
66 67 68
            if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(dx_bw) <= 0:
                dl_weights[cur_y, cur_x] = 0.05
            is_track[cur_y, cur_x] = 1
S
sjtubinlong 已提交
69
            track_cfd[cur_y, cur_x] = prev_cfd[row, col]
S
sjtubinlong 已提交
70 71 72 73
    return track_cfd, is_track, dl_weights


def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
S
sjtubinlong 已提交
74 75 76 77 78 79 80 81
    """Fusion of Optical flow track and segmentation
    Args:
        track_cfd: Optical flow track.
        dl_cfd: Segmentation result of current frame.
        dl_weights: Merged weights data.
        is_track: Binary graph, whethe a pixel matched with a optical flow point.
    Returns:
        cur_cfd: Fusion of Optical flow track and segmentation result.
S
sjtubinlong 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94
    """
    cur_cfd = dl_cfd.copy()
    idxs = np.where(is_track > 0)
    for i in range(len(idxs)):
        x, y = idxs[0][i], idxs[1][i]
        dl_score = dl_cfd[y, x]
        track_score = track_cfd[y, x]
        if dl_score > 0.9 or dl_score < 0.1:
            if dl_weights[x, y] < 0.1:
                cur_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score
            else:
                cur_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score
        else:
S
sjtubinlong 已提交
95
            cur_cfd[x, y] = dl_weights[x, y] * dl_score + (1 - dl_weights[x, y]) * track_score
S
sjtubinlong 已提交
96 97 98 99
    return cur_cfd


def threshold_mask(img, thresh_bg, thresh_fg):
S
sjtubinlong 已提交
100 101 102 103 104 105 106
    """Threshold mask for image foreground and background
    Args:
        img : Original image, an instance of np.uint8 array.
        thresh_bg : Threshold for background, set to 0 when less than it.
        thresh_fg : Threshold for foreground, set to 1 when greater than it.
    Returns:
        dst : Image after set thresthold mask, ans instance of np.float32 array.
S
sjtubinlong 已提交
107 108 109 110 111 112 113
    """
    dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
    dst[np.where(dst > 1)] = 1
    dst[np.where(dst < 0)] = 0
    return dst.astype(np.float32)


S
sjtubinlong 已提交
114
def optflow_handle(cur_gray, scoremap, is_init):
S
sjtubinlong 已提交
115 116 117 118 119 120 121
    """Processing optical flow and segmentation result.
    Args:
        cur_gray : Grayscale of current frame.
        scoremap : Segmentation result of current frame.
        is_init : True only when process the first frame of a video.
    Returns:
        dst : Image after set thresthold mask, ans instance of np.float32 array.
S
sjtubinlong 已提交
122
    """
S
sjtubinlong 已提交
123
    width, height = scoremap.shape[0], scoremap.shape[1]
S
sjtubinlong 已提交
124 125 126 127
    disflow = cv2.DISOpticalFlow_create(
        cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
    prev_gray = np.zeros((height, width), np.uint8)
    prev_cfd = np.zeros((height, width), np.float32)
S
sjtubinlong 已提交
128 129 130
    cur_cfd = scoremap.copy()
    if is_init:
        is_init = False
S
sjtubinlong 已提交
131
        if height <= 64 or width <= 64:
S
sjtubinlong 已提交
132
            disflow.setFinestScale(1)
S
sjtubinlong 已提交
133
        elif height <= 160 or width <= 160:
S
sjtubinlong 已提交
134 135 136 137
            disflow.setFinestScale(2)
        else:
            disflow.setFinestScale(3)
        fusion_cfd = cur_cfd
S
sjtubinlong 已提交
138
    else:
S
sjtubinlong 已提交
139 140
        weights = np.ones((width, height), np.float32) * 0.3
        track_cfd, is_track, weights = human_seg_tracking(
S
sjtubinlong 已提交
141
            prev_gray, cur_gray, prev_cfd, weights, disflow)
S
sjtubinlong 已提交
142
        fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights, is_track)
S
sjtubinlong 已提交
143
    fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0)
S
sjtubinlong 已提交
144
    return fusion_cfd
S
sjtubinlong 已提交
145 146 147


class HumanSeg:
S
sjtubinlong 已提交
148 149 150 151 152 153
    """Human Segmentation Class
    This Class instance will load the inference model and do inference
    on input image object.

    It includes the key stages for a object segmentation inference task.
    Call run_predict on your image and it will return a processed image.
S
sjtubinlong 已提交
154
    """
S
sjtubinlong 已提交
155
    def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False):
S
sjtubinlong 已提交
156

S
sjtubinlong 已提交
157 158 159
        self.mean = np.array(mean).reshape((3, 1, 1))
        self.scale = np.array(scale).reshape((3, 1, 1))
        self.eval_size = eval_size
S
sjtubinlong 已提交
160 161 162
        self.load_model(model_dir, use_gpu)

    def load_model(self, model_dir, use_gpu):
S
sjtubinlong 已提交
163 164 165 166
        """Load paddle inference model.
        Args:
            model_dir: The inference model path includes `__model__` and `__params__`.
            use_gpu: Enable gpu if use_gpu is True
S
sjtubinlong 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179
        """
        prog_file = os.path.join(model_dir, '__model__')
        params_file = os.path.join(model_dir, '__params__')
        config = fluid.core.AnalysisConfig(prog_file, params_file)
        if use_gpu:
            config.enable_use_gpu(100, 0)
            config.switch_ir_optim(True)
        else:
            config.disable_gpu()
        config.disable_glog_info()
        config.switch_specify_input_names(True)
        config.enable_memory_optim()
        self.predictor = fluid.core.create_paddle_predictor(config)
S
sjtubinlong 已提交
180

S
sjtubinlong 已提交
181
    def preprocess(self, image):
S
sjtubinlong 已提交
182 183 184 185 186 187
        """Preprocess input image.
        Convert hwc_rgb to chw_bgr.
        Args:
            image: The input opencv image object.
        Returns:
            A preprocessed image object.
S
sjtubinlong 已提交
188 189
        """
        img_mat = cv2.resize(
S
sjtubinlong 已提交
190
            image, self.eval_size, interpolation=cv2.INTER_LINEAR)
S
sjtubinlong 已提交
191
        # HWC -> CHW
S
sjtubinlong 已提交
192 193
        img_mat = img_mat.swapaxes(1, 2)
        img_mat = img_mat.swapaxes(0, 1)
S
sjtubinlong 已提交
194
        # Convert to float
S
sjtubinlong 已提交
195 196 197 198 199 200 201 202
        img_mat = img_mat[:, :, :].astype('float32')
        # img_mat = (img_mat - mean) * scale
        img_mat = img_mat - self.mean
        img_mat = img_mat * self.scale
        img_mat = img_mat[np.newaxis, :, :, :]
        return img_mat

    def postprocess(self, image, output_data):
S
sjtubinlong 已提交
203 204 205 206 207 208
        """Postprocess the inference result and original input image.
        Args:
             image: The original opencv image object.
             output_data: The inference output of paddle's humansegmentation model.
        Returns:
             The result merged original image and segmentation result with optical-flow improvement.
S
sjtubinlong 已提交
209
        """
S
sjtubinlong 已提交
210 211 212 213
        scoremap = output_data[0, 1, :, :]
        scoremap = (scoremap * 255).astype(np.uint8)
        ori_h, ori_w = image.shape[0], image.shape[1]
        evl_h, evl_w = self.eval_size[0], self.eval_size[1]
S
sjtubinlong 已提交
214
        # optical flow processing
S
sjtubinlong 已提交
215 216
        cur_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        cur_gray = cv2.resize(cur_gray, (evl_w, evl_h))
S
sjtubinlong 已提交
217
        optflow_map = optflow_handle(cur_gray, scoremap, False)
S
sjtubinlong 已提交
218 219 220 221
        optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
        optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8)
        optflow_map = cv2.resize(optflow_map, (ori_w, ori_h))
        optflow_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
S
sjtubinlong 已提交
222 223
        bg_im = np.ones_like(optflow_map) * 255
        comb = (optflow_map * image + (1 - optflow_map) * bg_im).astype(np.uint8)
S
sjtubinlong 已提交
224
        return comb
S
sjtubinlong 已提交
225

S
sjtubinlong 已提交
226
    def run_predict(self, image):
S
sjtubinlong 已提交
227 228 229 230 231 232
        """Run Predicting on an opencv image object.
        Preprocess the image, do inference, and then postprocess the infering output.
        Args:
             image: A valid opencv image object.
        Returns:
             The segmentation result which represents as an opencv image object.
S
sjtubinlong 已提交
233
        """
S
sjtubinlong 已提交
234
        im_mat = self.preprocess(image)
S
sjtubinlong 已提交
235
        im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32'))
S
sjtubinlong 已提交
236 237
        output_data = self.predictor.run([im_tensor])[0]
        output_data = output_data.as_ndarray()
S
sjtubinlong 已提交
238
        return self.postprocess(image, output_data)
S
sjtubinlong 已提交
239

S
sjtubinlong 已提交
240

S
sjtubinlong 已提交
241
def predict_image(seg, image_path):
S
sjtubinlong 已提交
242 243 244 245 246 247 248
    """Do Predicting on a image file.
    Decoding the image file and do predicting on it.
    The result will be saved as `result.jpeg`.
    Args:
        seg: The HumanSeg Object which holds a inference model.
            Do preprocessing / predicting / postprocessing on a input image object.
        image_path: Path of the image file needs to be processed.
S
sjtubinlong 已提交
249 250 251 252
    """
    img_mat = cv2.imread(image_path)
    img_mat = seg.run_predict(img_mat)
    cv2.imwrite('result.jpeg', img_mat)
S
sjtubinlong 已提交
253 254


S
sjtubinlong 已提交
255
def predict_video(seg, video_path):
S
sjtubinlong 已提交
256 257 258 259 260 261 262
    """Do Predicting on a video file.
    Decoding the video file and do predicting on each frame.
    All result will be saved as `result.avi`.
    Args:
        seg: The HumanSeg Object which holds a inference model.
            Do preprocessing / predicting / postprocessing on a input image object.
        video_path: Path of a video file needs to be processed.
S
sjtubinlong 已提交
263
    """
S
sjtubinlong 已提交
264
    cap = cv2.VideoCapture(video_path)
S
sjtubinlong 已提交
265
    if not cap.isOpened():
S
sjtubinlong 已提交
266 267
        print("Error opening video stream or file")
        return
S
sjtubinlong 已提交
268 269
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
S
sjtubinlong 已提交
270 271 272
    fps = cap.get(cv2.CAP_PROP_FPS)
    # Result Video Writer
    out = cv2.VideoWriter('result.avi',
S
sjtubinlong 已提交
273
                          cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
S
sjtubinlong 已提交
274
                          (width, height))
S
sjtubinlong 已提交
275
    # Start capturing from video
S
sjtubinlong 已提交
276
    while cap.isOpened():
S
sjtubinlong 已提交
277
        ret, frame = cap.read()
S
sjtubinlong 已提交
278 279 280
        if ret:
            img_mat = seg.run_predict(frame)
            out.write(img_mat)
S
sjtubinlong 已提交
281
        else:
S
sjtubinlong 已提交
282 283 284
            break
    cap.release()
    out.release()
S
sjtubinlong 已提交
285

S
sjtubinlong 已提交
286

S
sjtubinlong 已提交
287
def predict_camera(seg):
S
sjtubinlong 已提交
288 289 290 291 292 293
    """Do Predicting on a camera video stream.
    Capturing each video frame from camera and do predicting on it.
    All result frames will be shown in a GUI window.
    Args:
        seg: The HumanSeg Object which holds a inference model.
            Do preprocessing / predicting / postprocessing on a input image object.
S
sjtubinlong 已提交
294
    """
S
sjtubinlong 已提交
295
    cap = cv2.VideoCapture(0)
S
sjtubinlong 已提交
296
    if not cap.isOpened():
S
sjtubinlong 已提交
297 298 299
        print("Error opening video stream or file")
        return
    # Start capturing from video
S
sjtubinlong 已提交
300
    while cap.isOpened():
S
sjtubinlong 已提交
301
        ret, frame = cap.read()
S
sjtubinlong 已提交
302
        if ret:
S
sjtubinlong 已提交
303
            img_mat = seg.run_predict(frame)
S
sjtubinlong 已提交
304
            cv2.imshow('HumanSegmentation', img_mat)
S
sjtubinlong 已提交
305 306
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
S
sjtubinlong 已提交
307
        else:
S
sjtubinlong 已提交
308 309 310
            break
    cap.release()

S
sjtubinlong 已提交
311

S
sjtubinlong 已提交
312
def main(args):
S
sjtubinlong 已提交
313 314 315 316 317 318 319 320
    """Real Entrypoint of the script.
    Load the human segmentation inference model and do predicting on the input resource.
    Support three types of input: camera stream / video file / image file.
    Args:
      args: The command-line args for inference model.
           Open camera and do predicting on camera stream while `args.use_camera` is true.
           Open the video file and do predicting on it while `args.video_path` is valid.
           Open the image file and do predicting on it while `args.img_path` is valid.
S
sjtubinlong 已提交
321
    """
S
sjtubinlong 已提交
322 323
    model_dir = args.model_dir
    use_gpu = args.use_gpu
S
sjtubinlong 已提交
324

S
sjtubinlong 已提交
325
    # Init model
S
sjtubinlong 已提交
326 327 328 329
    mean = [104.008, 116.669, 122.675]
    scale = [1.0, 1.0, 1.0]
    eval_size = (192, 192)
    seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu)
S
sjtubinlong 已提交
330
    if args.use_camera:
S
sjtubinlong 已提交
331
        # if enable input video stream from camera
S
sjtubinlong 已提交
332 333
        predict_camera(seg)
    elif args.video_path:
S
sjtubinlong 已提交
334
        # if video_path valid, do predicting on the video
S
sjtubinlong 已提交
335 336 337 338
        predict_video(seg, args.video_path)
    elif args.img_path:
        # if img_path valid, do predicting on the image
        predict_image(seg, args.img_path)
S
sjtubinlong 已提交
339 340


S
sjtubinlong 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
def parse_args():
    """Parsing command-line argments
    """
    parser = argparse.ArgumentParser('Realtime Human Segmentation')
    parser.add_argument('--model_dir',
                        type=str,
                        default='',
                        help='path of human segmentation model')
    parser.add_argument('--img_path',
                        type=str,
                        default='',
                        help='path of input image')
    parser.add_argument('--video_path',
                        type=str,
                        default='',
                        help='path of input video')
    parser.add_argument('--use_camera',
                        type=bool,
                        default=False,
                        help='input video stream from camera')
    parser.add_argument('--use_gpu',
                        type=bool,
                        default=False,
                        help='enable gpu')
    return parser.parse_args()


S
sjtubinlong 已提交
368
if __name__ == "__main__":
S
sjtubinlong 已提交
369 370
    args = parse_args()
    main(args)