infer.py 3.5 KB
Newer Older
P
pennypm 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
# -*- coding: utf-8 -*-
import os
import numpy as np
from utils.util import get_arguments
from utils.palette import get_palette
from utils.data_util import Cluster, pad_img
from PIL import Image as PILImage
import importlib
import paddle.fluid as fluid

args = get_arguments()
config = importlib.import_module('config')
cfg = getattr(config, 'cfg')

cluster = Cluster()

# 预测数据集类
class TestDataSet():
    def __init__(self):
        self.data_dir = cfg.data_dir 
        self.data_list_file = cfg.data_list_file
        self.data_list = self.get_data_list()
        self.data_num = len(self.data_list)
    
    def get_data_list(self):
        # 获取预测图像路径列表
        data_list = []
        data_file_handler = open(self.data_list_file, 'r')
        for line in data_file_handler:
            img_name = line.strip()
            name_prefix = img_name.split('.')[0]
            if len(img_name.split('.')) == 1:
                img_name = img_name + '.jpg'
            img_path = os.path.join(self.data_dir, img_name)
            data_list.append(img_path)
        return data_list

    def preprocess(self, img):
        # 图像预处理
        h, w = img.shape[:2]
        h_new = (h//32 + 1 if h % 32 != 0 else h//32)*32
        w_new = (w//32 + 1 if w % 32 != 0 else w//32)*32
        img = np.pad(img, ((0, h_new - h), (0, w_new - w), (0, 0)), 'edge')
        
        img = img.astype(np.float32)/255.0
        img = img.transpose((2, 0, 1))
        img = np.expand_dims(img, axis=0)
        return img

    def get_data(self, index):
        # 获取图像信息
        img_path = self.data_list[index]
        img = np.array(PILImage.open(img_path))
        if img is None:
            return img, img,img_path, None

        img_name = img_path.split(os.sep)[-1]
        name_prefix = img_name.replace('.'+img_name.split('.')[-1],'')
        img_shape = img.shape[:2]
        img_process = self.preprocess(img)

        return img_process, name_prefix, img_shape


def infer():
    if not os.path.exists(cfg.vis_dir):
        os.makedirs(cfg.vis_dir)

    place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    # 加载预测模型
    test_prog, feed_name, fetch_list = fluid.io.load_inference_model(
        dirname=cfg.model_path, executor=exe, params_filename='__params__')

    #加载预测数据集
    test_dataset = TestDataSet()
    data_num = test_dataset.data_num

    for idx in range(data_num):
        # 数据获取
        image, im_name, im_shape = test_dataset.get_data(idx)
        if image is None:
            print(im_name, 'is None')
            continue
        # 预测
        output = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list)
        instance_map, predictions = cluster.cluster(output[0][0], n_sigma=cfg.n_sigma, \
                                    min_pixel=cfg.min_pixel, threshold=cfg.threshold)

        # 预测结果保存
        instance_map = pad_img(instance_map, image.shape[2:])
        instance_map = instance_map[:im_shape[0], :im_shape[1]]
        output_im = PILImage.fromarray(np.asarray(instance_map, dtype=np.uint8))
        palette = get_palette(len(predictions) + 1)
        output_im.putpalette(palette)
        result_path = os.path.join(cfg.vis_dir, im_name+'.png')
        output_im.save(result_path)

        if (idx + 1) % 100 == 0:
            print('%d  processd' % (idx + 1))
            
    print('%d  processd done' % (idx + 1))   
    
    return 0


if __name__ == "__main__":
    infer()