predict_demo.py 3.2 KB
Newer Older
W
wuyefeilin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16 17
import os
import os.path as osp
18
import sys
19 20 21 22 23 24 25 26
import numpy as np
from PIL import Image as Image
import argparse
from models import load_model


def parse_args():
    parser = argparse.ArgumentParser(description='RemoteSensing predict')
27 28 29 30 31 32
    parser.add_argument(
        '--single_img',
        dest='single_img',
        help='single image path to predict',
        default=None,
        type=str)
33 34 35 36 37 38
    parser.add_argument(
        '--data_dir',
        dest='data_dir',
        help='dataset directory',
        default=None,
        type=str)
39 40 41 42 43 44
    parser.add_argument(
        '--file_list',
        dest='file_list',
        help='file name of predict file list',
        default=None,
        type=str)
45 46 47 48 49 50
    parser.add_argument(
        '--load_model_dir',
        dest='load_model_dir',
        help='model load directory',
        default=None,
        type=str)
51 52 53 54 55 56 57 58 59
    parser.add_argument(
        '--save_img_dir',
        dest='save_img_dir',
        help='save directory name of predict results',
        default='predict_results',
        type=str)
    if len(sys.argv) < 2:
        parser.print_help()
        sys.exit(1)
60 61 62 63 64
    return parser.parse_args()


args = parse_args()
data_dir = args.data_dir
65 66
file_list = args.file_list
single_img = args.single_img
67
load_model_dir = args.load_model_dir
68 69 70
save_img_dir = args.save_img_dir
if not osp.exists(save_img_dir):
    os.makedirs(save_img_dir)
71 72

# predict
L
LutaoChu 已提交
73
model = load_model(load_model_dir)
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104

color_map = [0, 0, 0, 0, 255, 0]
if single_img is not None:
    pred = model.predict(single_img)
    # 以伪彩色png图片保存预测结果
    pred_name = osp.basename(single_img).rstrip('npy') + 'png'
    pred_path = osp.join(save_img_dir, pred_name)
    pred_mask = Image.fromarray(pred['label_map'].astype(np.uint8), mode='P')
    pred_mask.putpalette(color_map)
    pred_mask.save(pred_path)
elif (file_list is not None) and (data_dir is not None):
    with open(osp.join(data_dir, file_list)) as f:
        lines = f.readlines()
        for line in lines:
            img_path = line.split(' ')[0]
            print('Predicting {}'.format(img_path))
            img_path_ = osp.join(data_dir, img_path)

            pred = model.predict(img_path_)

            # 以伪彩色png图片保存预测结果
            pred_name = osp.basename(img_path).rstrip('npy') + 'png'
            pred_path = osp.join(save_img_dir, pred_name)
            pred_mask = Image.fromarray(
                pred['label_map'].astype(np.uint8), mode='P')
            pred_mask.putpalette(color_map)
            pred_mask.save(pred_path)
else:
    raise Exception(
        'You should either set the parameter single_img, or set the parameters data_dir, file_list.'
    )