提交 746122f4 编写于 作者: F FlyingQianMM

add analysis for seg dataset

上级 564417a4
...@@ -32,7 +32,7 @@ from . import slim ...@@ -32,7 +32,7 @@ from . import slim
from . import convertor from . import convertor
from . import tools from . import tools
from . import deploy from . import deploy
from . import RemoteSensing from . import remotesensing
try: try:
import pycocotools import pycocotools
......
...@@ -20,3 +20,4 @@ from .easydata_cls import EasyDataCls ...@@ -20,3 +20,4 @@ from .easydata_cls import EasyDataCls
from .easydata_det import EasyDataDet from .easydata_det import EasyDataDet
from .easydata_seg import EasyDataSeg from .easydata_seg import EasyDataSeg
from .dataset import generate_minibatch from .dataset import generate_minibatch
from .analysis import Seg
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import numpy as np
import os.path as osp
import cv2
from PIL import Image
import pickle
import threading
import multiprocessing as mp
import paddlex.utils.logging as logging
from paddlex.utils import path_normalization
from .dataset import get_encoding
class Seg:
def __init__(self, data_dir, file_list, label_list):
self.data_dir = data_dir
self.file_list_path = file_list
self.file_list = list()
self.labels = list()
with open(label_list, encoding=get_encoding(label_list)) as f:
for line in f:
item = line.strip()
self.labels.append(item)
with open(file_list, encoding=get_encoding(file_list)) as f:
for line in f:
if line.count(" ") > 1:
raise Exception(
"A space is defined as the separator, but it exists in image or label name {}."
.format(line))
items = line.strip().split()
items[0] = path_normalization(items[0])
items[1] = path_normalization(items[1])
full_path_im = osp.join(data_dir, items[0])
full_path_label = osp.join(data_dir, items[1])
if not osp.exists(full_path_im):
raise IOError('The image file {} is not exist!'.format(
full_path_im))
if not osp.exists(full_path_label):
raise IOError('The image file {} is not exist!'.format(
full_path_label))
self.file_list.append([full_path_im, full_path_label])
self.num_samples = len(self.file_list)
@staticmethod
def decode_image(im, label):
if isinstance(im, np.ndarray):
if len(im.shape) != 3:
raise Exception(
"im should be 3-dimensions, but now is {}-dimensions".
format(len(im.shape)))
else:
try:
im = cv2.imread(im)
except:
raise ValueError('Can\'t read The image file {}!'.format(im))
im = im.astype('float32')
if label is not None:
if isinstance(label, np.ndarray):
if len(label.shape) != 2:
raise Exception(
"label should be 2-dimensions, but now is {}-dimensions".
format(len(label.shape)))
else:
try:
label = np.asarray(Image.open(label))
except:
ValueError('Can\'t read The label file {}!'.format(label))
im_height, im_width, _ = im.shape
label_height, label_width = label.shape
if im_height != label_height or im_width != label_width:
raise Exception(
"The height or width of the image is not same as the label")
return (im, label)
def _get_shape(self):
max_height = max(self.im_height_list)
max_width = max(self.im_width_list)
min_height = min(self.im_height_list)
min_width = min(self.im_width_list)
shape_info = {
'max_height': max_height,
'max_width': max_width,
'min_height': min_height,
'min_width': min_width,
}
return shape_info
def _get_label_pixel_info(self):
pixel_num = np.dot(self.im_height_list, self.im_width_list)
label_pixel_info = dict()
for label_value, label_value_num in zip(self.label_value_list,
self.label_value_num_list):
for v, n in zip(label_value, label_value_num):
if v not in label_pixel_info.keys():
label_pixel_info[v] = [n, float(n) / float(pixel_num)]
else:
label_pixel_info[v][0] += n
label_pixel_info[v][1] += float(n) / float(pixel_num)
return label_pixel_info
def _get_image_pixel_info(self):
channel = max([len(im_value) for im_value in self.im_value_list])
im_pixel_info = [dict() for c in range(channel)]
for im_value, im_value_num in zip(self.im_value_list,
self.im_value_num_list):
for c in range(channel):
for v, n in zip(im_value[c], im_value_num[c]):
if v not in im_pixel_info[c].keys():
im_pixel_info[c][v] = n
else:
im_pixel_info[c][v] += n
mode = osp.split(self.file_list_path)[-1].split('.')[0]
with open(
osp.join(self.data_dir,
'{}_image_pixel_info.pkl'.format(mode)), 'wb') as f:
pickle.dump(im_pixel_info, f)
import matplotlib.pyplot as plt
plot_id = (channel // 3 + 1) * 100 + 31
for c in range(channel):
if c > 8:
continue
plt.subplot(plot_id + c)
plt.bar(im_pixel_info[c].keys(),
im_pixel_info[c].values(),
width=1,
log=True)
plt.xlabel('image pixel value')
plt.ylabel('number')
plt.title('channel={}'.format(c))
plt.savefig(
osp.join(self.data_dir, '{}_image_pixel_info.png'.format(mode)),
dpi=800)
plt.close()
return im_pixel_info
def _get_mean_std(self):
im_mean = np.asarray(self.im_mean_list)
im_mean = im_mean.sum(axis=0)
im_mean = im_mean / len(self.file_list)
im_mean /= 255.
im_std = np.asarray(self.im_std_list)
im_std = im_std.sum(axis=0)
im_std = im_std / len(self.file_list)
im_std /= 255.
return (im_mean, im_std)
def _get_image_info(self, start, end):
for id in range(start, end):
full_path_im, full_path_label = self.file_list[id]
image, label = self.decode_image(full_path_im, full_path_label)
height, width, channel = image.shape
self.im_height_list[id] = height
self.im_width_list[id] = width
self.im_channel_list[id] = channel
self.im_mean_list[
id] = [np.mean(image[:, :, c]) for c in range(channel)]
self.im_std_list[
id] = [np.mean(image[:, :, c]) for c in range(channel)]
for c in range(channel):
unique, counts = np.unique(image[:, :, c], return_counts=True)
self.im_value_list[id].extend([unique])
self.im_value_num_list[id].extend([counts])
unique, counts = np.unique(label, return_counts=True)
self.label_value_list[id] = unique
self.label_value_num_list[id] = counts
def _get_clipped_mean_std(self, start, end, clip_min_value,
clip_max_value):
for id in range(start, end):
full_path_im, full_path_label = self.file_list[id]
image, label = self.decode_image(full_path_im, full_path_label)
for c in range(self.channel_num):
np.clip(
image[:, :, c],
clip_min_value[c],
clip_max_value[c],
out=image[:, :, c])
image[:, :, c] -= clip_min_value[c]
image[:, :, c] /= clip_max_value[c] - clip_min_value[c]
self.clipped_im_mean_list[id] = [
image[:, :, c].mean() for c in range(self.channel_num)
]
self.clipped_im_std_list[
id] = [image[:, :, c].std() for c in range(self.channel_num)]
def analysis(self):
self.im_mean_list = [[] for i in range(len(self.file_list))]
self.im_std_list = [[] for i in range(len(self.file_list))]
self.im_value_list = [[] for i in range(len(self.file_list))]
self.im_value_num_list = [[] for i in range(len(self.file_list))]
self.im_height_list = np.zeros(len(self.file_list), dtype='int32')
self.im_width_list = np.zeros(len(self.file_list), dtype='int32')
self.im_channel_list = np.zeros(len(self.file_list), dtype='int32')
self.label_value_list = [[] for i in range(len(self.file_list))]
self.label_value_num_list = [[] for i in range(len(self.file_list))]
num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
num_workers = 6
threads = []
one_worker_file = len(self.file_list) // num_workers
for i in range(num_workers):
start = one_worker_file * i
end = one_worker_file * (
i + 1) if i < num_workers - 1 else len(self.file_list)
t = threading.Thread(
target=self._get_image_info, args=(start, end))
print("====", len(self.file_list), start, end)
#t.daemon = True
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
print('ok')
import time
import sys
sys.exit(0)
time.sleep(1000000)
return
#self._get_image_info(0, len(self.file_list))
unique, counts = np.unique(self.im_channel_list, return_counts=True)
print('==== unique')
if len(unique) > 1:
raise Exception("There are {} kinds of image channels: {}.".format(
len(unique), unique[:]))
self.channel_num = unique[0]
shape_info = self._get_shape()
print('==== shape_info')
self.max_height = shape_info['max_height']
self.max_width = shape_info['max_width']
self.min_height = shape_info['min_height']
self.min_width = shape_info['min_width']
self.label_pixel_info = self._get_label_pixel_info()
print('==== label_pixel_info')
self.im_pixel_info = self._get_image_pixel_info()
print('==== im_pixel_info')
im_mean, im_std = self._get_mean_std()
print('==== get_mean_std')
max_im_value = list()
min_im_value = list()
for c in range(self.channel_num):
max_im_value.append(max(self.im_pixel_info[c].keys()))
min_im_value.append(min(self.im_pixel_info[c].keys()))
self.max_im_value = np.asarray(max_im_value)
self.min_im_value = np.asarray(min_im_value)
logging.info(
"############## The analysis results are as follows ##############\n"
)
logging.info("{} samples in file {}\n".format(
len(self.file_list), self.file_list_path))
logging.info("Maximal image height: {} Maximal image width: {}.\n".
format(self.max_height, self.max_width))
logging.info("Minimal image height: {} Minimal image width: {}.\n".
format(self.min_height, self.min_width))
logging.info("Image channel is {}.\n".format(self.channel_num))
logging.info(
"Image mean value: {} Image standard deviation: {} (normalized by 255, sorted by a BGR format).\n".
format(im_mean, im_std))
logging.info(
"Label pixel information is shown in a format of (label_id, the number of label_id, the ratio of label_id):"
)
for v, (n, r) in self.label_pixel_info.items():
logging.info("({}, {}, {})".format(v, n, r))
mode = osp.split(self.file_list_path)[-1].split('.')[0]
saved_pkl_file = osp.join(self.data_dir,
'{}_image_pixel_info.pkl'.format(mode))
saved_png_file = osp.join(self.data_dir,
'{}_image_pixel_info.png'.format(mode))
logging.info(
"Image pixel information is saved in the file '{}' and shown in the file '{}'".
format(saved_pkl_file, saved_png_file))
def cal_clipvalue_ratio(self, clip_min_value, clip_max_value):
if len(clip_min_value) != self.channel_num or len(
clip_max_value) != self.channel_num:
raise Exception(
"The length of clip_min_value or clip_max_value should be equal to the number of image channel {}."
.format(self.channle_num))
for c in range(self.channel_num):
if clip_min_value[c] < self.min_im_value[c] or clip_min_value[
c] > self.max_im_value[c]:
raise Exception(
"Clip_min_value of the channel {} is not in [{}, {}]".
format(c, self.min_im_value[c], self.max_im_value[c]))
if clip_max_value[c] < self.min_im_value[c] or clip_max_value[
c] > self.max_im_value[c]:
raise Exception(
"Clip_max_value of the channel {} is not in [{}, {}]".
format(c, self.min_im_value[c], self.max_im_value[c]))
clip_pixel_num = 0
pixel_num = sum(self.im_pixel_info[c].values())
for v, n in self.im_pixel_info[c].items():
if v < clip_min_value[c] or v > clip_max_value[c]:
clip_pixel_num += n
logging.info("Channel {}, the ratio of pixels to be clipped = {}".
format(c, clip_pixel_num / pixel_num))
def cal_clipped_mean_std(self, clip_min_value, clip_max_value):
for c in range(self.channel_num):
if clip_min_value[c] < self.min_im_value[c] or clip_min_value[
c] > self.max_im_value[c]:
raise Exception(
"Clip_min_value of the channel {} is not in [{}, {}]".
format(c, self.min_im_value[c], self.max_im_value[c]))
if clip_max_value[c] < self.min_im_value[c] or clip_max_value[
c] > self.max_im_value[c]:
raise Exception(
"Clip_max_value of the channel {} is not in [{}, {}]".
format(c, self.min_im_value[c], self.max_im_value[c]))
self.clipped_im_mean_list = [[] for i in range(len(self.file_list))]
self.clipped_im_std_list = [[] for i in range(len(self.file_list))]
num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
threads = []
one_worker_file = len(self.file_list) // num_workers
for i in range(num_workers):
start = one_worker_file * i
end = one_worker_file * (
i + 1) if i < num_workers - 1 else len(self.file_list)
t = threading.Thread(
target=self._get_clipped_mean_std,
args=(start, end, clip_min_value, clip_max_value))
threads.append(t)
for t in threads:
t.setDaemon(True)
t.start()
t.join()
im_mean = np.asarray(self.clipped_im_mean_list)
im_mean = im_mean.sum(axis=0)
im_mean = im_mean / len(self.file_list)
im_std = np.asarray(self.clipped_im_std_list)
im_std = im_std.sum(axis=0)
im_std = im_std / len(self.file_list)
logging.info(
"Image mean value: {} Image standard deviation: {} (normalized by (clip_max_value - clip_min_value)).\n".
format(im_mean, im_std))
...@@ -20,7 +20,6 @@ import paddlex.utils.logging as logging ...@@ -20,7 +20,6 @@ import paddlex.utils.logging as logging
from paddlex.utils import path_normalization from paddlex.utils import path_normalization
from .dataset import Dataset from .dataset import Dataset
from .dataset import get_encoding from .dataset import get_encoding
from .dataset import is_pic
class SegDataset(Dataset): class SegDataset(Dataset):
...@@ -64,6 +63,10 @@ class SegDataset(Dataset): ...@@ -64,6 +63,10 @@ class SegDataset(Dataset):
self.labels.append(item) self.labels.append(item)
with open(file_list, encoding=get_encoding(file_list)) as f: with open(file_list, encoding=get_encoding(file_list)) as f:
for line in f: for line in f:
if line.count(" ") > 1:
raise Exception(
"A space is defined as the separator, but it exists in image or label name {}."
.format(line))
items = line.strip().split() items = line.strip().split()
items[0] = path_normalization(items[0]) items[0] = path_normalization(items[0])
items[1] = path_normalization(items[1]) items[1] = path_normalization(items[1])
......
...@@ -65,6 +65,7 @@ class DeepLabv3p(BaseAPI): ...@@ -65,6 +65,7 @@ class DeepLabv3p(BaseAPI):
def __init__(self, def __init__(self,
num_classes=2, num_classes=2,
input_channel=3,
backbone='MobileNetV2_x1.0', backbone='MobileNetV2_x1.0',
output_stride=16, output_stride=16,
aspp_with_sep_conv=True, aspp_with_sep_conv=True,
...@@ -114,6 +115,7 @@ class DeepLabv3p(BaseAPI): ...@@ -114,6 +115,7 @@ class DeepLabv3p(BaseAPI):
self.backbone = backbone self.backbone = backbone
self.num_classes = num_classes self.num_classes = num_classes
self.input_channel = input_channel
self.use_bce_loss = use_bce_loss self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss self.use_dice_loss = use_dice_loss
self.class_weight = class_weight self.class_weight = class_weight
...@@ -215,6 +217,7 @@ class DeepLabv3p(BaseAPI): ...@@ -215,6 +217,7 @@ class DeepLabv3p(BaseAPI):
def build_net(self, mode='train'): def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.DeepLabv3p( model = paddlex.cv.nets.segmentation.DeepLabv3p(
self.num_classes, self.num_classes,
input_channel=self.input_channel,
mode=mode, mode=mode,
backbone=self._get_backbone(self.backbone), backbone=self._get_backbone(self.backbone),
output_stride=self.output_stride, output_stride=self.output_stride,
......
...@@ -48,6 +48,7 @@ class FastSCNN(DeepLabv3p): ...@@ -48,6 +48,7 @@ class FastSCNN(DeepLabv3p):
def __init__(self, def __init__(self,
num_classes=2, num_classes=2,
input_channel=3,
use_bce_loss=False, use_bce_loss=False,
use_dice_loss=False, use_dice_loss=False,
class_weight=None, class_weight=None,
...@@ -86,6 +87,7 @@ class FastSCNN(DeepLabv3p): ...@@ -86,6 +87,7 @@ class FastSCNN(DeepLabv3p):
) )
self.num_classes = num_classes self.num_classes = num_classes
self.input_channel = input_channel
self.use_bce_loss = use_bce_loss self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss self.use_dice_loss = use_dice_loss
self.class_weight = class_weight self.class_weight = class_weight
...@@ -97,6 +99,7 @@ class FastSCNN(DeepLabv3p): ...@@ -97,6 +99,7 @@ class FastSCNN(DeepLabv3p):
def build_net(self, mode='train'): def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.FastSCNN( model = paddlex.cv.nets.segmentation.FastSCNN(
self.num_classes, self.num_classes,
input_channel=self.input_channel,
mode=mode, mode=mode,
use_bce_loss=self.use_bce_loss, use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss, use_dice_loss=self.use_dice_loss,
......
...@@ -44,6 +44,7 @@ class HRNet(DeepLabv3p): ...@@ -44,6 +44,7 @@ class HRNet(DeepLabv3p):
def __init__(self, def __init__(self,
num_classes=2, num_classes=2,
input_channel=3,
width=18, width=18,
use_bce_loss=False, use_bce_loss=False,
use_dice_loss=False, use_dice_loss=False,
...@@ -72,6 +73,7 @@ class HRNet(DeepLabv3p): ...@@ -72,6 +73,7 @@ class HRNet(DeepLabv3p):
'Expect class_weight is a list or string but receive {}'. 'Expect class_weight is a list or string but receive {}'.
format(type(class_weight))) format(type(class_weight)))
self.num_classes = num_classes self.num_classes = num_classes
self.input_channel = input_channel
self.width = width self.width = width
self.use_bce_loss = use_bce_loss self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss self.use_dice_loss = use_dice_loss
...@@ -83,6 +85,7 @@ class HRNet(DeepLabv3p): ...@@ -83,6 +85,7 @@ class HRNet(DeepLabv3p):
def build_net(self, mode='train'): def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.HRNet( model = paddlex.cv.nets.segmentation.HRNet(
self.num_classes, self.num_classes,
input_channel=self.input_channel,
width=self.width, width=self.width,
mode=mode, mode=mode,
use_bce_loss=self.use_bce_loss, use_bce_loss=self.use_bce_loss,
......
...@@ -36,7 +36,7 @@ class PPYOLO(BaseAPI): ...@@ -36,7 +36,7 @@ class PPYOLO(BaseAPI):
Args: Args:
num_classes (int): 类别数。默认为80。 num_classes (int): 类别数。默认为80。
backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd']。默认为'ResNet50_vd'。 backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd_ssld']。默认为'ResNet50_vd_ssld'。
with_dcn_v2 (bool): Backbone是否使用DCNv2结构。默认为True。 with_dcn_v2 (bool): Backbone是否使用DCNv2结构。默认为True。
anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值 anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
......
...@@ -72,6 +72,7 @@ class DeepLabv3p(object): ...@@ -72,6 +72,7 @@ class DeepLabv3p(object):
def __init__(self, def __init__(self,
num_classes, num_classes,
backbone, backbone,
input_channel=3,
mode='train', mode='train',
output_stride=16, output_stride=16,
aspp_with_sep_conv=True, aspp_with_sep_conv=True,
...@@ -115,6 +116,7 @@ class DeepLabv3p(object): ...@@ -115,6 +116,7 @@ class DeepLabv3p(object):
format(type(class_weight))) format(type(class_weight)))
self.num_classes = num_classes self.num_classes = num_classes
self.input_channel = input_channel
self.backbone = backbone self.backbone = backbone
self.mode = mode self.mode = mode
self.use_bce_loss = use_bce_loss self.use_bce_loss = use_bce_loss
...@@ -402,13 +404,16 @@ class DeepLabv3p(object): ...@@ -402,13 +404,16 @@ class DeepLabv3p(object):
if self.fixed_input_shape is not None: if self.fixed_input_shape is not None:
input_shape = [ input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] None, self.input_channel, self.fixed_input_shape[1],
self.fixed_input_shape[0]
] ]
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image') dtype='float32', shape=input_shape, name='image')
else: else:
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image') dtype='float32',
shape=[None, self.input_channel, None, None],
name='image')
if self.mode == 'train': if self.mode == 'train':
inputs['label'] = fluid.data( inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label') dtype='int32', shape=[None, 1, None, None], name='label')
......
...@@ -33,6 +33,7 @@ from .model_utils.loss import bce_loss ...@@ -33,6 +33,7 @@ from .model_utils.loss import bce_loss
class FastSCNN(object): class FastSCNN(object):
def __init__(self, def __init__(self,
num_classes, num_classes,
input_channel=3,
mode='train', mode='train',
use_bce_loss=False, use_bce_loss=False,
use_dice_loss=False, use_dice_loss=False,
...@@ -62,6 +63,7 @@ class FastSCNN(object): ...@@ -62,6 +63,7 @@ class FastSCNN(object):
format(type(class_weight))) format(type(class_weight)))
self.num_classes = num_classes self.num_classes = num_classes
self.input_channel = input_channel
self.mode = mode self.mode = mode
self.use_bce_loss = use_bce_loss self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss self.use_dice_loss = use_dice_loss
...@@ -137,13 +139,16 @@ class FastSCNN(object): ...@@ -137,13 +139,16 @@ class FastSCNN(object):
inputs = OrderedDict() inputs = OrderedDict()
if self.fixed_input_shape is not None: if self.fixed_input_shape is not None:
input_shape = [ input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] None, self.input_channel, self.fixed_input_shape[1],
self.fixed_input_shape[0]
] ]
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image') dtype='float32', shape=input_shape, name='image')
else: else:
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image') dtype='float32',
shape=[None, self.input_channel, None, None],
name='image')
if self.mode == 'train': if self.mode == 'train':
inputs['label'] = fluid.data( inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label') dtype='int32', shape=[None, 1, None, None], name='label')
......
...@@ -32,6 +32,7 @@ import paddlex ...@@ -32,6 +32,7 @@ import paddlex
class HRNet(object): class HRNet(object):
def __init__(self, def __init__(self,
num_classes, num_classes,
input_channel=3,
mode='train', mode='train',
width=18, width=18,
use_bce_loss=False, use_bce_loss=False,
...@@ -61,6 +62,7 @@ class HRNet(object): ...@@ -61,6 +62,7 @@ class HRNet(object):
format(type(class_weight))) format(type(class_weight)))
self.num_classes = num_classes self.num_classes = num_classes
self.input_channel = input_channel
self.mode = mode self.mode = mode
self.use_bce_loss = use_bce_loss self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss self.use_dice_loss = use_dice_loss
...@@ -136,13 +138,16 @@ class HRNet(object): ...@@ -136,13 +138,16 @@ class HRNet(object):
if self.fixed_input_shape is not None: if self.fixed_input_shape is not None:
input_shape = [ input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] None, self.input_channel, self.fixed_input_shape[1],
self.fixed_input_shape[0]
] ]
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image') dtype='float32', shape=input_shape, name='image')
else: else:
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image') dtype='float32',
shape=[None, self.input_channel, None, None],
name='image')
if self.mode == 'train': if self.mode == 'train':
inputs['label'] = fluid.data( inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label') dtype='int32', shape=[None, 1, None, None], name='label')
......
...@@ -74,8 +74,22 @@ class Compose(SegTransform): ...@@ -74,8 +74,22 @@ class Compose(SegTransform):
raise ValueError('Can\'t read The image file {}!'.format(im)) raise ValueError('Can\'t read The image file {}!'.format(im))
im = im.astype('float32') im = im.astype('float32')
if label is not None: if label is not None:
if not isinstance(label, np.ndarray): if isinstance(label, np.ndarray):
label = np.asarray(Image.open(label)) if len(label.shape) != 2:
raise Exception(
"label should be 2-dimensions, but now is {}-dimensions".
format(len(label.shape)))
else:
try:
label = np.asarray(Image.open(label))
except:
ValueError('Can\'t read The label file {}!'.format(label))
im_height, im_width, _ = im.shape
label_height, label_width = label.shape
if im_height != label_height or im_width != label_width:
raise Exception(
"The height or width of the image is not same as the label")
return (im, label) return (im, label)
def __call__(self, im, im_info=None, label=None): def __call__(self, im, im_info=None, label=None):
...@@ -605,6 +619,7 @@ class Normalize(SegTransform): ...@@ -605,6 +619,7 @@ class Normalize(SegTransform):
mean = np.array(self.mean)[np.newaxis, np.newaxis, :] mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :] std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = normalize(im, mean, std, self.min_val, self.max_val) im = normalize(im, mean, std, self.min_val, self.max_val)
im = im.astype('float32')
if label is None: if label is None:
return (im, im_info) return (im, im_info)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import os.path as osp import os.path as osp
import argparse import argparse
from paddlex.seg import transforms from paddlex.seg import transforms
import paddlex.RemoteSensing.transforms as custom_transforms import paddlex.remotesensing.transforms as rs_transforms
import paddlex as pdx import paddlex as pdx
...@@ -110,22 +110,22 @@ train_transforms = transforms.Compose([ ...@@ -110,22 +110,22 @@ train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(0.5), transforms.RandomHorizontalFlip(0.5),
transforms.ResizeStepScaling(0.5, 2.0, 0.25), transforms.ResizeStepScaling(0.5, 2.0, 0.25),
transforms.RandomPaddingCrop(im_padding_value=[1000] * channel), transforms.RandomPaddingCrop(im_padding_value=[1000] * channel),
custom_transforms.Clip( rs_transforms.Clip(
min_val=clip_min_value, max_val=clip_max_value), min_val=clip_min_value, max_val=clip_max_value),
transforms.Normalize( transforms.Normalize(
min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std), min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
]) ])
train_transforms.decode_image = custom_transforms.decode_image train_transforms.decode_image = rs_transforms.decode_image
eval_transforms = transforms.Compose([ eval_transforms = transforms.Compose([
custom_transforms.Clip( rs_transforms.Clip(
min_val=clip_min_value, max_val=clip_max_value), min_val=clip_min_value, max_val=clip_max_value),
transforms.Normalize( transforms.Normalize(
min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std), min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
]) ])
eval_transforms.decode_image = custom_transforms.decode_image eval_transforms.decode_image = rs_transforms.decode_image
train_list = osp.join(data_dir, 'train.txt') train_list = osp.join(data_dir, 'train.txt')
val_list = osp.join(data_dir, 'val.txt') val_list = osp.join(data_dir, 'val.txt')
......
...@@ -2,17 +2,23 @@ import os ...@@ -2,17 +2,23 @@ import os
import os.path as osp import os.path as osp
import imghdr import imghdr
import gdal import gdal
gdal.UseExceptions()
gdal.PushErrorHandler('CPLQuietErrorHandler')
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from paddlex.seg import transforms from paddlex.seg import transforms
import paddlex.utils.logging as logging
def read_img(img_path): def read_img(img_path):
img_format = imghdr.what(img_path) img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path) name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img': if img_format == 'tiff' or ext == '.img':
dataset = gdal.Open(img_path) try:
dataset = gdal.Open(img_path)
except:
logging.error(gdal.GetLastErrorMsg())
if dataset == None: if dataset == None:
raise Exception('Can not open', img_path) raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray() im_data = dataset.ReadAsArray()
...@@ -36,9 +42,25 @@ def decode_image(im, label): ...@@ -36,9 +42,25 @@ def decode_image(im, label):
im = read_img(im) im = read_img(im)
except: except:
raise ValueError('Can\'t read The image file {}!'.format(im)) raise ValueError('Can\'t read The image file {}!'.format(im))
im = im.astype('float32')
if label is not None: if label is not None:
if not isinstance(label, np.ndarray): if isinstance(label, np.ndarray):
label = read_img(label) if len(label.shape) != 2:
raise Exception(
"label should be 2-dimensions, but now is {}-dimensions".
format(len(label.shape)))
else:
try:
label = np.asarray(Image.open(label))
except:
ValueError('Can\'t read The label file {}!'.format(label))
im_height, im_width, _ = im.shape
label_height, label_width = label.shape
if im_height != label_height or im_width != label_width:
raise Exception(
"The height or width of the image is not same as the label")
return (im, label) return (im, label)
......
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import os.path as osp
import sys
import argparse
from PIL import Image
from tqdm import tqdm
import imghdr
import logging
import pickle
import gdal
def parse_args():
parser = argparse.ArgumentParser(
description='Data analyse and data check before training.')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='Dataset directory',
default=None,
type=str)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
default=None,
type=int)
parser.add_argument(
'--separator',
dest='separator',
help='file list separator',
default=" ",
type=str)
parser.add_argument(
'--ignore_index',
dest='ignore_index',
help='Ignored class index',
default=255,
type=int)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def read_img(img_path):
img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
dataset = gdal.Open(img_path)
if dataset == None:
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0))
elif ext == '.npy':
return np.load(img_path)
else:
raise Exception('Not support {} image format!'.format(ext))
def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value):
channel = img.shape[2]
means = np.zeros(channel)
stds = np.zeros(channel)
for k in range(channel):
img_k = img[:, :, k]
# count mean, std
means[k] = np.mean(img_k)
stds[k] = np.std(img_k)
# count min, max
min_value = np.min(img_k)
max_value = np.max(img_k)
if img_max_value[k] < max_value:
img_max_value[k] = max_value
if img_min_value[k] > min_value:
img_min_value[k] = min_value
# count the distribution of image value, value number
unique, counts = np.unique(img_k, return_counts=True)
add_num = []
max_unique = np.max(unique)
add_len = max_unique - len(img_value_num[k]) + 1
if add_len > 0:
img_value_num[k] += ([0] * add_len)
for i in range(len(unique)):
value = unique[i]
img_value_num[k][value] += counts[i]
img_value_num[k] += add_num
return means, stds, img_min_value, img_max_value, img_value_num
def data_distribution_statistics(data_dir, img_value_num, logger):
"""count the distribution of image value, value number
"""
logger.info(
"\n-----------------------------\nThe whole dataset statistics...")
if not img_value_num:
return
logger.info("\nImage pixel statistics:")
total_ratio = []
[total_ratio.append([]) for i in range(len(img_value_num))]
for k in range(len(img_value_num)):
total_num = sum(img_value_num[k])
total_ratio[k] = [i / total_num for i in img_value_num[k]]
total_ratio[k] = np.around(total_ratio[k], decimals=4)
with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f:
pickle.dump([total_ratio, img_value_num], f)
def data_range_statistics(img_min_value, img_max_value, logger):
"""print min value, max value
"""
logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".
format(img_min_value, img_max_value))
def cal_normalize_coefficient(total_means, total_stds, total_img_num, logger):
"""count mean, std
"""
total_means = total_means / total_img_num
total_stds = total_stds / total_img_num
logger.info("\nCount the channel-by-channel mean and std of the image:\n"
"mean = {}\nstd = {}".format(total_means, total_stds))
def error_print(str):
return "".join(["\nNOT PASS ", str])
def correct_print(str):
return "".join(["\nPASS ", str])
def pil_imread(file_path):
"""read pseudo-color label"""
im = Image.open(file_path)
return np.asarray(im)
def get_img_shape_range(img, max_width, max_height, min_width, min_height):
"""获取图片最大和最小宽高"""
img_shape = img.shape
height, width = img_shape[0], img_shape[1]
max_height = max(height, max_height)
max_width = max(width, max_width)
min_height = min(height, min_height)
min_width = min(width, min_width)
return max_width, max_height, min_width, min_height
def get_img_channel_num(img, img_channels):
"""获取图像的通道数"""
img_shape = img.shape
if img_shape[-1] not in img_channels:
img_channels.append(img_shape[-1])
return img_channels
def is_label_single_channel(label):
"""判断标签是否为灰度图"""
label_shape = label.shape
if len(label_shape) == 2:
return True
else:
return False
def image_label_shape_check(img, label):
"""
验证图像和标注的大小是否匹配
"""
flag = True
img_height = img.shape[0]
img_width = img.shape[1]
label_height = label.shape[0]
label_width = label.shape[1]
if img_height != label_height or img_width != label_width:
flag = False
return flag
def ground_truth_check(label, label_path):
"""
验证标注图像的格式
统计标注图类别和像素数
params:
label: 标注图
label_path: 标注图路径
return:
png_format: 返回是否是png格式图片
unique: 返回标注类别
counts: 返回标注的像素数
"""
if imghdr.what(label_path) == "png":
png_format = True
else:
png_format = False
unique, counts = np.unique(label, return_counts=True)
return png_format, unique, counts
def sum_label_check(label_classes, num_of_each_class, ignore_index,
num_classes, total_label_classes, total_num_of_each_class):
"""
统计所有标注图上的类别和每个类别的像素数
params:
label_classes: 标注类别
num_of_each_class: 各个类别的像素数目
"""
is_label_correct = True
if ignore_index in label_classes:
label_classes2 = np.delete(label_classes,
np.where(label_classes == ignore_index))
else:
label_classes2 = label_classes
if min(label_classes2) < 0 or max(label_classes2) > num_classes - 1:
is_label_correct = False
add_class = []
add_num = []
for i in range(len(label_classes)):
gi = label_classes[i]
if gi in total_label_classes:
j = total_label_classes.index(gi)
total_num_of_each_class[j] += num_of_each_class[i]
else:
add_class.append(gi)
add_num.append(num_of_each_class[i])
total_num_of_each_class += add_num
total_label_classes += add_class
return is_label_correct, total_num_of_each_class, total_label_classes
def label_class_check(num_classes, total_label_classes,
total_num_of_each_class, wrong_labels, logger):
"""
检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。
**NOTE:**
标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。
标注类别最好从0开始,否则可能影响精度。
"""
total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
total_ratio = np.around(total_ratio, decimals=4)
total_nc = sorted(
zip(total_label_classes, total_ratio, total_num_of_each_class))
if len(wrong_labels) == 0 and not total_nc[0][0]:
logger.info(correct_print("label class check!"))
else:
logger.info(error_print("label class check!"))
if total_nc[0][0]:
logger.info("Warning: label classes should start from 0")
if len(wrong_labels) > 0:
logger.info("fatal error: label class is out of range [0, {}]".
format(num_classes - 1))
for i in wrong_labels:
logger.debug(i)
return total_nc
def label_class_statistics(total_nc, logger):
"""
对标注图像进行校验,输出校验结果
"""
logger.info("\nLabel class statistics:\n"
"(label class, percentage, total pixel number) = {} ".format(
total_nc))
def shape_check(shape_unequal_image, logger):
"""输出shape校验结果"""
if len(shape_unequal_image) == 0:
logger.info(correct_print("shape check"))
logger.info("All images are the same shape as the labels")
else:
logger.info(error_print("shape check"))
logger.info(
"Some images are not the same shape as the labels as follow: ")
for i in shape_unequal_image:
logger.debug(i)
def separator_check(wrong_lines, file_list, separator, logger):
"""检查分割符是否复合要求"""
if len(wrong_lines) == 0:
logger.info(
correct_print(
file_list.split(os.sep)[-1] + " DATASET.separator check"))
else:
logger.info(
error_print(
file_list.split(os.sep)[-1] + " DATASET.separator check"))
logger.info("The following list is not separated by {}".format(
separator))
for i in wrong_lines:
logger.debug(i)
def imread_check(imread_failed, logger):
if len(imread_failed) == 0:
logger.info(correct_print("dataset reading check"))
logger.info("All images can be read successfully")
else:
logger.info(error_print("dataset reading check"))
logger.info("Failed to read {} images".format(len(imread_failed)))
for i in imread_failed:
logger.debug(i)
def single_channel_label_check(label_not_single_channel, logger):
if len(label_not_single_channel) == 0:
logger.info(correct_print("label single_channel check"))
logger.info("All label images are single_channel")
else:
logger.info(error_print("label single_channel check"))
logger.info(
"{} label images are not single_channel\nLabel pixel statistics may be insignificant"
.format(len(label_not_single_channel)))
for i in label_not_single_channel:
logger.debug(i)
def img_shape_range_statistics(max_width, min_width, max_height, min_height,
logger):
logger.info("\nImage size statistics:")
logger.info(
"max width = {} min width = {} max height = {} min height = {}".
format(max_width, min_width, max_height, min_height))
def img_channels_statistics(img_channels, logger):
logger.info("\nImage channels statistics\nImage channels = {}".format(
np.unique(img_channels)))
def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
logger):
train_file_list = osp.join(data_dir, 'train.txt')
val_file_list = osp.join(data_dir, 'val.txt')
test_file_list = osp.join(data_dir, 'test.txt')
total_img_num = 0
has_label = False
for file_list in [train_file_list, val_file_list, test_file_list]:
# initialization
imread_failed = []
max_width = 0
max_height = 0
min_width = sys.float_info.max
min_height = sys.float_info.max
label_not_single_channel = []
shape_unequal_image = []
wrong_labels = []
wrong_lines = []
total_label_classes = []
total_num_of_each_class = []
img_channels = []
with open(file_list, 'r') as fid:
logger.info("\n-----------------------------\nCheck {}...".format(
file_list))
lines = fid.readlines()
if not lines:
logger.info("File list is empty!")
continue
for line in tqdm(lines):
line = line.strip()
parts = line.split(separator)
if len(parts) == 1:
if file_list == train_file_list or file_list == val_file_list:
logger.info("Train or val list must have labels!")
break
img_name = parts
img_path = os.path.join(data_dir, img_name[0])
try:
img = read_img(img_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
elif len(parts) == 2:
has_label = True
img_name, label_name = parts[0], parts[1]
img_path = os.path.join(data_dir, img_name)
label_path = os.path.join(data_dir, label_name)
try:
img = read_img(img_path)
label = pil_imread(label_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
is_single_channel = is_label_single_channel(label)
if not is_single_channel:
label_not_single_channel.append(line)
continue
is_equal_img_label_shape = image_label_shape_check(img,
label)
if not is_equal_img_label_shape:
shape_unequal_image.append(line)
png_format, label_classes, num_of_each_class = ground_truth_check(
label, label_path)
is_label_correct, total_num_of_each_class, total_label_classes = sum_label_check(
label_classes, num_of_each_class, ignore_index,
num_classes, total_label_classes,
total_num_of_each_class)
if not is_label_correct:
wrong_labels.append(line)
else:
wrong_lines.append(lines)
continue
if total_img_num == 0:
channel = img.shape[2]
total_means = np.zeros(channel)
total_stds = np.zeros(channel)
img_min_value = [sys.float_info.max] * channel
img_max_value = [0] * channel
img_value_num = []
[img_value_num.append([]) for i in range(channel)]
means, stds, img_min_value, img_max_value, img_value_num = img_pixel_statistics(
img, img_value_num, img_min_value, img_max_value)
total_means += means
total_stds += stds
max_width, max_height, min_width, min_height = get_img_shape_range(
img, max_width, max_height, min_width, min_height)
img_channels = get_img_channel_num(img, img_channels)
total_img_num += 1
# data check
separator_check(wrong_lines, file_list, separator, logger)
imread_check(imread_failed, logger)
if has_label:
single_channel_label_check(label_not_single_channel, logger)
shape_check(shape_unequal_image, logger)
total_nc = label_class_check(num_classes, total_label_classes,
total_num_of_each_class,
wrong_labels, logger)
# data analyse on train, validation, test set.
img_channels_statistics(img_channels, logger)
img_shape_range_statistics(max_width, min_width, max_height,
min_height, logger)
if has_label:
label_class_statistics(total_nc, logger)
# data analyse on the whole dataset.
data_range_statistics(img_min_value, img_max_value, logger)
data_distribution_statistics(data_dir, img_value_num, logger)
cal_normalize_coefficient(total_means, total_stds, total_img_num, logger)
def main():
args = parse_args()
data_dir = args.data_dir
ignore_index = args.ignore_index
num_classes = args.num_classes
separator = args.separator
logger = logging.getLogger()
logger.setLevel('DEBUG')
BASIC_FORMAT = "%(message)s"
formatter = logging.Formatter(BASIC_FORMAT)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
sh.setLevel('INFO')
th = logging.FileHandler(
os.path.join(data_dir, 'data_analyse_and_check.log'), 'w')
th.setFormatter(formatter)
logger.addHandler(sh)
logger.addHandler(th)
data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
logger)
print("\nDetailed error information can be viewed in {}.".format(
os.path.join(data_dir, 'data_analyse_and_check.log')))
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册