From 04b3f1b0f7716e938fab1f0a8749416e6df9cf06 Mon Sep 17 00:00:00 2001 From: LutaoChu <30695251+LutaoChu@users.noreply.github.com> Date: Sun, 7 Jun 2020 21:39:23 +0800 Subject: [PATCH] Add whole process of remote sensing (#287) * support .tif .img read * add data analyse and check tools * upgrade to adapt new dataset * remove global * optimize function, file names --- contrib/RemoteSensing/__init__.py | 2 +- contrib/RemoteSensing/models/__init__.py | 2 +- contrib/RemoteSensing/models/base.py | 2 +- contrib/RemoteSensing/models/hrnet.py | 2 +- contrib/RemoteSensing/models/load_model.py | 2 +- contrib/RemoteSensing/models/unet.py | 4 +- .../RemoteSensing/models/utils/visualize.py | 46 ++ contrib/RemoteSensing/nets/__init__.py | 2 +- contrib/RemoteSensing/nets/hrnet.py | 2 +- contrib/RemoteSensing/nets/libs.py | 2 +- contrib/RemoteSensing/nets/loss.py | 2 +- contrib/RemoteSensing/nets/unet.py | 2 +- contrib/RemoteSensing/predict_demo.py | 28 +- contrib/RemoteSensing/readers/__init__.py | 2 +- contrib/RemoteSensing/readers/base.py | 2 +- contrib/RemoteSensing/readers/reader.py | 25 +- contrib/RemoteSensing/tools/cal_norm_coef.py | 168 ++++++ .../tools/create_dataset_list.py | 2 +- .../tools/data_analyse_and_check.py | 513 ++++++++++++++++++ .../tools/data_distribution_vis.py | 52 ++ .../RemoteSensing/tools/split_dataset_list.py | 2 +- contrib/RemoteSensing/train_demo.py | 70 ++- contrib/RemoteSensing/transforms/__init__.py | 2 +- contrib/RemoteSensing/transforms/ops.py | 2 +- .../RemoteSensing/transforms/transforms.py | 5 +- contrib/RemoteSensing/utils/__init__.py | 2 +- contrib/RemoteSensing/utils/logging.py | 2 +- contrib/RemoteSensing/utils/metrics.py | 2 +- .../RemoteSensing/utils/pretrain_weights.py | 2 +- contrib/RemoteSensing/utils/utils.py | 2 +- 30 files changed, 903 insertions(+), 50 deletions(-) create mode 100644 contrib/RemoteSensing/models/utils/visualize.py create mode 100644 contrib/RemoteSensing/tools/cal_norm_coef.py create mode 100644 contrib/RemoteSensing/tools/data_analyse_and_check.py create mode 100644 contrib/RemoteSensing/tools/data_distribution_vis.py diff --git a/contrib/RemoteSensing/__init__.py b/contrib/RemoteSensing/__init__.py index 8e5bc0b3..6406dd3b 100644 --- a/contrib/RemoteSensing/__init__.py +++ b/contrib/RemoteSensing/__init__.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/models/__init__.py b/contrib/RemoteSensing/models/__init__.py index 08e8fd5a..aae31421 100644 --- a/contrib/RemoteSensing/models/__init__.py +++ b/contrib/RemoteSensing/models/__init__.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/models/base.py b/contrib/RemoteSensing/models/base.py index 06a6bc6d..556c9ee0 100644 --- a/contrib/RemoteSensing/models/base.py +++ b/contrib/RemoteSensing/models/base.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/models/hrnet.py b/contrib/RemoteSensing/models/hrnet.py index 2b575b04..65663386 100644 --- a/contrib/RemoteSensing/models/hrnet.py +++ b/contrib/RemoteSensing/models/hrnet.py @@ -125,7 +125,7 @@ class HRNet(BaseModel): train_reader, train_batch_size=2, eval_reader=None, - eval_best_metric='kappa', + eval_best_metric='miou', save_interval_epochs=1, log_interval_steps=2, save_dir='output', diff --git a/contrib/RemoteSensing/models/load_model.py b/contrib/RemoteSensing/models/load_model.py index 53a61a01..1fcf22b0 100644 --- a/contrib/RemoteSensing/models/load_model.py +++ b/contrib/RemoteSensing/models/load_model.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/models/unet.py b/contrib/RemoteSensing/models/unet.py index 732c8374..3ae0b780 100644 --- a/contrib/RemoteSensing/models/unet.py +++ b/contrib/RemoteSensing/models/unet.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -127,7 +127,7 @@ class UNet(BaseModel): train_reader, train_batch_size=2, eval_reader=None, - eval_best_metric='kappa', + eval_best_metric='miou', save_interval_epochs=1, log_interval_steps=2, save_dir='output', diff --git a/contrib/RemoteSensing/models/utils/visualize.py b/contrib/RemoteSensing/models/utils/visualize.py new file mode 100644 index 00000000..a47a756f --- /dev/null +++ b/contrib/RemoteSensing/models/utils/visualize.py @@ -0,0 +1,46 @@ +import os +import os.path as osp +import numpy as np +from PIL import Image as Image + + +def get_color_map_list(num_classes): + """ Returns the color map for visualizing the segmentation mask, + which can support arbitrary number of classes. + Args: + num_classes: Number of classes + Returns: + The color map + """ + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + + return color_map + + +def splice_imgs(img_list, vis_path): + """Splice pictures horizontally + """ + IMAGE_WIDTH, IMAGE_HEIGHT = img_list[0].size + padding_width = 20 + img_num = len(img_list) + to_image = Image.new('RGB', + (img_num * IMAGE_WIDTH + (img_num - 1) * padding_width, + IMAGE_HEIGHT)) # Create a new picture + padding = Image.new('RGB', (padding_width, IMAGE_HEIGHT), (255, 255, 255)) + + # Loop through, paste each picture to the corresponding position in order + for i, from_image in enumerate(img_list): + to_image.paste(from_image, (i * (IMAGE_WIDTH + padding_width), 0)) + if i < img_num - 1: + to_image.paste(padding, + (i * (IMAGE_WIDTH + padding_width) + IMAGE_WIDTH, 0)) + return to_image.save(vis_path) diff --git a/contrib/RemoteSensing/nets/__init__.py b/contrib/RemoteSensing/nets/__init__.py index 7e56e961..381f327f 100644 --- a/contrib/RemoteSensing/nets/__init__.py +++ b/contrib/RemoteSensing/nets/__init__.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/nets/hrnet.py b/contrib/RemoteSensing/nets/hrnet.py index 854a0f03..32d613e6 100644 --- a/contrib/RemoteSensing/nets/hrnet.py +++ b/contrib/RemoteSensing/nets/hrnet.py @@ -1,5 +1,5 @@ # coding: utf8 -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/nets/libs.py b/contrib/RemoteSensing/nets/libs.py index f74c93fc..e475e197 100644 --- a/contrib/RemoteSensing/nets/libs.py +++ b/contrib/RemoteSensing/nets/libs.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/nets/loss.py b/contrib/RemoteSensing/nets/loss.py index a0794474..3d80416f 100644 --- a/contrib/RemoteSensing/nets/loss.py +++ b/contrib/RemoteSensing/nets/loss.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/nets/unet.py b/contrib/RemoteSensing/nets/unet.py index 79669c9d..0574f0f5 100644 --- a/contrib/RemoteSensing/nets/unet.py +++ b/contrib/RemoteSensing/nets/unet.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/predict_demo.py b/contrib/RemoteSensing/predict_demo.py index eb09cc29..5da7f0c5 100644 --- a/contrib/RemoteSensing/predict_demo.py +++ b/contrib/RemoteSensing/predict_demo.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import numpy as np from PIL import Image as Image import argparse from models import load_model +from models.utils.visualize import get_color_map_list def parse_args(): @@ -54,6 +55,13 @@ def parse_args(): help='save directory name of predict results', default='predict_results', type=str) + parser.add_argument( + '--color_map', + dest='color_map', + help='color map of predict results', + type=int, + nargs='*', + default=-1) if len(sys.argv) < 2: parser.print_help() sys.exit(1) @@ -68,37 +76,41 @@ load_model_dir = args.load_model_dir save_img_dir = args.save_img_dir if not osp.exists(save_img_dir): os.makedirs(save_img_dir) +if args.color_map == -1: + color_map = get_color_map_list(256) +else: + color_map = args.color_map # predict model = load_model(load_model_dir) -color_map = [0, 0, 0, 0, 255, 0] if single_img is not None: pred = model.predict(single_img) # 以伪彩色png图片保存预测结果 - pred_name = osp.basename(single_img).rstrip('npy') + 'png' - pred_path = osp.join(save_img_dir, pred_name) + pred_name, _ = osp.splitext(osp.basename(single_img)) + pred_path = osp.join(save_img_dir, pred_name + '.png') pred_mask = Image.fromarray(pred['label_map'].astype(np.uint8), mode='P') pred_mask.putpalette(color_map) pred_mask.save(pred_path) + print('Predict result is saved in {}'.format(pred_path)) elif (file_list is not None) and (data_dir is not None): with open(osp.join(data_dir, file_list)) as f: lines = f.readlines() for line in lines: img_path = line.split(' ')[0] - print('Predicting {}'.format(img_path)) img_path_ = osp.join(data_dir, img_path) pred = model.predict(img_path_) # 以伪彩色png图片保存预测结果 - pred_name = osp.basename(img_path).rstrip('npy') + 'png' - pred_path = osp.join(save_img_dir, pred_name) + pred_name, _ = osp.splitext(osp.basename(img_path)) + pred_path = osp.join(save_img_dir, pred_name + '.png') pred_mask = Image.fromarray( pred['label_map'].astype(np.uint8), mode='P') pred_mask.putpalette(color_map) pred_mask.save(pred_path) + print('Predict result is saved in {}'.format(pred_path)) else: raise Exception( - 'You should either set the parameter single_img, or set the parameters data_dir, file_list.' + 'You should either set the parameter single_img, or set the parameters data_dir and file_list.' ) diff --git a/contrib/RemoteSensing/readers/__init__.py b/contrib/RemoteSensing/readers/__init__.py index 42eafa21..642d80b1 100644 --- a/contrib/RemoteSensing/readers/__init__.py +++ b/contrib/RemoteSensing/readers/__init__.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/readers/base.py b/contrib/RemoteSensing/readers/base.py index 2ce42564..8e73adbc 100644 --- a/contrib/RemoteSensing/readers/base.py +++ b/contrib/RemoteSensing/readers/base.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/readers/reader.py b/contrib/RemoteSensing/readers/reader.py index a86a22e9..3e350a8a 100644 --- a/contrib/RemoteSensing/readers/reader.py +++ b/contrib/RemoteSensing/readers/reader.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,19 +15,36 @@ from __future__ import absolute_import import os.path as osp import random +import imghdr +import gdal +import numpy as np from utils import logging from .base import BaseReader from .base import get_encoding from collections import OrderedDict -from .base import is_pic + + +def read_img(img_path): + img_format = imghdr.what(img_path) + name, ext = osp.splitext(img_path) + if img_format == 'tiff' or ext == '.img': + dataset = gdal.Open(img_path) + if dataset == None: + raise Exception('Can not open', img_path) + im_data = dataset.ReadAsArray() + return im_data.transpose((1, 2, 0)) + elif ext == '.npy': + return np.load(img_path) + else: + raise Exception('Not support {} image format!'.format(ext)) class Reader(BaseReader): - """读取语分分割任务数据集,并对样本进行相应的处理。 + """读取数据集,并对样本进行相应的处理。 Args: data_dir (str): 数据集所在的目录路径。 - file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。 + file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路径)。 label_list (str): 描述数据集包含的类别信息文件路径。 transforms (list): 数据集中每个样本的预处理/增强算子。 num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。 diff --git a/contrib/RemoteSensing/tools/cal_norm_coef.py b/contrib/RemoteSensing/tools/cal_norm_coef.py new file mode 100644 index 00000000..a5fc03e0 --- /dev/null +++ b/contrib/RemoteSensing/tools/cal_norm_coef.py @@ -0,0 +1,168 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import os +import os.path as osp +import sys +import argparse +from tqdm import tqdm +import pickle +from data_analyse_and_check import read_img + + +def parse_args(): + parser = argparse.ArgumentParser( + description= + 'Compute normalization coefficient and clip percentage before training.' + ) + parser.add_argument( + '--data_dir', + dest='data_dir', + help='Dataset directory', + default=None, + type=str) + parser.add_argument( + '--pkl_path', + dest='pkl_path', + help='Path of img_pixel_statistics.pkl', + default=None, + type=str) + parser.add_argument( + '--clip_min_value', + dest='clip_min_value', + help='Min values for clipping data', + nargs='+', + default=None, + type=int) + parser.add_argument( + '--clip_max_value', + dest='clip_max_value', + help='Max values for clipping data', + nargs='+', + default=None, + type=int) + parser.add_argument( + '--separator', + dest='separator', + help='file list separator', + default=" ", + type=str) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args() + + +def compute_single_img(img, clip_min_value, clip_max_value): + channel = img.shape[2] + means = np.zeros(channel) + stds = np.zeros(channel) + for k in range(channel): + if clip_max_value != [] and clip_min_value != []: + np.clip( + img[:, :, k], + clip_min_value[k], + clip_max_value[k], + out=img[:, :, k]) + + # Rescaling (min-max normalization) + range_value = [ + clip_max_value[i] - clip_min_value[i] + for i in range(len(clip_max_value)) + ] + img_k = (img[:, :, k].astype(np.float32, copy=False) - + clip_min_value[k]) / range_value[k] + else: + img_k = img[:, :, k] + + # count mean, std + means[k] = np.mean(img_k) + stds[k] = np.std(img_k) + return means, stds + + +def cal_normalize_coefficient(data_dir, separator, clip_min_value, + clip_max_value): + train_file_list = osp.join(data_dir, 'train.txt') + val_file_list = osp.join(data_dir, 'val.txt') + test_file_list = osp.join(data_dir, 'test.txt') + total_img_num = 0 + for file_list in [train_file_list, val_file_list, test_file_list]: + with open(file_list, 'r') as fid: + print("\n-----------------------------\nCheck {}...".format( + file_list)) + lines = fid.readlines() + if not lines: + print("File list is empty!") + continue + for line in tqdm(lines): + line = line.strip() + parts = line.split(separator) + img_name, grt_name = parts[0], parts[1] + img_path = os.path.join(data_dir, img_name) + img = read_img(img_path) + if total_img_num == 0: + channel = img.shape[2] + total_means = np.zeros(channel) + total_stds = np.zeros(channel) + means, stds = compute_single_img(img, clip_min_value, + clip_max_value) + total_means += means + total_stds += stds + total_img_num += 1 + + # count mean, std + total_means = total_means / total_img_num + total_stds = total_stds / total_img_num + print("\nCount the channel-by-channel mean and std of the image:\n" + "mean = {}\nstd = {}".format(total_means, total_stds)) + + +def cal_clip_percentage(pkl_path, clip_min_value, clip_max_value): + """ + Calculate the percentage of pixels to be clipped + """ + with open(pkl_path, 'rb') as f: + percentage, img_value_num = pickle.load(f) + + for k in range(len(img_value_num)): + range_pixel = 0 + for i, element in enumerate(img_value_num[k]): + if clip_min_value[k] <= i <= clip_max_value[k]: + range_pixel += element + sum_pixel = sum(img_value_num[k]) + print('channel {}, the percentage of pixels to be clipped = {}'.format( + k, 1 - range_pixel / sum_pixel)) + + +def main(): + args = parse_args() + data_dir = args.data_dir + separator = args.separator + clip_min_value = args.clip_min_value + clip_max_value = args.clip_max_value + pkl_path = args.pkl_path + + cal_normalize_coefficient(data_dir, separator, clip_min_value, + clip_max_value) + cal_clip_percentage(pkl_path, clip_min_value, clip_max_value) + + +if __name__ == "__main__": + main() diff --git a/contrib/RemoteSensing/tools/create_dataset_list.py b/contrib/RemoteSensing/tools/create_dataset_list.py index 78a6c9b3..85d37e64 100644 --- a/contrib/RemoteSensing/tools/create_dataset_list.py +++ b/contrib/RemoteSensing/tools/create_dataset_list.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/tools/data_analyse_and_check.py b/contrib/RemoteSensing/tools/data_analyse_and_check.py new file mode 100644 index 00000000..d6225cb5 --- /dev/null +++ b/contrib/RemoteSensing/tools/data_analyse_and_check.py @@ -0,0 +1,513 @@ +# coding: utf8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import os +import os.path as osp +import sys +import argparse +from PIL import Image +from tqdm import tqdm +import imghdr +import logging +import pickle +import gdal + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Data analyse and data check before training.') + parser.add_argument( + '--data_dir', + dest='data_dir', + help='Dataset directory', + default=None, + type=str) + parser.add_argument( + '--num_classes', + dest='num_classes', + help='Number of classes', + default=None, + type=int) + parser.add_argument( + '--separator', + dest='separator', + help='file list separator', + default=" ", + type=str) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args() + + +def read_img(img_path): + img_format = imghdr.what(img_path) + name, ext = osp.splitext(img_path) + if img_format == 'tiff' or ext == '.img': + dataset = gdal.Open(img_path) + if dataset == None: + raise Exception('Can not open', img_path) + im_data = dataset.ReadAsArray() + return im_data.transpose((1, 2, 0)) + elif ext == '.npy': + return np.load(img_path) + else: + raise Exception('Not support {} image format!'.format(ext)) + + +def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value): + channel = img.shape[2] + means = np.zeros(channel) + stds = np.zeros(channel) + for k in range(channel): + img_k = img[:, :, k] + + # count mean, std + means[k] = np.mean(img_k) + stds[k] = np.std(img_k) + + # count min, max + min_value = np.min(img_k) + max_value = np.max(img_k) + if img_max_value[k] < max_value: + img_max_value[k] = max_value + if img_min_value[k] > min_value: + img_min_value[k] = min_value + + # count the distribution of image value, value number + unique, counts = np.unique(img_k, return_counts=True) + add_num = [] + max_unique = np.max(unique) + add_len = max_unique - len(img_value_num[k]) + 1 + if add_len > 0: + img_value_num[k] += ([0] * add_len) + for i in range(len(unique)): + value = unique[i] + img_value_num[k][value] += counts[i] + + img_value_num[k] += add_num + return means, stds, img_min_value, img_max_value, img_value_num + + +def dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num, + img_min_value, img_max_value, total_img_num, + logger): + logger.info("\n-----------------------------\nDataset pixel statistics...") + + # count the distribution of image value, value number + if not img_value_num: + return + logger.info("\nImage pixel statistics:") + total_ratio = [] + [total_ratio.append([]) for i in range(len(img_value_num))] + for k in range(len(img_value_num)): + total_num = sum(img_value_num[k]) + total_ratio[k] = [i / total_num for i in img_value_num[k]] + total_ratio[k] = np.around(total_ratio[k], decimals=4) + with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f: + pickle.dump([total_ratio, img_value_num], f) + + # print min value, max value + logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".format( + img_min_value, img_max_value)) + + # count mean, std + total_means = total_means / total_img_num + total_stds = total_stds / total_img_num + print("\nCount the channel-by-channel mean and std of the image:\n" + "mean = {}\nstd = {}".format(total_means, total_stds)) + + +def error_print(str): + return "".join(["\nNOT PASS ", str]) + + +def correct_print(str): + return "".join(["\nPASS ", str]) + + +def pil_imread(file_path): + """read pseudo-color label""" + im = Image.open(file_path) + return np.asarray(im) + + +def get_img_shape_range(img, max_width, max_height, min_width, min_height): + """获取图片最大和最小宽高""" + img_shape = img.shape + height, width = img_shape[0], img_shape[1] + max_height = max(height, max_height) + max_width = max(width, max_width) + min_height = min(height, min_height) + min_width = min(width, min_width) + return max_width, max_height, min_width, min_height + + +def get_image_dim(img, img_dim): + """获取图像的通道数""" + img_shape = img.shape + if img_shape[-1] not in img_dim: + img_dim.append(img_shape[-1]) + return img_dim + + +def is_label_single_channel(label): + """判断标签是否为灰度图""" + label_shape = label.shape + if len(label_shape) == 2: + return True + else: + return False + + +def image_label_shape_check(img, label): + """ + 验证图像和标注的大小是否匹配 + """ + + flag = True + img_height = img.shape[0] + img_width = img.shape[1] + label_height = label.shape[0] + label_width = label.shape[1] + + if img_height != label_height or img_width != label_width: + flag = False + return flag + + +def ground_truth_check(label, label_path): + """ + 验证标注图像的格式 + 统计标注图类别和像素数 + params: + label: 标注图 + label_path: 标注图路径 + return: + png_format: 返回是否是png格式图片 + unique: 返回标注类别 + counts: 返回标注的像素数 + """ + if imghdr.what(label_path) == "png": + png_format = True + else: + png_format = False + + unique, counts = np.unique(label, return_counts=True) + + return png_format, unique, counts + + +def sum_label_check(png_format, label_classes, num_of_each_class, ignore_index, + num_classes, png_format_right_num, png_format_wrong_num, + total_label_classes, total_num_of_each_class): + """ + 统计所有标注图上的格式、类别和每个类别的像素数 + params: + png_format: 是否是png格式图片 + label_classes: 标注类别 + num_of_each_class: 各个类别的像素数目 + """ + is_label_correct = True + + if png_format: + png_format_right_num += 1 + else: + png_format_wrong_num += 1 + + if ignore_index in label_classes: + label_classes2 = np.delete(label_classes, + np.where(label_classes == ignore_index)) + else: + label_classes2 = label_classes + if min(label_classes2) < 0 or max(label_classes2) > num_classes - 1: + is_label_correct = False + add_class = [] + add_num = [] + for i in range(len(label_classes)): + gi = label_classes[i] + if gi in total_label_classes: + j = total_label_classes.index(gi) + total_num_of_each_class[j] += num_of_each_class[i] + else: + add_class.append(gi) + add_num.append(num_of_each_class[i]) + total_num_of_each_class += add_num + total_label_classes += add_class + return is_label_correct, png_format_right_num, png_format_wrong_num, total_num_of_each_class, total_label_classes + + +def label_check_statistics(num_classes, png_format_wrong_image, + png_format_right_num, png_format_wrong_num, + total_label_classes, total_num_of_each_class, + wrong_labels, logger): + """ + 对标注图像进行校验,输出校验结果 + """ + if png_format_wrong_num == 0: + if png_format_right_num: + logger.info(correct_print("label format check")) + else: + logger.info(error_print("label format check")) + logger.info("No label image to check") + return + else: + logger.info(error_print("label format check")) + logger.info( + "total {} label images are png format, {} label images are not png " + "format".format(png_format_right_num, png_format_wrong_num)) + if len(png_format_wrong_image) > 0: + for i in png_format_wrong_image: + logger.debug(i) + + total_ratio = total_num_of_each_class / sum(total_num_of_each_class) + total_ratio = np.around(total_ratio, decimals=4) + total_nc = sorted( + zip(total_label_classes, total_ratio, total_num_of_each_class)) + if len(wrong_labels) == 0 and not total_nc[0][0]: + logger.info(correct_print("label class check!")) + else: + logger.info(error_print("label class check!")) + if total_nc[0][0]: + logger.info("Warning: label classes should start from 0") + if len(wrong_labels) > 0: + logger.info( + "fatal error: label class is out of range [0, {}]".format( + num_classes - 1)) + for i in wrong_labels: + logger.debug(i) + + logger.info( + "\nLabel pixel statistics:\n" + "(label class, percentage, total pixel number) = {} ".format(total_nc)) + + +def shape_check(shape_unequal_image, logger): + """输出shape校验结果""" + if len(shape_unequal_image) == 0: + logger.info(correct_print("shape check")) + logger.info("All images are the same shape as the labels") + else: + logger.info(error_print("shape check")) + logger.info( + "Some images are not the same shape as the labels as follow: ") + for i in shape_unequal_image: + logger.debug(i) + + +def separator_check(wrong_lines, file_list, separator, logger): + """检查分割符是否复合要求""" + if len(wrong_lines) == 0: + logger.info( + correct_print( + file_list.split(os.sep)[-1] + " DATASET.separator check")) + else: + logger.info( + error_print( + file_list.split(os.sep)[-1] + " DATASET.separator check")) + logger.info( + "The following list is not separated by {}".format(separator)) + for i in wrong_lines: + logger.debug(i) + + +def imread_check(imread_failed, logger): + if len(imread_failed) == 0: + logger.info(correct_print("dataset reading check")) + logger.info("All images can be read successfully") + else: + logger.info(error_print("dataset reading check")) + logger.info("Failed to read {} images".format(len(imread_failed))) + for i in imread_failed: + logger.debug(i) + + +def single_channel_label_check(label_not_single_channel, logger): + if len(label_not_single_channel) == 0: + logger.info(correct_print("label single_channel check")) + logger.info("All label images are single_channel") + else: + logger.info(error_print("label single_channel check")) + logger.info( + "{} label images are not single_channel\nLabel pixel statistics may be insignificant" + .format(len(label_not_single_channel))) + for i in label_not_single_channel: + logger.debug(i) + + +def img_shape_range_statistics(max_width, min_width, max_height, min_height, + logger): + logger.info("\nImage size statistics:") + logger.info( + "max width = {} min width = {} max height = {} min height = {}". + format(max_width, min_width, max_height, min_height)) + + +def img_dim_statistics(img_dim, logger): + logger.info("\nImage channels statistics\nImage channels = {}".format( + np.unique(img_dim))) + + +def data_analyse_and_check(data_dir, num_classes, separator, ignore_index, + logger): + train_file_list = osp.join(data_dir, 'train.txt') + val_file_list = osp.join(data_dir, 'val.txt') + test_file_list = osp.join(data_dir, 'test.txt') + total_img_num = 0 + has_label = False + for file_list in [train_file_list, val_file_list, test_file_list]: + # initialization + imread_failed = [] + max_width = 0 + max_height = 0 + min_width = sys.float_info.max + min_height = sys.float_info.max + label_not_single_channel = [] + shape_unequal_image = [] + png_format_wrong_image = [] + wrong_labels = [] + wrong_lines = [] + png_format_right_num = 0 + png_format_wrong_num = 0 + total_label_classes = [] + total_num_of_each_class = [] + img_dim = [] + + with open(file_list, 'r') as fid: + logger.info("\n-----------------------------\nCheck {}...".format( + file_list)) + lines = fid.readlines() + if not lines: + logger.info("File list is empty!") + continue + for line in tqdm(lines): + line = line.strip() + parts = line.split(separator) + if len(parts) == 1: + if file_list == train_file_list or file_list == val_file_list: + logger.info("Train or val list must have labels!") + break + img_name = parts + img_path = os.path.join(data_dir, img_name[0]) + try: + img = read_img(img_path) + except Exception as e: + imread_failed.append((line, str(e))) + continue + elif len(parts) == 2: + has_label = True + img_name, label_name = parts[0], parts[1] + img_path = os.path.join(data_dir, img_name) + label_path = os.path.join(data_dir, label_name) + try: + img = read_img(img_path) + label = pil_imread(label_path) + except Exception as e: + imread_failed.append((line, str(e))) + continue + + is_single_channel = is_label_single_channel(label) + if not is_single_channel: + label_not_single_channel.append(line) + continue + is_equal_img_label_shape = image_label_shape_check( + img, label) + if not is_equal_img_label_shape: + shape_unequal_image.append(line) + png_format, label_classes, num_of_each_class = ground_truth_check( + label, label_path) + if not png_format: + png_format_wrong_image.append(line) + is_label_correct, png_format_right_num, png_format_wrong_num, total_num_of_each_class, total_label_classes = sum_label_check( + png_format, label_classes, num_of_each_class, + ignore_index, num_classes, png_format_right_num, + png_format_wrong_num, total_label_classes, + total_num_of_each_class) + if not is_label_correct: + wrong_labels.append(line) + else: + wrong_lines.append(lines) + continue + + if total_img_num == 0: + channel = img.shape[2] + total_means = np.zeros(channel) + total_stds = np.zeros(channel) + img_min_value = [sys.float_info.max] * channel + img_max_value = [0] * channel + img_value_num = [] + [img_value_num.append([]) for i in range(channel)] + means, stds, img_min_value, img_max_value, img_value_num = img_pixel_statistics( + img, img_value_num, img_min_value, img_max_value) + total_means += means + total_stds += stds + max_width, max_height, min_width, min_height = get_img_shape_range( + img, max_width, max_height, min_width, min_height) + img_dim = get_image_dim(img, img_dim) + total_img_num += 1 + + separator_check(wrong_lines, file_list, separator, logger) + imread_check(imread_failed, logger) + img_dim_statistics(img_dim, logger) + img_shape_range_statistics(max_width, min_width, max_height, + min_height, logger) + + if has_label: + single_channel_label_check(label_not_single_channel, logger) + shape_check(shape_unequal_image, logger) + label_check_statistics( + num_classes, png_format_wrong_image, png_format_right_num, + png_format_wrong_num, total_label_classes, + total_num_of_each_class, wrong_labels, logger) + + dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num, + img_min_value, img_max_value, total_img_num, + logger) + + +def main(): + args = parse_args() + data_dir = args.data_dir + ignore_index = 255 + num_classes = args.num_classes + separator = args.separator + + logger = logging.getLogger() + logger.setLevel('DEBUG') + BASIC_FORMAT = "%(message)s" + formatter = logging.Formatter(BASIC_FORMAT) + sh = logging.StreamHandler() + sh.setFormatter(formatter) + sh.setLevel('INFO') + th = logging.FileHandler( + os.path.join(data_dir, 'data_analyse_and_check.log'), 'w') + th.setFormatter(formatter) + logger.addHandler(sh) + logger.addHandler(th) + + data_analyse_and_check(data_dir, num_classes, separator, ignore_index, + logger) + + print("\nDetailed error information can be viewed in {}.".format( + os.path.join(data_dir, 'data_analyse_and_check.log'))) + + +if __name__ == "__main__": + main() diff --git a/contrib/RemoteSensing/tools/data_distribution_vis.py b/contrib/RemoteSensing/tools/data_distribution_vis.py new file mode 100644 index 00000000..d911c323 --- /dev/null +++ b/contrib/RemoteSensing/tools/data_distribution_vis.py @@ -0,0 +1,52 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +import sys +import argparse +import matplotlib.pyplot as plt + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Visualize data distribution before training.') + parser.add_argument( + '--pkl_path', + dest='pkl_path', + help='Path of img_pixel_statistics.pkl', + default=None, + type=str) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + path = args.pkl_path + with open(path, 'rb') as f: + percentage, img_value_num = pickle.load(f) + + for k in range(len(img_value_num)): + print('channel = {}'.format(k)) + plt.bar( + list(range(len(img_value_num[k]))), + img_value_num[k], + width=1, + log=True) + plt.xlabel('image value') + plt.ylabel('number') + plt.title('channel={}'.format(k)) + plt.show() diff --git a/contrib/RemoteSensing/tools/split_dataset_list.py b/contrib/RemoteSensing/tools/split_dataset_list.py index 143618a3..fa5719e9 100644 --- a/contrib/RemoteSensing/tools/split_dataset_list.py +++ b/contrib/RemoteSensing/tools/split_dataset_list.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/train_demo.py b/contrib/RemoteSensing/train_demo.py index 18cb9949..a5721f24 100644 --- a/contrib/RemoteSensing/train_demo.py +++ b/contrib/RemoteSensing/train_demo.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,12 +40,46 @@ def parse_args(): help='model save directory', default=None, type=str) + parser.add_argument( + '--num_classes', + dest='num_classes', + help='Number of classes', + default=None, + type=int) parser.add_argument( '--channel', dest='channel', help='number of data channel', default=3, type=int) + parser.add_argument( + '--clip_min_value', + dest='clip_min_value', + help='Min values for clipping data', + nargs='+', + default=None, + type=int) + parser.add_argument( + '--clip_max_value', + dest='clip_max_value', + help='Max values for clipping data', + nargs='+', + default=None, + type=int) + parser.add_argument( + '--mean', + dest='mean', + help='Data means', + nargs='+', + default=None, + type=float) + parser.add_argument( + '--std', + dest='std', + help='Data standard deviation', + nargs='+', + default=None, + type=float) parser.add_argument( '--num_epochs', dest='num_epochs', @@ -66,15 +100,32 @@ def parse_args(): args = parse_args() data_dir = args.data_dir save_dir = args.save_dir +num_classes = args.num_classes channel = args.channel +clip_min_value = args.clip_min_value +clip_max_value = args.clip_max_value +mean = args.mean +std = args.std num_epochs = args.num_epochs train_batch_size = args.train_batch_size lr = args.lr # 定义训练和验证时的transforms -train_transforms = T.Compose([T.RandomHorizontalFlip(0.5), T.Normalize()]) +train_transforms = T.Compose([ + T.RandomVerticalFlip(0.5), + T.RandomHorizontalFlip(0.5), + T.ResizeStepScaling(0.5, 2.0, 0.25), + T.RandomPaddingCrop(1000), + T.Clip(min_val=clip_min_value, max_val=clip_max_value), + T.Normalize( + min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std), +]) -eval_transforms = T.Compose([T.Normalize()]) +eval_transforms = T.Compose([ + T.Clip(min_val=clip_min_value, max_val=clip_max_value), + T.Normalize( + min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std), +]) train_list = osp.join(data_dir, 'train.txt') val_list = osp.join(data_dir, 'val.txt') @@ -95,17 +146,9 @@ eval_reader = Reader( transforms=eval_transforms) if args.model_type == 'unet': - model = UNet( - num_classes=2, - input_channel=channel, - use_bce_loss=True, - use_dice_loss=True) + model = UNet(num_classes=num_classes, input_channel=channel) elif args.model_type == 'hrnet': - model = HRNet( - num_classes=2, - input_channel=channel, - use_bce_loss=True, - use_dice_loss=True) + model = HRNet(num_classes=num_classes, input_channel=channel) else: raise ValueError( "--model_type: {} is set wrong, it shold be one of ('unet', " @@ -116,6 +159,7 @@ model.train( train_reader=train_reader, train_batch_size=train_batch_size, eval_reader=eval_reader, + eval_best_metric='miou', save_interval_epochs=5, log_interval_steps=10, save_dir=save_dir, diff --git a/contrib/RemoteSensing/transforms/__init__.py b/contrib/RemoteSensing/transforms/__init__.py index 625b2bcc..2cc178c2 100644 --- a/contrib/RemoteSensing/transforms/__init__.py +++ b/contrib/RemoteSensing/transforms/__init__.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/transforms/ops.py b/contrib/RemoteSensing/transforms/ops.py index 0efca13a..9abb18de 100644 --- a/contrib/RemoteSensing/transforms/ops.py +++ b/contrib/RemoteSensing/transforms/ops.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/transforms/transforms.py b/contrib/RemoteSensing/transforms/transforms.py index 00faca4e..48ac23ae 100644 --- a/contrib/RemoteSensing/transforms/transforms.py +++ b/contrib/RemoteSensing/transforms/transforms.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import numpy as np from PIL import Image import cv2 from collections import OrderedDict +from readers.reader import read_img class Compose: @@ -58,7 +59,7 @@ class Compose: if im_info is None: im_info = dict() - im = np.load(im) + im = read_img(im) if im is None: raise ValueError('Can\'t read The image file {}!'.format(im)) if label is not None: diff --git a/contrib/RemoteSensing/utils/__init__.py b/contrib/RemoteSensing/utils/__init__.py index d839dbcc..7a4a8112 100644 --- a/contrib/RemoteSensing/utils/__init__.py +++ b/contrib/RemoteSensing/utils/__init__.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/utils/logging.py b/contrib/RemoteSensing/utils/logging.py index 8c850f61..64532505 100644 --- a/contrib/RemoteSensing/utils/logging.py +++ b/contrib/RemoteSensing/utils/logging.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/utils/metrics.py b/contrib/RemoteSensing/utils/metrics.py index 71dacd9b..80df6c5d 100644 --- a/contrib/RemoteSensing/utils/metrics.py +++ b/contrib/RemoteSensing/utils/metrics.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/utils/pretrain_weights.py b/contrib/RemoteSensing/utils/pretrain_weights.py index 241e7632..2e5397c9 100644 --- a/contrib/RemoteSensing/utils/pretrain_weights.py +++ b/contrib/RemoteSensing/utils/pretrain_weights.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/RemoteSensing/utils/utils.py b/contrib/RemoteSensing/utils/utils.py index c97a76ec..d39a43e7 100644 --- a/contrib/RemoteSensing/utils/utils.py +++ b/contrib/RemoteSensing/utils/utils.py @@ -1,5 +1,5 @@ # coding: utf8 -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -- GitLab