diff --git a/paddlex/__init__.py b/paddlex/__init__.py index 0a9cb735de37ff1643ac715e20270b222a00407a..1b8cea3503a07b4301240b5b0834f5a648db75e8 100644 --- a/paddlex/__init__.py +++ b/paddlex/__init__.py @@ -32,7 +32,7 @@ from . import slim from . import convertor from . import tools from . import deploy -from . import RemoteSensing +from . import remotesensing try: import pycocotools diff --git a/paddlex/cv/datasets/__init__.py b/paddlex/cv/datasets/__init__.py index bd5275246eaf0f9357417de28c6f7c4eb68f3f07..04516ff08af6c08ce95c519ca809382e1de55d44 100644 --- a/paddlex/cv/datasets/__init__.py +++ b/paddlex/cv/datasets/__init__.py @@ -20,3 +20,4 @@ from .easydata_cls import EasyDataCls from .easydata_det import EasyDataDet from .easydata_seg import EasyDataSeg from .dataset import generate_minibatch +from .analysis import Seg diff --git a/paddlex/cv/datasets/analysis.py b/paddlex/cv/datasets/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..b09bafb2d642ec6a9444f918a3eeb588bf3eb217 --- /dev/null +++ b/paddlex/cv/datasets/analysis.py @@ -0,0 +1,366 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import numpy as np +import os.path as osp +import cv2 +from PIL import Image +import pickle +import threading +import multiprocessing as mp + +import paddlex.utils.logging as logging +from paddlex.utils import path_normalization +from .dataset import get_encoding + + +class Seg: + def __init__(self, data_dir, file_list, label_list): + self.data_dir = data_dir + self.file_list_path = file_list + self.file_list = list() + self.labels = list() + with open(label_list, encoding=get_encoding(label_list)) as f: + for line in f: + item = line.strip() + self.labels.append(item) + + with open(file_list, encoding=get_encoding(file_list)) as f: + for line in f: + if line.count(" ") > 1: + raise Exception( + "A space is defined as the separator, but it exists in image or label name {}." + .format(line)) + items = line.strip().split() + items[0] = path_normalization(items[0]) + items[1] = path_normalization(items[1]) + full_path_im = osp.join(data_dir, items[0]) + full_path_label = osp.join(data_dir, items[1]) + if not osp.exists(full_path_im): + raise IOError('The image file {} is not exist!'.format( + full_path_im)) + if not osp.exists(full_path_label): + raise IOError('The image file {} is not exist!'.format( + full_path_label)) + self.file_list.append([full_path_im, full_path_label]) + self.num_samples = len(self.file_list) + + @staticmethod + def decode_image(im, label): + if isinstance(im, np.ndarray): + if len(im.shape) != 3: + raise Exception( + "im should be 3-dimensions, but now is {}-dimensions". + format(len(im.shape))) + else: + try: + im = cv2.imread(im) + except: + raise ValueError('Can\'t read The image file {}!'.format(im)) + im = im.astype('float32') + if label is not None: + if isinstance(label, np.ndarray): + if len(label.shape) != 2: + raise Exception( + "label should be 2-dimensions, but now is {}-dimensions". + format(len(label.shape))) + + else: + try: + label = np.asarray(Image.open(label)) + except: + ValueError('Can\'t read The label file {}!'.format(label)) + im_height, im_width, _ = im.shape + label_height, label_width = label.shape + if im_height != label_height or im_width != label_width: + raise Exception( + "The height or width of the image is not same as the label") + return (im, label) + + def _get_shape(self): + max_height = max(self.im_height_list) + max_width = max(self.im_width_list) + min_height = min(self.im_height_list) + min_width = min(self.im_width_list) + shape_info = { + 'max_height': max_height, + 'max_width': max_width, + 'min_height': min_height, + 'min_width': min_width, + } + return shape_info + + def _get_label_pixel_info(self): + pixel_num = np.dot(self.im_height_list, self.im_width_list) + label_pixel_info = dict() + for label_value, label_value_num in zip(self.label_value_list, + self.label_value_num_list): + for v, n in zip(label_value, label_value_num): + if v not in label_pixel_info.keys(): + label_pixel_info[v] = [n, float(n) / float(pixel_num)] + else: + label_pixel_info[v][0] += n + label_pixel_info[v][1] += float(n) / float(pixel_num) + + return label_pixel_info + + def _get_image_pixel_info(self): + channel = max([len(im_value) for im_value in self.im_value_list]) + im_pixel_info = [dict() for c in range(channel)] + for im_value, im_value_num in zip(self.im_value_list, + self.im_value_num_list): + for c in range(channel): + for v, n in zip(im_value[c], im_value_num[c]): + if v not in im_pixel_info[c].keys(): + im_pixel_info[c][v] = n + else: + im_pixel_info[c][v] += n + mode = osp.split(self.file_list_path)[-1].split('.')[0] + with open( + osp.join(self.data_dir, + '{}_image_pixel_info.pkl'.format(mode)), 'wb') as f: + pickle.dump(im_pixel_info, f) + + import matplotlib.pyplot as plt + plot_id = (channel // 3 + 1) * 100 + 31 + for c in range(channel): + if c > 8: + continue + plt.subplot(plot_id + c) + plt.bar(im_pixel_info[c].keys(), + im_pixel_info[c].values(), + width=1, + log=True) + plt.xlabel('image pixel value') + plt.ylabel('number') + plt.title('channel={}'.format(c)) + plt.savefig( + osp.join(self.data_dir, '{}_image_pixel_info.png'.format(mode)), + dpi=800) + plt.close() + return im_pixel_info + + def _get_mean_std(self): + im_mean = np.asarray(self.im_mean_list) + im_mean = im_mean.sum(axis=0) + im_mean = im_mean / len(self.file_list) + im_mean /= 255. + + im_std = np.asarray(self.im_std_list) + im_std = im_std.sum(axis=0) + im_std = im_std / len(self.file_list) + im_std /= 255. + + return (im_mean, im_std) + + def _get_image_info(self, start, end): + for id in range(start, end): + full_path_im, full_path_label = self.file_list[id] + image, label = self.decode_image(full_path_im, full_path_label) + + height, width, channel = image.shape + self.im_height_list[id] = height + self.im_width_list[id] = width + self.im_channel_list[id] = channel + + self.im_mean_list[ + id] = [np.mean(image[:, :, c]) for c in range(channel)] + self.im_std_list[ + id] = [np.mean(image[:, :, c]) for c in range(channel)] + for c in range(channel): + unique, counts = np.unique(image[:, :, c], return_counts=True) + self.im_value_list[id].extend([unique]) + self.im_value_num_list[id].extend([counts]) + + unique, counts = np.unique(label, return_counts=True) + self.label_value_list[id] = unique + self.label_value_num_list[id] = counts + + def _get_clipped_mean_std(self, start, end, clip_min_value, + clip_max_value): + for id in range(start, end): + full_path_im, full_path_label = self.file_list[id] + image, label = self.decode_image(full_path_im, full_path_label) + for c in range(self.channel_num): + np.clip( + image[:, :, c], + clip_min_value[c], + clip_max_value[c], + out=image[:, :, c]) + image[:, :, c] -= clip_min_value[c] + image[:, :, c] /= clip_max_value[c] - clip_min_value[c] + self.clipped_im_mean_list[id] = [ + image[:, :, c].mean() for c in range(self.channel_num) + ] + self.clipped_im_std_list[ + id] = [image[:, :, c].std() for c in range(self.channel_num)] + + def analysis(self): + self.im_mean_list = [[] for i in range(len(self.file_list))] + self.im_std_list = [[] for i in range(len(self.file_list))] + self.im_value_list = [[] for i in range(len(self.file_list))] + self.im_value_num_list = [[] for i in range(len(self.file_list))] + self.im_height_list = np.zeros(len(self.file_list), dtype='int32') + self.im_width_list = np.zeros(len(self.file_list), dtype='int32') + self.im_channel_list = np.zeros(len(self.file_list), dtype='int32') + self.label_value_list = [[] for i in range(len(self.file_list))] + self.label_value_num_list = [[] for i in range(len(self.file_list))] + + num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8 + num_workers = 6 + threads = [] + one_worker_file = len(self.file_list) // num_workers + for i in range(num_workers): + start = one_worker_file * i + end = one_worker_file * ( + i + 1) if i < num_workers - 1 else len(self.file_list) + t = threading.Thread( + target=self._get_image_info, args=(start, end)) + print("====", len(self.file_list), start, end) + #t.daemon = True + threads.append(t) + for t in threads: + t.start() + for t in threads: + t.join() + print('ok') + import time + import sys + sys.exit(0) + time.sleep(1000000) + return + + #self._get_image_info(0, len(self.file_list)) + unique, counts = np.unique(self.im_channel_list, return_counts=True) + print('==== unique') + if len(unique) > 1: + raise Exception("There are {} kinds of image channels: {}.".format( + len(unique), unique[:])) + self.channel_num = unique[0] + shape_info = self._get_shape() + print('==== shape_info') + self.max_height = shape_info['max_height'] + self.max_width = shape_info['max_width'] + self.min_height = shape_info['min_height'] + self.min_width = shape_info['min_width'] + self.label_pixel_info = self._get_label_pixel_info() + print('==== label_pixel_info') + self.im_pixel_info = self._get_image_pixel_info() + print('==== im_pixel_info') + im_mean, im_std = self._get_mean_std() + print('==== get_mean_std') + max_im_value = list() + min_im_value = list() + for c in range(self.channel_num): + max_im_value.append(max(self.im_pixel_info[c].keys())) + min_im_value.append(min(self.im_pixel_info[c].keys())) + self.max_im_value = np.asarray(max_im_value) + self.min_im_value = np.asarray(min_im_value) + + logging.info( + "############## The analysis results are as follows ##############\n" + ) + logging.info("{} samples in file {}\n".format( + len(self.file_list), self.file_list_path)) + logging.info("Maximal image height: {} Maximal image width: {}.\n". + format(self.max_height, self.max_width)) + logging.info("Minimal image height: {} Minimal image width: {}.\n". + format(self.min_height, self.min_width)) + logging.info("Image channel is {}.\n".format(self.channel_num)) + logging.info( + "Image mean value: {} Image standard deviation: {} (normalized by 255, sorted by a BGR format).\n". + format(im_mean, im_std)) + logging.info( + "Label pixel information is shown in a format of (label_id, the number of label_id, the ratio of label_id):" + ) + for v, (n, r) in self.label_pixel_info.items(): + logging.info("({}, {}, {})".format(v, n, r)) + mode = osp.split(self.file_list_path)[-1].split('.')[0] + saved_pkl_file = osp.join(self.data_dir, + '{}_image_pixel_info.pkl'.format(mode)) + saved_png_file = osp.join(self.data_dir, + '{}_image_pixel_info.png'.format(mode)) + logging.info( + "Image pixel information is saved in the file '{}' and shown in the file '{}'". + format(saved_pkl_file, saved_png_file)) + + def cal_clipvalue_ratio(self, clip_min_value, clip_max_value): + if len(clip_min_value) != self.channel_num or len( + clip_max_value) != self.channel_num: + raise Exception( + "The length of clip_min_value or clip_max_value should be equal to the number of image channel {}." + .format(self.channle_num)) + for c in range(self.channel_num): + if clip_min_value[c] < self.min_im_value[c] or clip_min_value[ + c] > self.max_im_value[c]: + raise Exception( + "Clip_min_value of the channel {} is not in [{}, {}]". + format(c, self.min_im_value[c], self.max_im_value[c])) + if clip_max_value[c] < self.min_im_value[c] or clip_max_value[ + c] > self.max_im_value[c]: + raise Exception( + "Clip_max_value of the channel {} is not in [{}, {}]". + format(c, self.min_im_value[c], self.max_im_value[c])) + clip_pixel_num = 0 + pixel_num = sum(self.im_pixel_info[c].values()) + for v, n in self.im_pixel_info[c].items(): + if v < clip_min_value[c] or v > clip_max_value[c]: + clip_pixel_num += n + logging.info("Channel {}, the ratio of pixels to be clipped = {}". + format(c, clip_pixel_num / pixel_num)) + + def cal_clipped_mean_std(self, clip_min_value, clip_max_value): + for c in range(self.channel_num): + if clip_min_value[c] < self.min_im_value[c] or clip_min_value[ + c] > self.max_im_value[c]: + raise Exception( + "Clip_min_value of the channel {} is not in [{}, {}]". + format(c, self.min_im_value[c], self.max_im_value[c])) + if clip_max_value[c] < self.min_im_value[c] or clip_max_value[ + c] > self.max_im_value[c]: + raise Exception( + "Clip_max_value of the channel {} is not in [{}, {}]". + format(c, self.min_im_value[c], self.max_im_value[c])) + + self.clipped_im_mean_list = [[] for i in range(len(self.file_list))] + self.clipped_im_std_list = [[] for i in range(len(self.file_list))] + + num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8 + threads = [] + one_worker_file = len(self.file_list) // num_workers + for i in range(num_workers): + start = one_worker_file * i + end = one_worker_file * ( + i + 1) if i < num_workers - 1 else len(self.file_list) + t = threading.Thread( + target=self._get_clipped_mean_std, + args=(start, end, clip_min_value, clip_max_value)) + threads.append(t) + for t in threads: + t.setDaemon(True) + t.start() + t.join() + + im_mean = np.asarray(self.clipped_im_mean_list) + im_mean = im_mean.sum(axis=0) + im_mean = im_mean / len(self.file_list) + + im_std = np.asarray(self.clipped_im_std_list) + im_std = im_std.sum(axis=0) + im_std = im_std / len(self.file_list) + + logging.info( + "Image mean value: {} Image standard deviation: {} (normalized by (clip_max_value - clip_min_value)).\n". + format(im_mean, im_std)) diff --git a/paddlex/cv/datasets/seg_dataset.py b/paddlex/cv/datasets/seg_dataset.py index aea7d5eb6058c0011928137786cf600725611846..cc80fc12ebe1e9da24beae69e65171256fdbd7ec 100644 --- a/paddlex/cv/datasets/seg_dataset.py +++ b/paddlex/cv/datasets/seg_dataset.py @@ -20,7 +20,6 @@ import paddlex.utils.logging as logging from paddlex.utils import path_normalization from .dataset import Dataset from .dataset import get_encoding -from .dataset import is_pic class SegDataset(Dataset): @@ -64,6 +63,10 @@ class SegDataset(Dataset): self.labels.append(item) with open(file_list, encoding=get_encoding(file_list)) as f: for line in f: + if line.count(" ") > 1: + raise Exception( + "A space is defined as the separator, but it exists in image or label name {}." + .format(line)) items = line.strip().split() items[0] = path_normalization(items[0]) items[1] = path_normalization(items[1]) diff --git a/paddlex/cv/models/deeplabv3p.py b/paddlex/cv/models/deeplabv3p.py index fe1c294ae61d5d7e6e18696e56ff22909d8cc6c8..f9c76293abe54020a1fc4c5aae3a55a57770cbbe 100644 --- a/paddlex/cv/models/deeplabv3p.py +++ b/paddlex/cv/models/deeplabv3p.py @@ -65,6 +65,7 @@ class DeepLabv3p(BaseAPI): def __init__(self, num_classes=2, + input_channel=3, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, @@ -114,6 +115,7 @@ class DeepLabv3p(BaseAPI): self.backbone = backbone self.num_classes = num_classes + self.input_channel = input_channel self.use_bce_loss = use_bce_loss self.use_dice_loss = use_dice_loss self.class_weight = class_weight @@ -215,6 +217,7 @@ class DeepLabv3p(BaseAPI): def build_net(self, mode='train'): model = paddlex.cv.nets.segmentation.DeepLabv3p( self.num_classes, + input_channel=self.input_channel, mode=mode, backbone=self._get_backbone(self.backbone), output_stride=self.output_stride, diff --git a/paddlex/cv/models/fast_scnn.py b/paddlex/cv/models/fast_scnn.py index 36f6ffbb887ce868c38578dec18e099a71fb7f02..21003bffbfef2b4466a66db0fe8cc4680b03fb8b 100644 --- a/paddlex/cv/models/fast_scnn.py +++ b/paddlex/cv/models/fast_scnn.py @@ -48,6 +48,7 @@ class FastSCNN(DeepLabv3p): def __init__(self, num_classes=2, + input_channel=3, use_bce_loss=False, use_dice_loss=False, class_weight=None, @@ -86,6 +87,7 @@ class FastSCNN(DeepLabv3p): ) self.num_classes = num_classes + self.input_channel = input_channel self.use_bce_loss = use_bce_loss self.use_dice_loss = use_dice_loss self.class_weight = class_weight @@ -97,6 +99,7 @@ class FastSCNN(DeepLabv3p): def build_net(self, mode='train'): model = paddlex.cv.nets.segmentation.FastSCNN( self.num_classes, + input_channel=self.input_channel, mode=mode, use_bce_loss=self.use_bce_loss, use_dice_loss=self.use_dice_loss, diff --git a/paddlex/cv/models/hrnet.py b/paddlex/cv/models/hrnet.py index 8d9a224de34c91ea9663d2fe4cbed2683f817662..cc4154a42ad079271b2e55f754305123ab4e4474 100644 --- a/paddlex/cv/models/hrnet.py +++ b/paddlex/cv/models/hrnet.py @@ -44,6 +44,7 @@ class HRNet(DeepLabv3p): def __init__(self, num_classes=2, + input_channel=3, width=18, use_bce_loss=False, use_dice_loss=False, @@ -72,6 +73,7 @@ class HRNet(DeepLabv3p): 'Expect class_weight is a list or string but receive {}'. format(type(class_weight))) self.num_classes = num_classes + self.input_channel = input_channel self.width = width self.use_bce_loss = use_bce_loss self.use_dice_loss = use_dice_loss @@ -83,6 +85,7 @@ class HRNet(DeepLabv3p): def build_net(self, mode='train'): model = paddlex.cv.nets.segmentation.HRNet( self.num_classes, + input_channel=self.input_channel, width=self.width, mode=mode, use_bce_loss=self.use_bce_loss, diff --git a/paddlex/cv/models/ppyolo.py b/paddlex/cv/models/ppyolo.py index e82dea4b10b4857d4aeea86e1c4998fdaa7358dc..eab6e9565adcbad559100b7b8aad031e5815c39d 100644 --- a/paddlex/cv/models/ppyolo.py +++ b/paddlex/cv/models/ppyolo.py @@ -36,7 +36,7 @@ class PPYOLO(BaseAPI): Args: num_classes (int): 类别数。默认为80。 - backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd']。默认为'ResNet50_vd'。 + backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd_ssld']。默认为'ResNet50_vd_ssld'。 with_dcn_v2 (bool): Backbone是否使用DCNv2结构。默认为True。 anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值 [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], diff --git a/paddlex/cv/nets/segmentation/deeplabv3p.py b/paddlex/cv/nets/segmentation/deeplabv3p.py index 7d597a606a88a78513452c37357b806c4dfa156f..db8c0ef80c2c110c115fce3def1c6ca285db0dac 100644 --- a/paddlex/cv/nets/segmentation/deeplabv3p.py +++ b/paddlex/cv/nets/segmentation/deeplabv3p.py @@ -72,6 +72,7 @@ class DeepLabv3p(object): def __init__(self, num_classes, backbone, + input_channel=3, mode='train', output_stride=16, aspp_with_sep_conv=True, @@ -115,6 +116,7 @@ class DeepLabv3p(object): format(type(class_weight))) self.num_classes = num_classes + self.input_channel = input_channel self.backbone = backbone self.mode = mode self.use_bce_loss = use_bce_loss @@ -402,13 +404,16 @@ class DeepLabv3p(object): if self.fixed_input_shape is not None: input_shape = [ - None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] + None, self.input_channel, self.fixed_input_shape[1], + self.fixed_input_shape[0] ] inputs['image'] = fluid.data( dtype='float32', shape=input_shape, name='image') else: inputs['image'] = fluid.data( - dtype='float32', shape=[None, 3, None, None], name='image') + dtype='float32', + shape=[None, self.input_channel, None, None], + name='image') if self.mode == 'train': inputs['label'] = fluid.data( dtype='int32', shape=[None, 1, None, None], name='label') diff --git a/paddlex/cv/nets/segmentation/fast_scnn.py b/paddlex/cv/nets/segmentation/fast_scnn.py index 8e86f4bffa275c3d7660d3d2f7b01151c2785c41..daffa69db752a71b82f02200a15b0f79fb21e277 100644 --- a/paddlex/cv/nets/segmentation/fast_scnn.py +++ b/paddlex/cv/nets/segmentation/fast_scnn.py @@ -33,6 +33,7 @@ from .model_utils.loss import bce_loss class FastSCNN(object): def __init__(self, num_classes, + input_channel=3, mode='train', use_bce_loss=False, use_dice_loss=False, @@ -62,6 +63,7 @@ class FastSCNN(object): format(type(class_weight))) self.num_classes = num_classes + self.input_channel = input_channel self.mode = mode self.use_bce_loss = use_bce_loss self.use_dice_loss = use_dice_loss @@ -137,13 +139,16 @@ class FastSCNN(object): inputs = OrderedDict() if self.fixed_input_shape is not None: input_shape = [ - None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] + None, self.input_channel, self.fixed_input_shape[1], + self.fixed_input_shape[0] ] inputs['image'] = fluid.data( dtype='float32', shape=input_shape, name='image') else: inputs['image'] = fluid.data( - dtype='float32', shape=[None, 3, None, None], name='image') + dtype='float32', + shape=[None, self.input_channel, None, None], + name='image') if self.mode == 'train': inputs['label'] = fluid.data( dtype='int32', shape=[None, 1, None, None], name='label') diff --git a/paddlex/cv/nets/segmentation/hrnet.py b/paddlex/cv/nets/segmentation/hrnet.py index b74c044951f62a0dcc70fbc9964f42f781f4d573..fe0f690bccefe9a38cebb5d00794ab99dfb1658e 100644 --- a/paddlex/cv/nets/segmentation/hrnet.py +++ b/paddlex/cv/nets/segmentation/hrnet.py @@ -32,6 +32,7 @@ import paddlex class HRNet(object): def __init__(self, num_classes, + input_channel=3, mode='train', width=18, use_bce_loss=False, @@ -61,6 +62,7 @@ class HRNet(object): format(type(class_weight))) self.num_classes = num_classes + self.input_channel = input_channel self.mode = mode self.use_bce_loss = use_bce_loss self.use_dice_loss = use_dice_loss @@ -136,13 +138,16 @@ class HRNet(object): if self.fixed_input_shape is not None: input_shape = [ - None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] + None, self.input_channel, self.fixed_input_shape[1], + self.fixed_input_shape[0] ] inputs['image'] = fluid.data( dtype='float32', shape=input_shape, name='image') else: inputs['image'] = fluid.data( - dtype='float32', shape=[None, 3, None, None], name='image') + dtype='float32', + shape=[None, self.input_channel, None, None], + name='image') if self.mode == 'train': inputs['label'] = fluid.data( dtype='int32', shape=[None, 1, None, None], name='label') diff --git a/paddlex/cv/transforms/seg_transforms.py b/paddlex/cv/transforms/seg_transforms.py index b0e6fa015a2d1cc2ee9590661ce5fe72a8bfdba7..bc16af3c1b555d9f1bdda5d345afff89010071fb 100644 --- a/paddlex/cv/transforms/seg_transforms.py +++ b/paddlex/cv/transforms/seg_transforms.py @@ -74,8 +74,22 @@ class Compose(SegTransform): raise ValueError('Can\'t read The image file {}!'.format(im)) im = im.astype('float32') if label is not None: - if not isinstance(label, np.ndarray): - label = np.asarray(Image.open(label)) + if isinstance(label, np.ndarray): + if len(label.shape) != 2: + raise Exception( + "label should be 2-dimensions, but now is {}-dimensions". + format(len(label.shape))) + + else: + try: + label = np.asarray(Image.open(label)) + except: + ValueError('Can\'t read The label file {}!'.format(label)) + im_height, im_width, _ = im.shape + label_height, label_width = label.shape + if im_height != label_height or im_width != label_width: + raise Exception( + "The height or width of the image is not same as the label") return (im, label) def __call__(self, im, im_info=None, label=None): @@ -605,6 +619,7 @@ class Normalize(SegTransform): mean = np.array(self.mean)[np.newaxis, np.newaxis, :] std = np.array(self.std)[np.newaxis, np.newaxis, :] im = normalize(im, mean, std, self.min_val, self.max_val) + im = im.astype('float32') if label is None: return (im, im_info) diff --git a/paddlex/RemoteSensing/__init__.py b/paddlex/remotesensing/__init__.py similarity index 100% rename from paddlex/RemoteSensing/__init__.py rename to paddlex/remotesensing/__init__.py diff --git a/paddlex/RemoteSensing/train_demo.py b/paddlex/remotesensing/train_demo.py similarity index 94% rename from paddlex/RemoteSensing/train_demo.py rename to paddlex/remotesensing/train_demo.py index d36528e1381aecafa0276e273de0d442b888d484..4a193d5b77814b9cda2e50843656f528050d51d6 100644 --- a/paddlex/RemoteSensing/train_demo.py +++ b/paddlex/remotesensing/train_demo.py @@ -16,7 +16,7 @@ import os.path as osp import argparse from paddlex.seg import transforms -import paddlex.RemoteSensing.transforms as custom_transforms +import paddlex.remotesensing.transforms as rs_transforms import paddlex as pdx @@ -110,22 +110,22 @@ train_transforms = transforms.Compose([ transforms.RandomHorizontalFlip(0.5), transforms.ResizeStepScaling(0.5, 2.0, 0.25), transforms.RandomPaddingCrop(im_padding_value=[1000] * channel), - custom_transforms.Clip( + rs_transforms.Clip( min_val=clip_min_value, max_val=clip_max_value), transforms.Normalize( min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std), ]) -train_transforms.decode_image = custom_transforms.decode_image +train_transforms.decode_image = rs_transforms.decode_image eval_transforms = transforms.Compose([ - custom_transforms.Clip( + rs_transforms.Clip( min_val=clip_min_value, max_val=clip_max_value), transforms.Normalize( min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std), ]) -eval_transforms.decode_image = custom_transforms.decode_image +eval_transforms.decode_image = rs_transforms.decode_image train_list = osp.join(data_dir, 'train.txt') val_list = osp.join(data_dir, 'val.txt') diff --git a/paddlex/RemoteSensing/transforms.py b/paddlex/remotesensing/transforms.py similarity index 69% rename from paddlex/RemoteSensing/transforms.py rename to paddlex/remotesensing/transforms.py index f342d99a634604e52a7342785347c68edc07ad81..d1f4422333436659b5a24287f808e55eeac52405 100644 --- a/paddlex/RemoteSensing/transforms.py +++ b/paddlex/remotesensing/transforms.py @@ -2,17 +2,23 @@ import os import os.path as osp import imghdr import gdal +gdal.UseExceptions() +gdal.PushErrorHandler('CPLQuietErrorHandler') import numpy as np from PIL import Image from paddlex.seg import transforms +import paddlex.utils.logging as logging def read_img(img_path): img_format = imghdr.what(img_path) name, ext = osp.splitext(img_path) if img_format == 'tiff' or ext == '.img': - dataset = gdal.Open(img_path) + try: + dataset = gdal.Open(img_path) + except: + logging.error(gdal.GetLastErrorMsg()) if dataset == None: raise Exception('Can not open', img_path) im_data = dataset.ReadAsArray() @@ -36,9 +42,25 @@ def decode_image(im, label): im = read_img(im) except: raise ValueError('Can\'t read The image file {}!'.format(im)) + im = im.astype('float32') + if label is not None: - if not isinstance(label, np.ndarray): - label = read_img(label) + if isinstance(label, np.ndarray): + if len(label.shape) != 2: + raise Exception( + "label should be 2-dimensions, but now is {}-dimensions". + format(len(label.shape))) + + else: + try: + label = np.asarray(Image.open(label)) + except: + ValueError('Can\'t read The label file {}!'.format(label)) + im_height, im_width, _ = im.shape + label_height, label_width = label.shape + if im_height != label_height or im_width != label_width: + raise Exception( + "The height or width of the image is not same as the label") return (im, label) diff --git a/paddlex/remotesensing/utils/__init__.py b/paddlex/remotesensing/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/paddlex/remotesensing/utils/analyse.py b/paddlex/remotesensing/utils/analyse.py new file mode 100644 index 0000000000000000000000000000000000000000..000c0bd0fbc859626a416b059dff719b461a20c9 --- /dev/null +++ b/paddlex/remotesensing/utils/analyse.py @@ -0,0 +1,506 @@ +# coding: utf8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import os +import os.path as osp +import sys +import argparse +from PIL import Image +from tqdm import tqdm +import imghdr +import logging +import pickle +import gdal + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Data analyse and data check before training.') + parser.add_argument( + '--data_dir', + dest='data_dir', + help='Dataset directory', + default=None, + type=str) + parser.add_argument( + '--num_classes', + dest='num_classes', + help='Number of classes', + default=None, + type=int) + parser.add_argument( + '--separator', + dest='separator', + help='file list separator', + default=" ", + type=str) + parser.add_argument( + '--ignore_index', + dest='ignore_index', + help='Ignored class index', + default=255, + type=int) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args() + + +def read_img(img_path): + img_format = imghdr.what(img_path) + name, ext = osp.splitext(img_path) + if img_format == 'tiff' or ext == '.img': + dataset = gdal.Open(img_path) + if dataset == None: + raise Exception('Can not open', img_path) + im_data = dataset.ReadAsArray() + return im_data.transpose((1, 2, 0)) + elif ext == '.npy': + return np.load(img_path) + else: + raise Exception('Not support {} image format!'.format(ext)) + + +def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value): + channel = img.shape[2] + means = np.zeros(channel) + stds = np.zeros(channel) + for k in range(channel): + img_k = img[:, :, k] + + # count mean, std + means[k] = np.mean(img_k) + stds[k] = np.std(img_k) + + # count min, max + min_value = np.min(img_k) + max_value = np.max(img_k) + if img_max_value[k] < max_value: + img_max_value[k] = max_value + if img_min_value[k] > min_value: + img_min_value[k] = min_value + + # count the distribution of image value, value number + unique, counts = np.unique(img_k, return_counts=True) + add_num = [] + max_unique = np.max(unique) + add_len = max_unique - len(img_value_num[k]) + 1 + if add_len > 0: + img_value_num[k] += ([0] * add_len) + for i in range(len(unique)): + value = unique[i] + img_value_num[k][value] += counts[i] + + img_value_num[k] += add_num + return means, stds, img_min_value, img_max_value, img_value_num + + +def data_distribution_statistics(data_dir, img_value_num, logger): + """count the distribution of image value, value number + """ + logger.info( + "\n-----------------------------\nThe whole dataset statistics...") + + if not img_value_num: + return + logger.info("\nImage pixel statistics:") + total_ratio = [] + [total_ratio.append([]) for i in range(len(img_value_num))] + for k in range(len(img_value_num)): + total_num = sum(img_value_num[k]) + total_ratio[k] = [i / total_num for i in img_value_num[k]] + total_ratio[k] = np.around(total_ratio[k], decimals=4) + with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f: + pickle.dump([total_ratio, img_value_num], f) + + +def data_range_statistics(img_min_value, img_max_value, logger): + """print min value, max value + """ + logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}". + format(img_min_value, img_max_value)) + + +def cal_normalize_coefficient(total_means, total_stds, total_img_num, logger): + """count mean, std + """ + total_means = total_means / total_img_num + total_stds = total_stds / total_img_num + logger.info("\nCount the channel-by-channel mean and std of the image:\n" + "mean = {}\nstd = {}".format(total_means, total_stds)) + + +def error_print(str): + return "".join(["\nNOT PASS ", str]) + + +def correct_print(str): + return "".join(["\nPASS ", str]) + + +def pil_imread(file_path): + """read pseudo-color label""" + im = Image.open(file_path) + return np.asarray(im) + + +def get_img_shape_range(img, max_width, max_height, min_width, min_height): + """获取图片最大和最小宽高""" + img_shape = img.shape + height, width = img_shape[0], img_shape[1] + max_height = max(height, max_height) + max_width = max(width, max_width) + min_height = min(height, min_height) + min_width = min(width, min_width) + return max_width, max_height, min_width, min_height + + +def get_img_channel_num(img, img_channels): + """获取图像的通道数""" + img_shape = img.shape + if img_shape[-1] not in img_channels: + img_channels.append(img_shape[-1]) + return img_channels + + +def is_label_single_channel(label): + """判断标签是否为灰度图""" + label_shape = label.shape + if len(label_shape) == 2: + return True + else: + return False + + +def image_label_shape_check(img, label): + """ + 验证图像和标注的大小是否匹配 + """ + + flag = True + img_height = img.shape[0] + img_width = img.shape[1] + label_height = label.shape[0] + label_width = label.shape[1] + + if img_height != label_height or img_width != label_width: + flag = False + return flag + + +def ground_truth_check(label, label_path): + """ + 验证标注图像的格式 + 统计标注图类别和像素数 + params: + label: 标注图 + label_path: 标注图路径 + return: + png_format: 返回是否是png格式图片 + unique: 返回标注类别 + counts: 返回标注的像素数 + """ + if imghdr.what(label_path) == "png": + png_format = True + else: + png_format = False + + unique, counts = np.unique(label, return_counts=True) + + return png_format, unique, counts + + +def sum_label_check(label_classes, num_of_each_class, ignore_index, + num_classes, total_label_classes, total_num_of_each_class): + """ + 统计所有标注图上的类别和每个类别的像素数 + params: + label_classes: 标注类别 + num_of_each_class: 各个类别的像素数目 + """ + is_label_correct = True + + if ignore_index in label_classes: + label_classes2 = np.delete(label_classes, + np.where(label_classes == ignore_index)) + else: + label_classes2 = label_classes + if min(label_classes2) < 0 or max(label_classes2) > num_classes - 1: + is_label_correct = False + add_class = [] + add_num = [] + for i in range(len(label_classes)): + gi = label_classes[i] + if gi in total_label_classes: + j = total_label_classes.index(gi) + total_num_of_each_class[j] += num_of_each_class[i] + else: + add_class.append(gi) + add_num.append(num_of_each_class[i]) + total_num_of_each_class += add_num + total_label_classes += add_class + return is_label_correct, total_num_of_each_class, total_label_classes + + +def label_class_check(num_classes, total_label_classes, + total_num_of_each_class, wrong_labels, logger): + """ + 检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。 + + **NOTE:** + 标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。 + 标注类别最好从0开始,否则可能影响精度。 + """ + total_ratio = total_num_of_each_class / sum(total_num_of_each_class) + total_ratio = np.around(total_ratio, decimals=4) + total_nc = sorted( + zip(total_label_classes, total_ratio, total_num_of_each_class)) + if len(wrong_labels) == 0 and not total_nc[0][0]: + logger.info(correct_print("label class check!")) + else: + logger.info(error_print("label class check!")) + if total_nc[0][0]: + logger.info("Warning: label classes should start from 0") + if len(wrong_labels) > 0: + logger.info("fatal error: label class is out of range [0, {}]". + format(num_classes - 1)) + for i in wrong_labels: + logger.debug(i) + return total_nc + + +def label_class_statistics(total_nc, logger): + """ + 对标注图像进行校验,输出校验结果 + """ + logger.info("\nLabel class statistics:\n" + "(label class, percentage, total pixel number) = {} ".format( + total_nc)) + + +def shape_check(shape_unequal_image, logger): + """输出shape校验结果""" + if len(shape_unequal_image) == 0: + logger.info(correct_print("shape check")) + logger.info("All images are the same shape as the labels") + else: + logger.info(error_print("shape check")) + logger.info( + "Some images are not the same shape as the labels as follow: ") + for i in shape_unequal_image: + logger.debug(i) + + +def separator_check(wrong_lines, file_list, separator, logger): + """检查分割符是否复合要求""" + if len(wrong_lines) == 0: + logger.info( + correct_print( + file_list.split(os.sep)[-1] + " DATASET.separator check")) + else: + logger.info( + error_print( + file_list.split(os.sep)[-1] + " DATASET.separator check")) + logger.info("The following list is not separated by {}".format( + separator)) + for i in wrong_lines: + logger.debug(i) + + +def imread_check(imread_failed, logger): + if len(imread_failed) == 0: + logger.info(correct_print("dataset reading check")) + logger.info("All images can be read successfully") + else: + logger.info(error_print("dataset reading check")) + logger.info("Failed to read {} images".format(len(imread_failed))) + for i in imread_failed: + logger.debug(i) + + +def single_channel_label_check(label_not_single_channel, logger): + if len(label_not_single_channel) == 0: + logger.info(correct_print("label single_channel check")) + logger.info("All label images are single_channel") + else: + logger.info(error_print("label single_channel check")) + logger.info( + "{} label images are not single_channel\nLabel pixel statistics may be insignificant" + .format(len(label_not_single_channel))) + for i in label_not_single_channel: + logger.debug(i) + + +def img_shape_range_statistics(max_width, min_width, max_height, min_height, + logger): + logger.info("\nImage size statistics:") + logger.info( + "max width = {} min width = {} max height = {} min height = {}". + format(max_width, min_width, max_height, min_height)) + + +def img_channels_statistics(img_channels, logger): + logger.info("\nImage channels statistics\nImage channels = {}".format( + np.unique(img_channels))) + + +def data_analyse_and_check(data_dir, num_classes, separator, ignore_index, + logger): + train_file_list = osp.join(data_dir, 'train.txt') + val_file_list = osp.join(data_dir, 'val.txt') + test_file_list = osp.join(data_dir, 'test.txt') + total_img_num = 0 + has_label = False + for file_list in [train_file_list, val_file_list, test_file_list]: + # initialization + imread_failed = [] + max_width = 0 + max_height = 0 + min_width = sys.float_info.max + min_height = sys.float_info.max + label_not_single_channel = [] + shape_unequal_image = [] + wrong_labels = [] + wrong_lines = [] + total_label_classes = [] + total_num_of_each_class = [] + img_channels = [] + + with open(file_list, 'r') as fid: + logger.info("\n-----------------------------\nCheck {}...".format( + file_list)) + lines = fid.readlines() + if not lines: + logger.info("File list is empty!") + continue + for line in tqdm(lines): + line = line.strip() + parts = line.split(separator) + if len(parts) == 1: + if file_list == train_file_list or file_list == val_file_list: + logger.info("Train or val list must have labels!") + break + img_name = parts + img_path = os.path.join(data_dir, img_name[0]) + try: + img = read_img(img_path) + except Exception as e: + imread_failed.append((line, str(e))) + continue + elif len(parts) == 2: + has_label = True + img_name, label_name = parts[0], parts[1] + img_path = os.path.join(data_dir, img_name) + label_path = os.path.join(data_dir, label_name) + try: + img = read_img(img_path) + label = pil_imread(label_path) + except Exception as e: + imread_failed.append((line, str(e))) + continue + + is_single_channel = is_label_single_channel(label) + if not is_single_channel: + label_not_single_channel.append(line) + continue + is_equal_img_label_shape = image_label_shape_check(img, + label) + if not is_equal_img_label_shape: + shape_unequal_image.append(line) + png_format, label_classes, num_of_each_class = ground_truth_check( + label, label_path) + is_label_correct, total_num_of_each_class, total_label_classes = sum_label_check( + label_classes, num_of_each_class, ignore_index, + num_classes, total_label_classes, + total_num_of_each_class) + if not is_label_correct: + wrong_labels.append(line) + else: + wrong_lines.append(lines) + continue + + if total_img_num == 0: + channel = img.shape[2] + total_means = np.zeros(channel) + total_stds = np.zeros(channel) + img_min_value = [sys.float_info.max] * channel + img_max_value = [0] * channel + img_value_num = [] + [img_value_num.append([]) for i in range(channel)] + means, stds, img_min_value, img_max_value, img_value_num = img_pixel_statistics( + img, img_value_num, img_min_value, img_max_value) + total_means += means + total_stds += stds + max_width, max_height, min_width, min_height = get_img_shape_range( + img, max_width, max_height, min_width, min_height) + img_channels = get_img_channel_num(img, img_channels) + total_img_num += 1 + + # data check + separator_check(wrong_lines, file_list, separator, logger) + imread_check(imread_failed, logger) + if has_label: + single_channel_label_check(label_not_single_channel, logger) + shape_check(shape_unequal_image, logger) + total_nc = label_class_check(num_classes, total_label_classes, + total_num_of_each_class, + wrong_labels, logger) + + # data analyse on train, validation, test set. + img_channels_statistics(img_channels, logger) + img_shape_range_statistics(max_width, min_width, max_height, + min_height, logger) + if has_label: + label_class_statistics(total_nc, logger) + # data analyse on the whole dataset. + data_range_statistics(img_min_value, img_max_value, logger) + data_distribution_statistics(data_dir, img_value_num, logger) + cal_normalize_coefficient(total_means, total_stds, total_img_num, logger) + + +def main(): + args = parse_args() + data_dir = args.data_dir + ignore_index = args.ignore_index + num_classes = args.num_classes + separator = args.separator + + logger = logging.getLogger() + logger.setLevel('DEBUG') + BASIC_FORMAT = "%(message)s" + formatter = logging.Formatter(BASIC_FORMAT) + sh = logging.StreamHandler() + sh.setFormatter(formatter) + sh.setLevel('INFO') + th = logging.FileHandler( + os.path.join(data_dir, 'data_analyse_and_check.log'), 'w') + th.setFormatter(formatter) + logger.addHandler(sh) + logger.addHandler(th) + + data_analyse_and_check(data_dir, num_classes, separator, ignore_index, + logger) + + print("\nDetailed error information can be viewed in {}.".format( + os.path.join(data_dir, 'data_analyse_and_check.log'))) + + +if __name__ == "__main__": + main()