predict_demo.py 2.6 KB
Newer Older
1 2
import os
import os.path as osp
3
import sys
4 5 6 7 8 9 10 11
import numpy as np
from PIL import Image as Image
import argparse
from models import load_model


def parse_args():
    parser = argparse.ArgumentParser(description='RemoteSensing predict')
12 13 14 15 16 17
    parser.add_argument(
        '--single_img',
        dest='single_img',
        help='single image path to predict',
        default=None,
        type=str)
18 19 20 21 22 23
    parser.add_argument(
        '--data_dir',
        dest='data_dir',
        help='dataset directory',
        default=None,
        type=str)
24 25 26 27 28 29
    parser.add_argument(
        '--file_list',
        dest='file_list',
        help='file name of predict file list',
        default=None,
        type=str)
30 31 32 33 34 35
    parser.add_argument(
        '--load_model_dir',
        dest='load_model_dir',
        help='model load directory',
        default=None,
        type=str)
36 37 38 39 40 41 42 43 44
    parser.add_argument(
        '--save_img_dir',
        dest='save_img_dir',
        help='save directory name of predict results',
        default='predict_results',
        type=str)
    if len(sys.argv) < 2:
        parser.print_help()
        sys.exit(1)
45 46 47 48 49
    return parser.parse_args()


args = parse_args()
data_dir = args.data_dir
50 51
file_list = args.file_list
single_img = args.single_img
52
load_model_dir = args.load_model_dir
53 54 55
save_img_dir = args.save_img_dir
if not osp.exists(save_img_dir):
    os.makedirs(save_img_dir)
56 57

# predict
L
LutaoChu 已提交
58
model = load_model(load_model_dir)
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

color_map = [0, 0, 0, 0, 255, 0]
if single_img is not None:
    pred = model.predict(single_img)
    # 以伪彩色png图片保存预测结果
    pred_name = osp.basename(single_img).rstrip('npy') + 'png'
    pred_path = osp.join(save_img_dir, pred_name)
    pred_mask = Image.fromarray(pred['label_map'].astype(np.uint8), mode='P')
    pred_mask.putpalette(color_map)
    pred_mask.save(pred_path)
elif (file_list is not None) and (data_dir is not None):
    with open(osp.join(data_dir, file_list)) as f:
        lines = f.readlines()
        for line in lines:
            img_path = line.split(' ')[0]
            print('Predicting {}'.format(img_path))
            img_path_ = osp.join(data_dir, img_path)

            pred = model.predict(img_path_)

            # 以伪彩色png图片保存预测结果
            pred_name = osp.basename(img_path).rstrip('npy') + 'png'
            pred_path = osp.join(save_img_dir, pred_name)
            pred_mask = Image.fromarray(
                pred['label_map'].astype(np.uint8), mode='P')
            pred_mask.putpalette(color_map)
            pred_mask.save(pred_path)
else:
    raise Exception(
        'You should either set the parameter single_img, or set the parameters data_dir, file_list.'
    )