infer.py 4.6 KB
Newer Older
S
sjtubinlong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import ast
import time
import json
import argparse

import numpy as np
import cv2

import paddle.fluid as fluid


def LoadModel(model_dir, use_gpu=False):
    prog_file = os.path.join(model_dir, '__model__')
    params_file = os.path.join(model_dir, '__params__')
    config = fluid.core.AnalysisConfig(prog_file, params_file)
    if use_gpu:
        config.enable_use_gpu(100, 0)
        config.switch_ir_optim(True)
    else:
        config.disable_gpu()
    config.disable_glog_info()
    config.switch_specify_input_names(True)
    config.enable_memory_optim()
    return fluid.core.create_paddle_predictor(config)


class HumanSeg:
    def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False):
        self.mean = np.array(mean).reshape((3, 1, 1))
        self.scale = np.array(scale).reshape((3, 1, 1))
        self.eval_size = eval_size
        self.predictor = LoadModel(model_dir, use_gpu)

    def Preprocess(self, image):
S
sjtubinlong 已提交
52 53
        im = cv2.resize(
            image, self.eval_size, fx=0, fy=0, interpolation=cv2.INTER_CUBIC)
S
sjtubinlong 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        # HWC -> CHW
        im = im.swapaxes(1, 2)
        im = im.swapaxes(0, 1)
        # Convert to float
        im = im[:, :, :].astype('float32')
        # im  = (im - mean) * scale
        im = im - self.mean
        im = im * self.scale
        im = im[np.newaxis, :, :, :]
        return im

    def Postprocess(self, image, output_data):
        mask = output_data[0, 1, :, :]
        mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
        scoremap = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
        bg = np.ones_like(scoremap) * 255
        merge_im = (scoremap * image + (1 - scoremap) * bg).astype(np.uint8)
        return merge_im

    def Predict(self, image):
        ori_im = image.copy()
        im = self.Preprocess(image)
        im_tensor = fluid.core.PaddleTensor(im.copy().astype('float32'))
        output_data = self.predictor.run([im_tensor])[0]
        output_data = output_data.as_ndarray()
        return self.Postprocess(image, output_data)

S
sjtubinlong 已提交
81

S
sjtubinlong 已提交
82 83 84 85 86 87
# Do Predicting on a image
def PredictImage(seg, image_path):
    im = cv2.imread(input_path)
    im = seg.Predict(im)
    cv2.imwrite('result.jpeg', im)

S
sjtubinlong 已提交
88

S
sjtubinlong 已提交
89 90
# Do Predicting on a video
def PredictVideo(seg, video_path):
S
sjtubinlong 已提交
91 92 93 94
    cap = cv2.VideoCapture(video_path)
    if cap.isOpened() == False:
        print("Error opening video stream or file")
        return
S
sjtubinlong 已提交
95
    w = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
S
sjtubinlong 已提交
96 97 98 99
    h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
    fps = cap.get(cv2.CAP_PROP_FPS)
    # Result Video Writer
    out = cv2.VideoWriter('result.avi',
S
sjtubinlong 已提交
100
                          cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
S
sjtubinlong 已提交
101 102
                          (int(w), int(h)))
    # Start capturing from video
S
sjtubinlong 已提交
103
    while (cap.isOpened()):
S
sjtubinlong 已提交
104 105 106
        ret, frame = cap.read()
        if ret == True:
            im = seg.Predict(frame)
S
sjtubinlong 已提交
107 108
            out.write(im)
        else:
S
sjtubinlong 已提交
109 110 111
            break
    cap.release()
    out.release()
S
sjtubinlong 已提交
112

S
sjtubinlong 已提交
113 114

# Do Predicting on a camera video stream
S
sjtubinlong 已提交
115 116 117 118 119 120
def PredictCamera(seg):
    cap = cv2.VideoCapture(0)
    if cap.isOpened() == False:
        print("Error opening video stream or file")
        return
    # Start capturing from video
S
sjtubinlong 已提交
121
    while (cap.isOpened()):
S
sjtubinlong 已提交
122 123 124 125 126 127
        ret, frame = cap.read()
        if ret == True:
            im = seg.Predict(frame)
            cv2.imshow('Frame', im)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
S
sjtubinlong 已提交
128
        else:
S
sjtubinlong 已提交
129 130 131 132
            break
    cap.release()


S
sjtubinlong 已提交
133 134 135 136 137 138 139 140
if __name__ == "__main__":
    if len(sys.argv) < 3:
        print('Usage: python infer.py /path/to/model/ /path/to/video')
        exit(0)

    model_dir = sys.argv[1]
    input_path = sys.argv[2]
    use_gpu = int(sys.argv[3]) if len(sys.argv) >= 4 else 0
S
sjtubinlong 已提交
141
    # Init model
S
sjtubinlong 已提交
142 143
    mean = [104.008, 116.669, 122.675]
    scale = [1.0, 1.0, 1.0]
S
sjtubinlong 已提交
144
    eval_size = (513, 513)
S
sjtubinlong 已提交
145
    seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu)
S
sjtubinlong 已提交
146
    # Run Predicting on a video and result will be saved as result.avi
S
sjtubinlong 已提交
147 148
    PredictCamera(seg)
    #PredictVideo(seg, input_path)