infer.py 4.4 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
# -*- coding: utf-8 -*-
import os
import cv2
import numpy as np
from utils.util import get_arguments
from utils.palette import get_palette
from PIL import Image as PILImage
import importlib

args = get_arguments()
config = importlib.import_module(args.example+'.config')
cfg = getattr(config, 'cfg')

# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0'

import paddle.fluid as fluid

# 预测数据集类
class TestDataSet():
    def __init__(self):
        self.data_dir = cfg.data_dir 
        self.data_list_file = cfg.data_list_file
        self.data_list = self.get_data_list()
        self.data_num = len(self.data_list)
    
    def get_data_list(self):
        # 获取预测图像路径列表
        data_list = []
        data_file_handler = open(self.data_list_file, 'r')
        for line in data_file_handler:
            img_name = line.strip()
            name_prefix = img_name.split('.')[0]
            if len(img_name.split('.')) == 1:
                img_name = img_name + '.jpg'
            img_path = os.path.join(self.data_dir, img_name)
            data_list.append(img_path)
        return data_list

    def preprocess(self, img):
        # 图像预处理
        if cfg.example == 'ACE2P':
            reader = importlib.import_module(args.example+'.reader')
            ACE2P_preprocess = getattr(reader, 'preprocess')
            img = ACE2P_preprocess(img)
        else:
            img = cv2.resize(img, cfg.input_size).astype(np.float32)
            img -= np.array(cfg.MEAN)
            img /= np.array(cfg.STD)
            img = img.transpose((2, 0, 1))
            img = np.expand_dims(img, axis=0)
        return img

    def get_data(self, index):
        # 获取图像信息
        img_path = self.data_list[index]
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if img is None:
            return img, img,img_path, None

        img_name = img_path.split(os.sep)[-1]
        name_prefix = img_name.replace('.'+img_name.split('.')[-1],'')
        img_shape = img.shape[:2]
        img_process = self.preprocess(img)

        return img, img_process, name_prefix, img_shape


def infer():
    if not os.path.exists(cfg.vis_dir):
        os.makedirs(cfg.vis_dir)
    palette = get_palette(cfg.class_num)
    # 人像分割结果显示阈值
    thresh = 120

    place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    # 加载预测模型
    test_prog, feed_name, fetch_list = fluid.io.load_inference_model(
        dirname=cfg.model_path, executor=exe, params_filename='__params__')

    #加载预测数据集
    test_dataset = TestDataSet()
    data_num = test_dataset.data_num

P
pennypm 已提交
87
    for idx in range(data_num):
W
wuzewu 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
        # 数据获取
        ori_img, image, im_name, im_shape = test_dataset.get_data(idx)
        if image is None:
            print(im_name, 'is None')
            continue
        
        # 预测
        if cfg.example == 'ACE2P':
            # ACE2P模型使用多尺度预测
            reader = importlib.import_module(args.example+'.reader')
            multi_scale_test = getattr(reader, 'multi_scale_test')
            parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape)
        else:
            # HumanSeg,RoadLine模型单尺度预测
            result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list)
            parsing = np.argmax(result[0][0], axis=0)
            parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
        
        # 预测结果保存
        result_path = os.path.join(cfg.vis_dir, im_name + '.png')
        if cfg.example == 'HumanSeg':
            logits = result[0][0][1]*255
            logits = cv2.resize(logits, im_shape[::-1])
            ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
            logits = 255 *(logits - thresh)/(255 - thresh)
            # 将分割结果添加到alpha通道
            rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2)
            cv2.imwrite(result_path, rgba)
        else: 
            output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
            output_im.putpalette(palette)
            output_im.save(result_path)

P
pennypm 已提交
121 122
        if (idx + 1) % 100 == 0:
            print('%d  processd' % (idx + 1))
W
wuzewu 已提交
123
            
P
pennypm 已提交
124
    print('%d  processd done' % (idx + 1))   
W
wuzewu 已提交
125 126 127 128 129 130
    
    return 0


if __name__ == "__main__":
    infer()