未验证 提交 04b3f1b0 编写于 作者: L LutaoChu 提交者: GitHub

Add whole process of remote sensing (#287)

* support .tif .img read

* add data analyse and check tools

* upgrade to adapt new dataset

* remove global

* optimize function, file names
上级 d16fc9a9
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -125,7 +125,7 @@ class HRNet(BaseModel): ...@@ -125,7 +125,7 @@ class HRNet(BaseModel):
train_reader, train_reader,
train_batch_size=2, train_batch_size=2,
eval_reader=None, eval_reader=None,
eval_best_metric='kappa', eval_best_metric='miou',
save_interval_epochs=1, save_interval_epochs=1,
log_interval_steps=2, log_interval_steps=2,
save_dir='output', save_dir='output',
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -127,7 +127,7 @@ class UNet(BaseModel): ...@@ -127,7 +127,7 @@ class UNet(BaseModel):
train_reader, train_reader,
train_batch_size=2, train_batch_size=2,
eval_reader=None, eval_reader=None,
eval_best_metric='kappa', eval_best_metric='miou',
save_interval_epochs=1, save_interval_epochs=1,
log_interval_steps=2, log_interval_steps=2,
save_dir='output', save_dir='output',
......
import os
import os.path as osp
import numpy as np
from PIL import Image as Image
def get_color_map_list(num_classes):
""" Returns the color map for visualizing the segmentation mask,
which can support arbitrary number of classes.
Args:
num_classes: Number of classes
Returns:
The color map
"""
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
return color_map
def splice_imgs(img_list, vis_path):
"""Splice pictures horizontally
"""
IMAGE_WIDTH, IMAGE_HEIGHT = img_list[0].size
padding_width = 20
img_num = len(img_list)
to_image = Image.new('RGB',
(img_num * IMAGE_WIDTH + (img_num - 1) * padding_width,
IMAGE_HEIGHT)) # Create a new picture
padding = Image.new('RGB', (padding_width, IMAGE_HEIGHT), (255, 255, 255))
# Loop through, paste each picture to the corresponding position in order
for i, from_image in enumerate(img_list):
to_image.paste(from_image, (i * (IMAGE_WIDTH + padding_width), 0))
if i < img_num - 1:
to_image.paste(padding,
(i * (IMAGE_WIDTH + padding_width) + IMAGE_WIDTH, 0))
return to_image.save(vis_path)
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from PIL import Image as Image from PIL import Image as Image
import argparse import argparse
from models import load_model from models import load_model
from models.utils.visualize import get_color_map_list
def parse_args(): def parse_args():
...@@ -54,6 +55,13 @@ def parse_args(): ...@@ -54,6 +55,13 @@ def parse_args():
help='save directory name of predict results', help='save directory name of predict results',
default='predict_results', default='predict_results',
type=str) type=str)
parser.add_argument(
'--color_map',
dest='color_map',
help='color map of predict results',
type=int,
nargs='*',
default=-1)
if len(sys.argv) < 2: if len(sys.argv) < 2:
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
...@@ -68,37 +76,41 @@ load_model_dir = args.load_model_dir ...@@ -68,37 +76,41 @@ load_model_dir = args.load_model_dir
save_img_dir = args.save_img_dir save_img_dir = args.save_img_dir
if not osp.exists(save_img_dir): if not osp.exists(save_img_dir):
os.makedirs(save_img_dir) os.makedirs(save_img_dir)
if args.color_map == -1:
color_map = get_color_map_list(256)
else:
color_map = args.color_map
# predict # predict
model = load_model(load_model_dir) model = load_model(load_model_dir)
color_map = [0, 0, 0, 0, 255, 0]
if single_img is not None: if single_img is not None:
pred = model.predict(single_img) pred = model.predict(single_img)
# 以伪彩色png图片保存预测结果 # 以伪彩色png图片保存预测结果
pred_name = osp.basename(single_img).rstrip('npy') + 'png' pred_name, _ = osp.splitext(osp.basename(single_img))
pred_path = osp.join(save_img_dir, pred_name) pred_path = osp.join(save_img_dir, pred_name + '.png')
pred_mask = Image.fromarray(pred['label_map'].astype(np.uint8), mode='P') pred_mask = Image.fromarray(pred['label_map'].astype(np.uint8), mode='P')
pred_mask.putpalette(color_map) pred_mask.putpalette(color_map)
pred_mask.save(pred_path) pred_mask.save(pred_path)
print('Predict result is saved in {}'.format(pred_path))
elif (file_list is not None) and (data_dir is not None): elif (file_list is not None) and (data_dir is not None):
with open(osp.join(data_dir, file_list)) as f: with open(osp.join(data_dir, file_list)) as f:
lines = f.readlines() lines = f.readlines()
for line in lines: for line in lines:
img_path = line.split(' ')[0] img_path = line.split(' ')[0]
print('Predicting {}'.format(img_path))
img_path_ = osp.join(data_dir, img_path) img_path_ = osp.join(data_dir, img_path)
pred = model.predict(img_path_) pred = model.predict(img_path_)
# 以伪彩色png图片保存预测结果 # 以伪彩色png图片保存预测结果
pred_name = osp.basename(img_path).rstrip('npy') + 'png' pred_name, _ = osp.splitext(osp.basename(img_path))
pred_path = osp.join(save_img_dir, pred_name) pred_path = osp.join(save_img_dir, pred_name + '.png')
pred_mask = Image.fromarray( pred_mask = Image.fromarray(
pred['label_map'].astype(np.uint8), mode='P') pred['label_map'].astype(np.uint8), mode='P')
pred_mask.putpalette(color_map) pred_mask.putpalette(color_map)
pred_mask.save(pred_path) pred_mask.save(pred_path)
print('Predict result is saved in {}'.format(pred_path))
else: else:
raise Exception( raise Exception(
'You should either set the parameter single_img, or set the parameters data_dir, file_list.' 'You should either set the parameter single_img, or set the parameters data_dir and file_list.'
) )
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,19 +15,36 @@ ...@@ -15,19 +15,36 @@
from __future__ import absolute_import from __future__ import absolute_import
import os.path as osp import os.path as osp
import random import random
import imghdr
import gdal
import numpy as np
from utils import logging from utils import logging
from .base import BaseReader from .base import BaseReader
from .base import get_encoding from .base import get_encoding
from collections import OrderedDict from collections import OrderedDict
from .base import is_pic
def read_img(img_path):
img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
dataset = gdal.Open(img_path)
if dataset == None:
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0))
elif ext == '.npy':
return np.load(img_path)
else:
raise Exception('Not support {} image format!'.format(ext))
class Reader(BaseReader): class Reader(BaseReader):
"""读取语分分割任务数据集,并对样本进行相应的处理。 """读取数据集,并对样本进行相应的处理。
Args: Args:
data_dir (str): 数据集所在的目录路径。 data_dir (str): 数据集所在的目录路径。
file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。 file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
label_list (str): 描述数据集包含的类别信息文件路径。 label_list (str): 描述数据集包含的类别信息文件路径。
transforms (list): 数据集中每个样本的预处理/增强算子。 transforms (list): 数据集中每个样本的预处理/增强算子。
num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。 num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import os.path as osp
import sys
import argparse
from tqdm import tqdm
import pickle
from data_analyse_and_check import read_img
def parse_args():
parser = argparse.ArgumentParser(
description=
'Compute normalization coefficient and clip percentage before training.'
)
parser.add_argument(
'--data_dir',
dest='data_dir',
help='Dataset directory',
default=None,
type=str)
parser.add_argument(
'--pkl_path',
dest='pkl_path',
help='Path of img_pixel_statistics.pkl',
default=None,
type=str)
parser.add_argument(
'--clip_min_value',
dest='clip_min_value',
help='Min values for clipping data',
nargs='+',
default=None,
type=int)
parser.add_argument(
'--clip_max_value',
dest='clip_max_value',
help='Max values for clipping data',
nargs='+',
default=None,
type=int)
parser.add_argument(
'--separator',
dest='separator',
help='file list separator',
default=" ",
type=str)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def compute_single_img(img, clip_min_value, clip_max_value):
channel = img.shape[2]
means = np.zeros(channel)
stds = np.zeros(channel)
for k in range(channel):
if clip_max_value != [] and clip_min_value != []:
np.clip(
img[:, :, k],
clip_min_value[k],
clip_max_value[k],
out=img[:, :, k])
# Rescaling (min-max normalization)
range_value = [
clip_max_value[i] - clip_min_value[i]
for i in range(len(clip_max_value))
]
img_k = (img[:, :, k].astype(np.float32, copy=False) -
clip_min_value[k]) / range_value[k]
else:
img_k = img[:, :, k]
# count mean, std
means[k] = np.mean(img_k)
stds[k] = np.std(img_k)
return means, stds
def cal_normalize_coefficient(data_dir, separator, clip_min_value,
clip_max_value):
train_file_list = osp.join(data_dir, 'train.txt')
val_file_list = osp.join(data_dir, 'val.txt')
test_file_list = osp.join(data_dir, 'test.txt')
total_img_num = 0
for file_list in [train_file_list, val_file_list, test_file_list]:
with open(file_list, 'r') as fid:
print("\n-----------------------------\nCheck {}...".format(
file_list))
lines = fid.readlines()
if not lines:
print("File list is empty!")
continue
for line in tqdm(lines):
line = line.strip()
parts = line.split(separator)
img_name, grt_name = parts[0], parts[1]
img_path = os.path.join(data_dir, img_name)
img = read_img(img_path)
if total_img_num == 0:
channel = img.shape[2]
total_means = np.zeros(channel)
total_stds = np.zeros(channel)
means, stds = compute_single_img(img, clip_min_value,
clip_max_value)
total_means += means
total_stds += stds
total_img_num += 1
# count mean, std
total_means = total_means / total_img_num
total_stds = total_stds / total_img_num
print("\nCount the channel-by-channel mean and std of the image:\n"
"mean = {}\nstd = {}".format(total_means, total_stds))
def cal_clip_percentage(pkl_path, clip_min_value, clip_max_value):
"""
Calculate the percentage of pixels to be clipped
"""
with open(pkl_path, 'rb') as f:
percentage, img_value_num = pickle.load(f)
for k in range(len(img_value_num)):
range_pixel = 0
for i, element in enumerate(img_value_num[k]):
if clip_min_value[k] <= i <= clip_max_value[k]:
range_pixel += element
sum_pixel = sum(img_value_num[k])
print('channel {}, the percentage of pixels to be clipped = {}'.format(
k, 1 - range_pixel / sum_pixel))
def main():
args = parse_args()
data_dir = args.data_dir
separator = args.separator
clip_min_value = args.clip_min_value
clip_max_value = args.clip_max_value
pkl_path = args.pkl_path
cal_normalize_coefficient(data_dir, separator, clip_min_value,
clip_max_value)
cal_clip_percentage(pkl_path, clip_min_value, clip_max_value)
if __name__ == "__main__":
main()
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import os.path as osp
import sys
import argparse
from PIL import Image
from tqdm import tqdm
import imghdr
import logging
import pickle
import gdal
def parse_args():
parser = argparse.ArgumentParser(
description='Data analyse and data check before training.')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='Dataset directory',
default=None,
type=str)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
default=None,
type=int)
parser.add_argument(
'--separator',
dest='separator',
help='file list separator',
default=" ",
type=str)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def read_img(img_path):
img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
dataset = gdal.Open(img_path)
if dataset == None:
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0))
elif ext == '.npy':
return np.load(img_path)
else:
raise Exception('Not support {} image format!'.format(ext))
def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value):
channel = img.shape[2]
means = np.zeros(channel)
stds = np.zeros(channel)
for k in range(channel):
img_k = img[:, :, k]
# count mean, std
means[k] = np.mean(img_k)
stds[k] = np.std(img_k)
# count min, max
min_value = np.min(img_k)
max_value = np.max(img_k)
if img_max_value[k] < max_value:
img_max_value[k] = max_value
if img_min_value[k] > min_value:
img_min_value[k] = min_value
# count the distribution of image value, value number
unique, counts = np.unique(img_k, return_counts=True)
add_num = []
max_unique = np.max(unique)
add_len = max_unique - len(img_value_num[k]) + 1
if add_len > 0:
img_value_num[k] += ([0] * add_len)
for i in range(len(unique)):
value = unique[i]
img_value_num[k][value] += counts[i]
img_value_num[k] += add_num
return means, stds, img_min_value, img_max_value, img_value_num
def dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num,
img_min_value, img_max_value, total_img_num,
logger):
logger.info("\n-----------------------------\nDataset pixel statistics...")
# count the distribution of image value, value number
if not img_value_num:
return
logger.info("\nImage pixel statistics:")
total_ratio = []
[total_ratio.append([]) for i in range(len(img_value_num))]
for k in range(len(img_value_num)):
total_num = sum(img_value_num[k])
total_ratio[k] = [i / total_num for i in img_value_num[k]]
total_ratio[k] = np.around(total_ratio[k], decimals=4)
with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f:
pickle.dump([total_ratio, img_value_num], f)
# print min value, max value
logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".format(
img_min_value, img_max_value))
# count mean, std
total_means = total_means / total_img_num
total_stds = total_stds / total_img_num
print("\nCount the channel-by-channel mean and std of the image:\n"
"mean = {}\nstd = {}".format(total_means, total_stds))
def error_print(str):
return "".join(["\nNOT PASS ", str])
def correct_print(str):
return "".join(["\nPASS ", str])
def pil_imread(file_path):
"""read pseudo-color label"""
im = Image.open(file_path)
return np.asarray(im)
def get_img_shape_range(img, max_width, max_height, min_width, min_height):
"""获取图片最大和最小宽高"""
img_shape = img.shape
height, width = img_shape[0], img_shape[1]
max_height = max(height, max_height)
max_width = max(width, max_width)
min_height = min(height, min_height)
min_width = min(width, min_width)
return max_width, max_height, min_width, min_height
def get_image_dim(img, img_dim):
"""获取图像的通道数"""
img_shape = img.shape
if img_shape[-1] not in img_dim:
img_dim.append(img_shape[-1])
return img_dim
def is_label_single_channel(label):
"""判断标签是否为灰度图"""
label_shape = label.shape
if len(label_shape) == 2:
return True
else:
return False
def image_label_shape_check(img, label):
"""
验证图像和标注的大小是否匹配
"""
flag = True
img_height = img.shape[0]
img_width = img.shape[1]
label_height = label.shape[0]
label_width = label.shape[1]
if img_height != label_height or img_width != label_width:
flag = False
return flag
def ground_truth_check(label, label_path):
"""
验证标注图像的格式
统计标注图类别和像素数
params:
label: 标注图
label_path: 标注图路径
return:
png_format: 返回是否是png格式图片
unique: 返回标注类别
counts: 返回标注的像素数
"""
if imghdr.what(label_path) == "png":
png_format = True
else:
png_format = False
unique, counts = np.unique(label, return_counts=True)
return png_format, unique, counts
def sum_label_check(png_format, label_classes, num_of_each_class, ignore_index,
num_classes, png_format_right_num, png_format_wrong_num,
total_label_classes, total_num_of_each_class):
"""
统计所有标注图上的格式、类别和每个类别的像素数
params:
png_format: 是否是png格式图片
label_classes: 标注类别
num_of_each_class: 各个类别的像素数目
"""
is_label_correct = True
if png_format:
png_format_right_num += 1
else:
png_format_wrong_num += 1
if ignore_index in label_classes:
label_classes2 = np.delete(label_classes,
np.where(label_classes == ignore_index))
else:
label_classes2 = label_classes
if min(label_classes2) < 0 or max(label_classes2) > num_classes - 1:
is_label_correct = False
add_class = []
add_num = []
for i in range(len(label_classes)):
gi = label_classes[i]
if gi in total_label_classes:
j = total_label_classes.index(gi)
total_num_of_each_class[j] += num_of_each_class[i]
else:
add_class.append(gi)
add_num.append(num_of_each_class[i])
total_num_of_each_class += add_num
total_label_classes += add_class
return is_label_correct, png_format_right_num, png_format_wrong_num, total_num_of_each_class, total_label_classes
def label_check_statistics(num_classes, png_format_wrong_image,
png_format_right_num, png_format_wrong_num,
total_label_classes, total_num_of_each_class,
wrong_labels, logger):
"""
对标注图像进行校验,输出校验结果
"""
if png_format_wrong_num == 0:
if png_format_right_num:
logger.info(correct_print("label format check"))
else:
logger.info(error_print("label format check"))
logger.info("No label image to check")
return
else:
logger.info(error_print("label format check"))
logger.info(
"total {} label images are png format, {} label images are not png "
"format".format(png_format_right_num, png_format_wrong_num))
if len(png_format_wrong_image) > 0:
for i in png_format_wrong_image:
logger.debug(i)
total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
total_ratio = np.around(total_ratio, decimals=4)
total_nc = sorted(
zip(total_label_classes, total_ratio, total_num_of_each_class))
if len(wrong_labels) == 0 and not total_nc[0][0]:
logger.info(correct_print("label class check!"))
else:
logger.info(error_print("label class check!"))
if total_nc[0][0]:
logger.info("Warning: label classes should start from 0")
if len(wrong_labels) > 0:
logger.info(
"fatal error: label class is out of range [0, {}]".format(
num_classes - 1))
for i in wrong_labels:
logger.debug(i)
logger.info(
"\nLabel pixel statistics:\n"
"(label class, percentage, total pixel number) = {} ".format(total_nc))
def shape_check(shape_unequal_image, logger):
"""输出shape校验结果"""
if len(shape_unequal_image) == 0:
logger.info(correct_print("shape check"))
logger.info("All images are the same shape as the labels")
else:
logger.info(error_print("shape check"))
logger.info(
"Some images are not the same shape as the labels as follow: ")
for i in shape_unequal_image:
logger.debug(i)
def separator_check(wrong_lines, file_list, separator, logger):
"""检查分割符是否复合要求"""
if len(wrong_lines) == 0:
logger.info(
correct_print(
file_list.split(os.sep)[-1] + " DATASET.separator check"))
else:
logger.info(
error_print(
file_list.split(os.sep)[-1] + " DATASET.separator check"))
logger.info(
"The following list is not separated by {}".format(separator))
for i in wrong_lines:
logger.debug(i)
def imread_check(imread_failed, logger):
if len(imread_failed) == 0:
logger.info(correct_print("dataset reading check"))
logger.info("All images can be read successfully")
else:
logger.info(error_print("dataset reading check"))
logger.info("Failed to read {} images".format(len(imread_failed)))
for i in imread_failed:
logger.debug(i)
def single_channel_label_check(label_not_single_channel, logger):
if len(label_not_single_channel) == 0:
logger.info(correct_print("label single_channel check"))
logger.info("All label images are single_channel")
else:
logger.info(error_print("label single_channel check"))
logger.info(
"{} label images are not single_channel\nLabel pixel statistics may be insignificant"
.format(len(label_not_single_channel)))
for i in label_not_single_channel:
logger.debug(i)
def img_shape_range_statistics(max_width, min_width, max_height, min_height,
logger):
logger.info("\nImage size statistics:")
logger.info(
"max width = {} min width = {} max height = {} min height = {}".
format(max_width, min_width, max_height, min_height))
def img_dim_statistics(img_dim, logger):
logger.info("\nImage channels statistics\nImage channels = {}".format(
np.unique(img_dim)))
def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
logger):
train_file_list = osp.join(data_dir, 'train.txt')
val_file_list = osp.join(data_dir, 'val.txt')
test_file_list = osp.join(data_dir, 'test.txt')
total_img_num = 0
has_label = False
for file_list in [train_file_list, val_file_list, test_file_list]:
# initialization
imread_failed = []
max_width = 0
max_height = 0
min_width = sys.float_info.max
min_height = sys.float_info.max
label_not_single_channel = []
shape_unequal_image = []
png_format_wrong_image = []
wrong_labels = []
wrong_lines = []
png_format_right_num = 0
png_format_wrong_num = 0
total_label_classes = []
total_num_of_each_class = []
img_dim = []
with open(file_list, 'r') as fid:
logger.info("\n-----------------------------\nCheck {}...".format(
file_list))
lines = fid.readlines()
if not lines:
logger.info("File list is empty!")
continue
for line in tqdm(lines):
line = line.strip()
parts = line.split(separator)
if len(parts) == 1:
if file_list == train_file_list or file_list == val_file_list:
logger.info("Train or val list must have labels!")
break
img_name = parts
img_path = os.path.join(data_dir, img_name[0])
try:
img = read_img(img_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
elif len(parts) == 2:
has_label = True
img_name, label_name = parts[0], parts[1]
img_path = os.path.join(data_dir, img_name)
label_path = os.path.join(data_dir, label_name)
try:
img = read_img(img_path)
label = pil_imread(label_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
is_single_channel = is_label_single_channel(label)
if not is_single_channel:
label_not_single_channel.append(line)
continue
is_equal_img_label_shape = image_label_shape_check(
img, label)
if not is_equal_img_label_shape:
shape_unequal_image.append(line)
png_format, label_classes, num_of_each_class = ground_truth_check(
label, label_path)
if not png_format:
png_format_wrong_image.append(line)
is_label_correct, png_format_right_num, png_format_wrong_num, total_num_of_each_class, total_label_classes = sum_label_check(
png_format, label_classes, num_of_each_class,
ignore_index, num_classes, png_format_right_num,
png_format_wrong_num, total_label_classes,
total_num_of_each_class)
if not is_label_correct:
wrong_labels.append(line)
else:
wrong_lines.append(lines)
continue
if total_img_num == 0:
channel = img.shape[2]
total_means = np.zeros(channel)
total_stds = np.zeros(channel)
img_min_value = [sys.float_info.max] * channel
img_max_value = [0] * channel
img_value_num = []
[img_value_num.append([]) for i in range(channel)]
means, stds, img_min_value, img_max_value, img_value_num = img_pixel_statistics(
img, img_value_num, img_min_value, img_max_value)
total_means += means
total_stds += stds
max_width, max_height, min_width, min_height = get_img_shape_range(
img, max_width, max_height, min_width, min_height)
img_dim = get_image_dim(img, img_dim)
total_img_num += 1
separator_check(wrong_lines, file_list, separator, logger)
imread_check(imread_failed, logger)
img_dim_statistics(img_dim, logger)
img_shape_range_statistics(max_width, min_width, max_height,
min_height, logger)
if has_label:
single_channel_label_check(label_not_single_channel, logger)
shape_check(shape_unequal_image, logger)
label_check_statistics(
num_classes, png_format_wrong_image, png_format_right_num,
png_format_wrong_num, total_label_classes,
total_num_of_each_class, wrong_labels, logger)
dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num,
img_min_value, img_max_value, total_img_num,
logger)
def main():
args = parse_args()
data_dir = args.data_dir
ignore_index = 255
num_classes = args.num_classes
separator = args.separator
logger = logging.getLogger()
logger.setLevel('DEBUG')
BASIC_FORMAT = "%(message)s"
formatter = logging.Formatter(BASIC_FORMAT)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
sh.setLevel('INFO')
th = logging.FileHandler(
os.path.join(data_dir, 'data_analyse_and_check.log'), 'w')
th.setFormatter(formatter)
logger.addHandler(sh)
logger.addHandler(th)
data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
logger)
print("\nDetailed error information can be viewed in {}.".format(
os.path.join(data_dir, 'data_analyse_and_check.log')))
if __name__ == "__main__":
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import sys
import argparse
import matplotlib.pyplot as plt
def parse_args():
parser = argparse.ArgumentParser(
description='Visualize data distribution before training.')
parser.add_argument(
'--pkl_path',
dest='pkl_path',
help='Path of img_pixel_statistics.pkl',
default=None,
type=str)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
path = args.pkl_path
with open(path, 'rb') as f:
percentage, img_value_num = pickle.load(f)
for k in range(len(img_value_num)):
print('channel = {}'.format(k))
plt.bar(
list(range(len(img_value_num[k]))),
img_value_num[k],
width=1,
log=True)
plt.xlabel('image value')
plt.ylabel('number')
plt.title('channel={}'.format(k))
plt.show()
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -40,12 +40,46 @@ def parse_args(): ...@@ -40,12 +40,46 @@ def parse_args():
help='model save directory', help='model save directory',
default=None, default=None,
type=str) type=str)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
default=None,
type=int)
parser.add_argument( parser.add_argument(
'--channel', '--channel',
dest='channel', dest='channel',
help='number of data channel', help='number of data channel',
default=3, default=3,
type=int) type=int)
parser.add_argument(
'--clip_min_value',
dest='clip_min_value',
help='Min values for clipping data',
nargs='+',
default=None,
type=int)
parser.add_argument(
'--clip_max_value',
dest='clip_max_value',
help='Max values for clipping data',
nargs='+',
default=None,
type=int)
parser.add_argument(
'--mean',
dest='mean',
help='Data means',
nargs='+',
default=None,
type=float)
parser.add_argument(
'--std',
dest='std',
help='Data standard deviation',
nargs='+',
default=None,
type=float)
parser.add_argument( parser.add_argument(
'--num_epochs', '--num_epochs',
dest='num_epochs', dest='num_epochs',
...@@ -66,15 +100,32 @@ def parse_args(): ...@@ -66,15 +100,32 @@ def parse_args():
args = parse_args() args = parse_args()
data_dir = args.data_dir data_dir = args.data_dir
save_dir = args.save_dir save_dir = args.save_dir
num_classes = args.num_classes
channel = args.channel channel = args.channel
clip_min_value = args.clip_min_value
clip_max_value = args.clip_max_value
mean = args.mean
std = args.std
num_epochs = args.num_epochs num_epochs = args.num_epochs
train_batch_size = args.train_batch_size train_batch_size = args.train_batch_size
lr = args.lr lr = args.lr
# 定义训练和验证时的transforms # 定义训练和验证时的transforms
train_transforms = T.Compose([T.RandomHorizontalFlip(0.5), T.Normalize()]) train_transforms = T.Compose([
T.RandomVerticalFlip(0.5),
T.RandomHorizontalFlip(0.5),
T.ResizeStepScaling(0.5, 2.0, 0.25),
T.RandomPaddingCrop(1000),
T.Clip(min_val=clip_min_value, max_val=clip_max_value),
T.Normalize(
min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
])
eval_transforms = T.Compose([T.Normalize()]) eval_transforms = T.Compose([
T.Clip(min_val=clip_min_value, max_val=clip_max_value),
T.Normalize(
min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
])
train_list = osp.join(data_dir, 'train.txt') train_list = osp.join(data_dir, 'train.txt')
val_list = osp.join(data_dir, 'val.txt') val_list = osp.join(data_dir, 'val.txt')
...@@ -95,17 +146,9 @@ eval_reader = Reader( ...@@ -95,17 +146,9 @@ eval_reader = Reader(
transforms=eval_transforms) transforms=eval_transforms)
if args.model_type == 'unet': if args.model_type == 'unet':
model = UNet( model = UNet(num_classes=num_classes, input_channel=channel)
num_classes=2,
input_channel=channel,
use_bce_loss=True,
use_dice_loss=True)
elif args.model_type == 'hrnet': elif args.model_type == 'hrnet':
model = HRNet( model = HRNet(num_classes=num_classes, input_channel=channel)
num_classes=2,
input_channel=channel,
use_bce_loss=True,
use_dice_loss=True)
else: else:
raise ValueError( raise ValueError(
"--model_type: {} is set wrong, it shold be one of ('unet', " "--model_type: {} is set wrong, it shold be one of ('unet', "
...@@ -116,6 +159,7 @@ model.train( ...@@ -116,6 +159,7 @@ model.train(
train_reader=train_reader, train_reader=train_reader,
train_batch_size=train_batch_size, train_batch_size=train_batch_size,
eval_reader=eval_reader, eval_reader=eval_reader,
eval_best_metric='miou',
save_interval_epochs=5, save_interval_epochs=5,
log_interval_steps=10, log_interval_steps=10,
save_dir=save_dir, save_dir=save_dir,
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from PIL import Image from PIL import Image
import cv2 import cv2
from collections import OrderedDict from collections import OrderedDict
from readers.reader import read_img
class Compose: class Compose:
...@@ -58,7 +59,7 @@ class Compose: ...@@ -58,7 +59,7 @@ class Compose:
if im_info is None: if im_info is None:
im_info = dict() im_info = dict()
im = np.load(im) im = read_img(im)
if im is None: if im is None:
raise ValueError('Can\'t read The image file {}!'.format(im)) raise ValueError('Can\'t read The image file {}!'.format(im))
if label is not None: if label is not None:
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册