infer.py 5.1 KB
Newer Older
S
sjtubinlong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

import numpy as np
import cv2

import paddle.fluid as fluid


S
sjtubinlong 已提交
25
def load_model(model_dir, use_gpu=False):
S
sjtubinlong 已提交
26 27 28
    """
    Load model files and init paddle predictor
    """
S
sjtubinlong 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
    prog_file = os.path.join(model_dir, '__model__')
    params_file = os.path.join(model_dir, '__params__')
    config = fluid.core.AnalysisConfig(prog_file, params_file)
    if use_gpu:
        config.enable_use_gpu(100, 0)
        config.switch_ir_optim(True)
    else:
        config.disable_gpu()
    config.disable_glog_info()
    config.switch_specify_input_names(True)
    config.enable_memory_optim()
    return fluid.core.create_paddle_predictor(config)


class HumanSeg:
S
sjtubinlong 已提交
44 45 46
    """
    Human Segmentation Class
    """
S
sjtubinlong 已提交
47 48 49 50
    def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False):
        self.mean = np.array(mean).reshape((3, 1, 1))
        self.scale = np.array(scale).reshape((3, 1, 1))
        self.eval_size = eval_size
S
sjtubinlong 已提交
51
        self.predictor = load_model(model_dir, use_gpu)
S
sjtubinlong 已提交
52

S
sjtubinlong 已提交
53 54 55 56 57
    def preprocess(self, image):
        """
        preprocess image: hwc_rgb to chw_bgr
        """
        img_mat = cv2.resize(
S
sjtubinlong 已提交
58
            image, self.eval_size, fx=0, fy=0, interpolation=cv2.INTER_CUBIC)
S
sjtubinlong 已提交
59
        # HWC -> CHW
S
sjtubinlong 已提交
60 61
        img_mat = img_mat.swapaxes(1, 2)
        img_mat = img_mat.swapaxes(0, 1)
S
sjtubinlong 已提交
62
        # Convert to float
S
sjtubinlong 已提交
63 64 65 66 67 68 69 70 71 72 73
        img_mat = img_mat[:, :, :].astype('float32')
        # img_mat = (img_mat - mean) * scale
        img_mat = img_mat - self.mean
        img_mat = img_mat * self.scale
        img_mat = img_mat[np.newaxis, :, :, :]
        return img_mat

    def postprocess(self, image, output_data):
        """
        postprocess result: merge background with segmentation result
        """
S
sjtubinlong 已提交
74 75 76
        mask = output_data[0, 1, :, :]
        mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
        scoremap = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
S
sjtubinlong 已提交
77 78
        bg_im = np.ones_like(scoremap) * 255
        merge_im = (scoremap * image + (1 - scoremap) * bg_im).astype(np.uint8)
S
sjtubinlong 已提交
79 80
        return merge_im

S
sjtubinlong 已提交
81 82 83 84
    def run_predict(self, image):
        """
        run predict: return segmentation image mat
        """
S
sjtubinlong 已提交
85
        ori_im = image.copy()
S
sjtubinlong 已提交
86 87
        im_mat = self.preprocess(ori_im)
        im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32'))
S
sjtubinlong 已提交
88 89
        output_data = self.predictor.run([im_tensor])[0]
        output_data = output_data.as_ndarray()
S
sjtubinlong 已提交
90
        return self.postprocess(image, output_data)
S
sjtubinlong 已提交
91

S
sjtubinlong 已提交
92

S
sjtubinlong 已提交
93 94 95 96 97 98 99
def predict_image(seg, image_path):
    """
    Do Predicting on a image
    """
    img_mat = cv2.imread(image_path)
    img_mat = seg.run_predict(img_mat)
    cv2.imwrite('result.jpeg', img_mat)
S
sjtubinlong 已提交
100

S
sjtubinlong 已提交
101

S
sjtubinlong 已提交
102 103 104 105
def predict_video(seg, video_path):
    """
    Do Predicting on a video
    """
S
sjtubinlong 已提交
106
    cap = cv2.VideoCapture(video_path)
S
sjtubinlong 已提交
107
    if not cap.isOpened():
S
sjtubinlong 已提交
108 109
        print("Error opening video stream or file")
        return
S
sjtubinlong 已提交
110 111
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
S
sjtubinlong 已提交
112 113 114
    fps = cap.get(cv2.CAP_PROP_FPS)
    # Result Video Writer
    out = cv2.VideoWriter('result.avi',
S
sjtubinlong 已提交
115
                          cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
S
sjtubinlong 已提交
116
                          (width, height))
S
sjtubinlong 已提交
117
    # Start capturing from video
S
sjtubinlong 已提交
118
    while cap.isOpened():
S
sjtubinlong 已提交
119
        ret, frame = cap.read()
S
sjtubinlong 已提交
120 121 122
        if ret:
            img_mat = seg.run_predict(frame)
            out.write(img_mat)
S
sjtubinlong 已提交
123
        else:
S
sjtubinlong 已提交
124 125 126
            break
    cap.release()
    out.release()
S
sjtubinlong 已提交
127

S
sjtubinlong 已提交
128 129 130 131
def predict_camera(seg):
    """
    Do Predicting on a camera video stream
    """
S
sjtubinlong 已提交
132
    cap = cv2.VideoCapture(0)
S
sjtubinlong 已提交
133
    if not cap.isOpened():
S
sjtubinlong 已提交
134 135 136
        print("Error opening video stream or file")
        return
    # Start capturing from video
S
sjtubinlong 已提交
137
    while cap.isOpened():
S
sjtubinlong 已提交
138
        ret, frame = cap.read()
S
sjtubinlong 已提交
139
        if ret:
S
sjtubinlong 已提交
140 141
            img_mat = seg.run_predict(frame)
            cv2.imshow('Frame', img_mat)
S
sjtubinlong 已提交
142 143
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
S
sjtubinlong 已提交
144
        else:
S
sjtubinlong 已提交
145 146 147
            break
    cap.release()

S
sjtubinlong 已提交
148 149 150 151 152
def main(argv):
    """
    Entrypoint of the script
    """
    if len(argv) < 3:
S
sjtubinlong 已提交
153
        print('Usage: python infer.py /path/to/model/ /path/to/video')
S
sjtubinlong 已提交
154
        return
S
sjtubinlong 已提交
155 156 157 158

    model_dir = sys.argv[1]
    input_path = sys.argv[2]
    use_gpu = int(sys.argv[3]) if len(sys.argv) >= 4 else 0
S
sjtubinlong 已提交
159
    # Init model
S
sjtubinlong 已提交
160 161
    mean = [104.008, 116.669, 122.675]
    scale = [1.0, 1.0, 1.0]
S
sjtubinlong 已提交
162
    eval_size = (513, 513)
S
sjtubinlong 已提交
163
    seg = HumanSeg(model_dir, mean, scale, eval_size, use_gpu)
S
sjtubinlong 已提交
164
    # Run Predicting on a video and result will be saved as result.avi
S
sjtubinlong 已提交
165 166
    # predict_camera(seg)
    predict_video(seg, input_path)
S
sjtubinlong 已提交
167 168 169 170


if __name__ == "__main__":
    main(sys.argv)