提交 085dabf3 编写于 作者: C chenguowei01

add infer for save prediction result

上级 84426956
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import os.path as osp
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid
import cv2
import tqdm
import transforms as T
import models
import utils
import utils.logging as logging
from utils import get_environ_info
def parse_args():
parser = argparse.ArgumentParser(description='Model training')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--data_dir',
dest='data_dir',
help='The root directory of dataset',
type=str)
parser.add_argument(
'--test_list',
dest='test_list',
help='Val list file of dataset',
type=str,
default=None)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
# params of prediction
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=2)
parser.add_argument(
'--model_dir',
dest='model_dir',
help='The path of model for evaluation',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the inference results',
type=str,
default='./output/result')
return parser.parse_args()
def mkdir(path):
sub_dir = osp.dirname(path)
if not osp.exists(sub_dir):
os.makedirs(sub_dir)
def infer(model, data_dir=None, test_list=None, model_dir=None,
transforms=None):
ckpt_path = osp.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
added_saved_dir = osp.join(args.save_dir, 'added')
pred_saved_dir = osp.join(args.save_dir, 'prediction')
logging.info("Start to predict...")
with open(test_list, 'r') as f:
files = f.readlines()
for file in tqdm.tqdm(files):
file = file.strip()
im_file = osp.join(data_dir, file)
im, im_info = transforms(im_file)
im = np.expand_dims(im, axis=0)
im = to_variable(im)
pred, _ = model(im, mode='test')
pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8')
keys = list(im_info.keys())
for k in keys[::-1]:
if k == 'shape_before_resize':
h, w = im_info[k][0], im_info[k][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif k == 'shape_before_padding':
h, w = im_info[k][0], im_info[k][1]
pred = pred[0:h, 0:w]
# save added image
added_image = utils.visualize(im_file, pred, weight=0.6)
added_image_path = osp.join(added_saved_dir, file)
mkdir(added_image_path)
cv2.imwrite(added_image_path, added_image)
# save prediction
pred_im = utils.visualize(im_file, pred, weight=0.0)
pred_saved_path = osp.join(pred_saved_dir, file)
mkdir(pred_saved_path)
cv2.imwrite(pred_saved_path, pred_im)
def arrange_transform(transforms, mode='train'):
arrange_transform = T.ArrangeSegmenter
if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
transforms.transforms[-1] = arrange_transform(mode=mode)
else:
transforms.transforms.append(arrange_transform(mode=mode))
def main(args):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
arrange_transform(test_transforms, mode='test')
if args.model_name == 'UNet':
model = models.UNet(num_classes=args.num_classes)
infer(
model,
data_dir=args.data_dir,
test_list=args.test_list,
model_dir=args.model_dir,
transforms=test_transforms)
if __name__ == '__main__':
args = parse_args()
env_info = get_environ_info()
if env_info['place'] == 'cpu':
places = fluid.CPUPlace()
else:
places = fluid.CUDAPlace(0)
with fluid.dygraph.guard(places):
main(args)
...@@ -228,13 +228,12 @@ def visualize(image, result, save_dir=None, weight=0.6): ...@@ -228,13 +228,12 @@ def visualize(image, result, save_dir=None, weight=0.6):
save_dir: the directory for saving visual image save_dir: the directory for saving visual image
weight: the image weight of visual image, and the result weight is (1 - weight) weight: the image weight of visual image, and the result weight is (1 - weight)
""" """
label_map = result['label_map']
color_map = get_color_map_list(256) color_map = get_color_map_list(256)
color_map = np.array(color_map).astype("uint8") color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping # Use OpenCV LUT for color mapping
c1 = cv2.LUT(label_map, color_map[:, 0]) c1 = cv2.LUT(result, color_map[:, 0])
c2 = cv2.LUT(label_map, color_map[:, 1]) c2 = cv2.LUT(result, color_map[:, 1])
c3 = cv2.LUT(label_map, color_map[:, 2]) c3 = cv2.LUT(result, color_map[:, 2])
pseudo_img = np.dstack((c1, c2, c3)) pseudo_img = np.dstack((c1, c2, c3))
im = cv2.imread(image) im = cv2.imread(image)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册