predict_demo.py 3.6 KB
Newer Older
W
wuyefeilin 已提交
1
# coding: utf8
2
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
W
wuyefeilin 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16 17
import os
import os.path as osp
18
import sys
19 20 21 22
import numpy as np
from PIL import Image as Image
import argparse
from models import load_model
23
from models.utils.visualize import get_color_map_list
24 25 26 27


def parse_args():
    parser = argparse.ArgumentParser(description='RemoteSensing predict')
28 29 30 31 32 33
    parser.add_argument(
        '--single_img',
        dest='single_img',
        help='single image path to predict',
        default=None,
        type=str)
34 35 36 37 38 39
    parser.add_argument(
        '--data_dir',
        dest='data_dir',
        help='dataset directory',
        default=None,
        type=str)
40 41 42 43 44 45
    parser.add_argument(
        '--file_list',
        dest='file_list',
        help='file name of predict file list',
        default=None,
        type=str)
46 47 48 49 50 51
    parser.add_argument(
        '--load_model_dir',
        dest='load_model_dir',
        help='model load directory',
        default=None,
        type=str)
52 53 54 55 56 57
    parser.add_argument(
        '--save_img_dir',
        dest='save_img_dir',
        help='save directory name of predict results',
        default='predict_results',
        type=str)
58 59 60 61 62 63 64
    parser.add_argument(
        '--color_map',
        dest='color_map',
        help='color map of predict results',
        type=int,
        nargs='*',
        default=-1)
65 66 67
    if len(sys.argv) < 2:
        parser.print_help()
        sys.exit(1)
68 69 70 71 72
    return parser.parse_args()


args = parse_args()
data_dir = args.data_dir
73 74
file_list = args.file_list
single_img = args.single_img
75
load_model_dir = args.load_model_dir
76 77 78
save_img_dir = args.save_img_dir
if not osp.exists(save_img_dir):
    os.makedirs(save_img_dir)
79 80 81 82
if args.color_map == -1:
    color_map = get_color_map_list(256)
else:
    color_map = args.color_map
83 84

# predict
L
LutaoChu 已提交
85
model = load_model(load_model_dir)
86 87 88 89

if single_img is not None:
    pred = model.predict(single_img)
    # 以伪彩色png图片保存预测结果
90 91
    pred_name, _ = osp.splitext(osp.basename(single_img))
    pred_path = osp.join(save_img_dir, pred_name + '.png')
92 93 94
    pred_mask = Image.fromarray(pred['label_map'].astype(np.uint8), mode='P')
    pred_mask.putpalette(color_map)
    pred_mask.save(pred_path)
95
    print('Predict result is saved in {}'.format(pred_path))
96 97 98 99 100 101 102 103 104 105
elif (file_list is not None) and (data_dir is not None):
    with open(osp.join(data_dir, file_list)) as f:
        lines = f.readlines()
        for line in lines:
            img_path = line.split(' ')[0]
            img_path_ = osp.join(data_dir, img_path)

            pred = model.predict(img_path_)

            # 以伪彩色png图片保存预测结果
106 107
            pred_name, _ = osp.splitext(osp.basename(img_path))
            pred_path = osp.join(save_img_dir, pred_name + '.png')
108 109 110 111
            pred_mask = Image.fromarray(
                pred['label_map'].astype(np.uint8), mode='P')
            pred_mask.putpalette(color_map)
            pred_mask.save(pred_path)
112
            print('Predict result is saved in {}'.format(pred_path))
113 114
else:
    raise Exception(
115
        'You should either set the parameter single_img, or set the parameters data_dir and file_list.'
116
    )