infer.py 5.1 KB
Newer Older
W
wuyefeilin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
16 17 18 19 20 21 22 23 24
import os
import cv2
import numpy as np
from utils.util import get_arguments
from utils.palette import get_palette
from PIL import Image as PILImage
import importlib

args = get_arguments()
L
LutaoChu 已提交
25
config = importlib.import_module('config')
W
wuzewu 已提交
26 27 28
cfg = getattr(config, 'cfg')

# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
W
wuyefeilin 已提交
29
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
W
wuzewu 已提交
30 31 32

import paddle.fluid as fluid

W
wuyefeilin 已提交
33

W
wuzewu 已提交
34 35 36
# 预测数据集类
class TestDataSet():
    def __init__(self):
W
wuyefeilin 已提交
37
        self.data_dir = cfg.data_dir
W
wuzewu 已提交
38 39 40
        self.data_list_file = cfg.data_list_file
        self.data_list = self.get_data_list()
        self.data_num = len(self.data_list)
W
wuyefeilin 已提交
41

W
wuzewu 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    def get_data_list(self):
        # 获取预测图像路径列表
        data_list = []
        data_file_handler = open(self.data_list_file, 'r')
        for line in data_file_handler:
            img_name = line.strip()
            name_prefix = img_name.split('.')[0]
            if len(img_name.split('.')) == 1:
                img_name = img_name + '.jpg'
            img_path = os.path.join(self.data_dir, img_name)
            data_list.append(img_path)
        return data_list

    def preprocess(self, img):
        # 图像预处理
        if cfg.example == 'ACE2P':
L
LutaoChu 已提交
58
            reader = importlib.import_module('reader')
W
wuzewu 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
            ACE2P_preprocess = getattr(reader, 'preprocess')
            img = ACE2P_preprocess(img)
        else:
            img = cv2.resize(img, cfg.input_size).astype(np.float32)
            img -= np.array(cfg.MEAN)
            img /= np.array(cfg.STD)
            img = img.transpose((2, 0, 1))
            img = np.expand_dims(img, axis=0)
        return img

    def get_data(self, index):
        # 获取图像信息
        img_path = self.data_list[index]
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if img is None:
W
wuyefeilin 已提交
74
            return img, img, img_path, None
W
wuzewu 已提交
75 76

        img_name = img_path.split(os.sep)[-1]
W
wuyefeilin 已提交
77
        name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
W
wuzewu 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        img_shape = img.shape[:2]
        img_process = self.preprocess(img)

        return img, img_process, name_prefix, img_shape


def infer():
    if not os.path.exists(cfg.vis_dir):
        os.makedirs(cfg.vis_dir)
    palette = get_palette(cfg.class_num)
    # 人像分割结果显示阈值
    thresh = 120

    place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    # 加载预测模型
    test_prog, feed_name, fetch_list = fluid.io.load_inference_model(
        dirname=cfg.model_path, executor=exe, params_filename='__params__')

    #加载预测数据集
    test_dataset = TestDataSet()
    data_num = test_dataset.data_num

P
pennypm 已提交
102
    for idx in range(data_num):
W
wuzewu 已提交
103 104 105 106 107
        # 数据获取
        ori_img, image, im_name, im_shape = test_dataset.get_data(idx)
        if image is None:
            print(im_name, 'is None')
            continue
W
wuyefeilin 已提交
108

W
wuzewu 已提交
109 110 111
        # 预测
        if cfg.example == 'ACE2P':
            # ACE2P模型使用多尺度预测
L
LutaoChu 已提交
112
            reader = importlib.import_module('reader')
W
wuzewu 已提交
113
            multi_scale_test = getattr(reader, 'multi_scale_test')
W
wuyefeilin 已提交
114 115
            parsing, logits = multi_scale_test(exe, test_prog, feed_name,
                                               fetch_list, image, im_shape)
W
wuzewu 已提交
116 117
        else:
            # HumanSeg,RoadLine模型单尺度预测
W
wuyefeilin 已提交
118 119 120 121
            result = exe.run(
                program=test_prog,
                feed={feed_name[0]: image},
                fetch_list=fetch_list)
W
wuzewu 已提交
122 123
            parsing = np.argmax(result[0][0], axis=0)
            parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
W
wuyefeilin 已提交
124

W
wuzewu 已提交
125 126 127
        # 预测结果保存
        result_path = os.path.join(cfg.vis_dir, im_name + '.png')
        if cfg.example == 'HumanSeg':
W
wuyefeilin 已提交
128
            logits = result[0][0][1] * 255
W
wuzewu 已提交
129 130
            logits = cv2.resize(logits, im_shape[::-1])
            ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
W
wuyefeilin 已提交
131
            logits = 255 * (logits - thresh) / (255 - thresh)
W
wuzewu 已提交
132
            # 将分割结果添加到alpha通道
W
wuyefeilin 已提交
133 134
            rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
                                  axis=2)
W
wuzewu 已提交
135
            cv2.imwrite(result_path, rgba)
W
wuyefeilin 已提交
136
        else:
W
wuzewu 已提交
137 138 139 140
            output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
            output_im.putpalette(palette)
            output_im.save(result_path)

P
pennypm 已提交
141 142
        if (idx + 1) % 100 == 0:
            print('%d  processd' % (idx + 1))
W
wuyefeilin 已提交
143 144 145

    print('%d  processd done' % (idx + 1))

W
wuzewu 已提交
146 147 148 149 150
    return 0


if __name__ == "__main__":
    infer()