提交 1093eacd 编写于 作者: F FlyingQianMM

add analysis

上级 746122f4
......@@ -6,6 +6,7 @@ API接口说明
transforms/index.rst
datasets.md
analysis.md
models/index.rst
slim.md
visualize.md
......
......@@ -78,16 +78,19 @@ paddlex.seg.transforms.ResizeStepScaling(min_scale_factor=0.75, max_scale_factor
## Normalize
```python
paddlex.seg.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
paddlex.seg.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0])
```
对图像进行标准化。
1.图像像素归一化到区间 [0.0, 1.0]。
2.对图像进行减均值除以标准差操作。
1.像素值减去min_val
2.像素值除以(max_val-min_val), 归一化到区间 [0.0, 1.0]。
3.对图像进行减均值除以标准差操作。
### 参数
* **mean** (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
* **std** (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
* **min_val** (list): 图像数据集的最小值。默认值[0, 0, 0]。
* **max_val** (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]。
## Padding
```python
......@@ -167,6 +170,16 @@ paddlex.seg.transforms.RandomDistort(brightness_range=0.5, brightness_prob=0.5,
* **hue_range** (int): 色调因子的范围。默认为18。
* **hue_prob** (float): 随机调整色调的概率。默认为0.5。
## Clip
```python
paddlex.seg.transforms.Clip(min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0])
```
对图像上超出一定范围的数据进行截断。
### 参数
* **min_val** (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0。
* **max_val** (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0。
<!--
## ComposedSegTransforms
```python
......
......@@ -27,7 +27,7 @@ pdx.det.visualize('./xiaoduxiong_epoch_12/xiaoduxiong.jpeg', result, save_dir='.
## paddlex.seg.visualize
> **语义分割模型预测结果可视化**
```
paddlex.seg.visualize(image, result, weight=0.6, save_dir='./')
paddlex.seg.visualize(image, result, weight=0.6, save_dir='./', color=None)
```
将语义分割模型预测得到的Mask在原图上进行可视化。
......@@ -36,6 +36,7 @@ paddlex.seg.visualize(image, result, weight=0.6, save_dir='./')
> * **result** (str): 模型预测结果。
> * **weight**(float): mask可视化结果与原图权重因子,weight表示原图的权重。默认0.6。
> * **save_dir**(str): 可视化结果保存路径。若为None,则表示不保存,该函数将可视化的结果以np.ndarray的形式返回;若设为目录路径,则将可视化结果保存至该目录下。默认值为'./'。
> * **color** (list): 各类别的BGR颜色值组成的列表。例如两类时可设置为[255, 255, 255, 0, 0, 0]。默认值为None,则使用默认生成的颜色列表。
### 使用示例
> 点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/cityscape_deeplab.tar.gz)和[测试图片](https://bj.bcebos.com/paddlex/datasets/city.png)
......
......@@ -12,3 +12,4 @@ PaddleX精选飞桨视觉开发套件在产业实践中的成熟模型结构,
solutions.md
meter_reader.md
human_segmentation.md
multi-channel_remote_sensing/README.md
# 多通道遥感影像分割
遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。
本案例基于PaddleX实现多通道遥感影像分割,涵盖数据分析、模型训练、模型预测等流程,旨在帮助用户利用深度学习技术解决多通道遥感影像分割问题。
## 前置依赖
* Paddle paddle >= 1.8.4
* Python >= 3.5
* PaddleX >= 1.1.0
安装的相关问题参考[PaddleX安装](../../install.md)
**另外还需安装gdal**, 使用pip安装gdal可能出错,推荐使用conda进行安装:
```
conda install gdal
```
下载PaddleX源码:
```
git clone https://github.com/PaddlePaddle/PaddleX
```
该案例所有脚本均位于`PaddleX/examples/channel_remote_sensing/`,进入该目录:
```
cd PaddleX/examples/channel_remote_sensing/
```
## 数据准备
遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。PaddleX现已兼容以下4种格式图片读取:
- `tif`
- `png`
- `img`
- `npy`
标注图要求必须为单通道的png格式图像,像素值即为对应的类别,像素标注类别需要从0开始递增。例如0,1,2,3表示有4种类别,255用于指定不参与训练和评估的像素,标注类别最多为256类。
本案例使用[L8 SPARCS公开数据集](https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation)进行云雪分割,该数据集包含80张卫星影像,涵盖10个波段。原始标注图片包含7个类别,分别是`cloud`, `cloud shadow`, `shadow over water`, `snow/ice`, `water`, `land``flooded`。由于`flooded``shadow over water`2个类别占比仅为`1.8%``0.24%`,我们将其进行合并,`flooded`归为`land``shadow over water`归为`shadow`,合并后标注包含5个类别。
数值、类别、颜色对应表:
|Pixel value|Class|Color|
|---|---|---|
|0|cloud|white|
|1|shadow|black|
|2|snow/ice|cyan|
|3|water|blue|
|4|land|grey|
![](../../../examples/multi-channel_remote_sensing/docs/images/dataset.png)
执行以下命令下载并解压经过类别合并后的数据集:
```shell script
mkdir dataset && cd dataset
wget https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip
unzip remote_sensing_seg.zip
cd ..
```
其中`data`目录存放遥感影像,`data_vis`目录存放彩色合成预览图,`mask`目录存放标注图。
## 数据分析
遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的分布来优化模型训练效果,需要对数据进行分析。
参考文档[数据分析](./analysis.md)对训练集进行统计分析,确定图像像素值的截断范围,并统计截断后的均值和方差。
## 模型训练
本案例选择`UNet`语义分割模型完成云雪分割,运行以下步骤完成模型训练,模型的最优精度`miou``77.99%`
* 设置GPU卡号
```shell script
export CUDA_VISIBLE_DEVICES=0
```
* 运行以下脚本开始训练
```shell script
python train.py --data_dir dataset/remote_sensing_seg \
--train_file_list dataset/remote_sensing_seg/train.txt \
--eval_file_list dataset/remote_sensing_seg/val.txt \
--label_list dataset/remote_sensing_seg/labels.txt \
--save_dir saved_model/remote_sensing_unet \
--num_classes 5 \
--channel 10 \
--lr 0.01 \
--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000 \
--mean 0.14311188522260637 0.14288498042151332 0.14812997807748615 0.16377211813814938 0.2737538363784552 0.2740934379398823 0.27749601919204 0.07767443032935262 0.5694699410349131 0.5549716085195542 \
--std 0.09101632762467489 0.09600705942721106 0.096193618606776 0.10371446736389771 0.10911951586604118 0.11043593115173281 0.12648042598739268 0.027746262217260665 0.06822348076384514 0.062377591186668725 \
--num_epochs 500 \
--train_batch_size 3
```
也可以跳过模型训练步骤,下载预训练模型直接进行模型预测:
```
wget https://bj.bcebos.com/paddlex/examples/multi-channel_remote_sensing/models/l8sparcs_remote_model.tar.gz
tar -xvf l8sparcs_remote_model.tar.gz
```
## 模型预测
运行以下脚本,对遥感图像进行预测并可视化预测结果,相应地也将对应的标注文件进行可视化,以比较预测效果。
```shell script
export CUDA_VISIBLE_DEVICES=0
python predict.py
```
可视化效果如下所示:
![](../../../examples/multi-channel_remote_sensing/docs/images/prediction.jpg)
数值、类别、颜色对应表:
|Pixel value|Class|Color|
|---|---|---|
|0|cloud|white|
|1|shadow|black|
|2|snow/ice|cyan|
|3|water|blue|
|4|land|grey|
# 多通道遥感影像分割
遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。
本案例基于PaddleX实现多通道遥感影像分割,涵盖数据分析、模型训练、模型预测等流程,旨在帮助用户利用深度学习技术解决多通道遥感影像分割问题。
## 目录
* [前置依赖](#1)
* [数据准备](#2)
* [数据分析](#3)
* [模型训练](#4)
* [模型预测](#5)
## <h2 id="1">前置依赖</h2>
* Paddle paddle >= 1.8.4
* Python >= 3.5
* PaddleX >= 1.1.0
安装的相关问题参考[PaddleX安装](../../docs/install.md)
**另外还需安装gdal**, 使用pip安装gdal可能出错,推荐使用conda进行安装:
```
conda install gdal
```
下载PaddleX源码:
```
git clone https://github.com/PaddlePaddle/PaddleX
```
该案例所有脚本均位于`PaddleX/examples/channel_remote_sensing/`,进入该目录:
```
cd PaddleX/examples/channel_remote_sensing/
```
## <h2 id="2">数据准备</h2>
遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。PaddleX现已兼容以下4种格式图片读取:
- `tif`
- `png`
- `img`
- `npy`
标注图要求必须为单通道的png格式图像,像素值即为对应的类别,像素标注类别需要从0开始递增。例如0,1,2,3表示有4种类别,255用于指定不参与训练和评估的像素,标注类别最多为256类。
本案例使用[L8 SPARCS公开数据集](https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation)进行云雪分割,该数据集包含80张卫星影像,涵盖10个波段。原始标注图片包含7个类别,分别是`cloud`, `cloud shadow`, `shadow over water`, `snow/ice`, `water`, `land``flooded`。由于`flooded``shadow over water`2个类别占比仅为`1.8%``0.24%`,我们将其进行合并,`flooded`归为`land``shadow over water`归为`shadow`,合并后标注包含5个类别。
数值、类别、颜色对应表:
|Pixel value|Class|Color|
|---|---|---|
|0|cloud|white|
|1|shadow|black|
|2|snow/ice|cyan|
|3|water|blue|
|4|land|grey|
<p align="center">
<img src="./docs/images/dataset.png" align="middle"
</p>
<p align='center'>
L8 SPARCS数据集示例
</p>
执行以下命令下载并解压经过类别合并后的数据集:
```shell script
mkdir dataset && cd dataset
wget https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip
unzip remote_sensing_seg.zip
cd ..
```
其中`data`目录存放遥感影像,`data_vis`目录存放彩色合成预览图,`mask`目录存放标注图。
## <h2 id="2">数据分析</h2>
遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的分布来优化模型训练效果,需要对数据进行分析。
参考文档[数据分析](./docs/analysis.md)对训练集进行统计分析,确定图像像素值的截断范围,并统计截断后的均值和方差。
## <h2 id="2">模型训练</h2>
本案例选择`UNet`语义分割模型完成云雪分割,运行以下步骤完成模型训练,模型的最优精度`miou``77.99%`
* 设置GPU卡号
```shell script
export CUDA_VISIBLE_DEVICES=0
```
* 运行以下脚本开始训练
```shell script
python train.py --data_dir dataset/remote_sensing_seg \
--train_file_list dataset/remote_sensing_seg/train.txt \
--eval_file_list dataset/remote_sensing_seg/val.txt \
--label_list dataset/remote_sensing_seg/labels.txt \
--save_dir saved_model/remote_sensing_unet \
--num_classes 5 \
--channel 10 \
--lr 0.01 \
--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000 \
--mean 0.14311188522260637 0.14288498042151332 0.14812997807748615 0.16377211813814938 0.2737538363784552 0.2740934379398823 0.27749601919204 0.07767443032935262 0.5694699410349131 0.5549716085195542 \
--std 0.09101632762467489 0.09600705942721106 0.096193618606776 0.10371446736389771 0.10911951586604118 0.11043593115173281 0.12648042598739268 0.027746262217260665 0.06822348076384514 0.062377591186668725 \
--num_epochs 500 \
--train_batch_size 3
```
也可以跳过模型训练步骤,下载预训练模型直接进行模型预测:
```
wget https://bj.bcebos.com/paddlex/examples/multi-channel_remote_sensing/models/l8sparcs_remote_model.tar.gz
tar -xvf l8sparcs_remote_model.tar.gz
```
## <h2 id="2">模型预测</h2>
运行以下脚本,对遥感图像进行预测并可视化预测结果,相应地也将对应的标注文件进行可视化,以比较预测效果。
```shell script
export CUDA_VISIBLE_DEVICES=0
python predict.py
```
可视化效果如下所示:
<img src="./docs/images/prediction.jpg" alt="预测图" align=center />
数值、类别、颜色对应表:
|Pixel value|Class|Color|
|---|---|---|
|0|cloud|white|
|1|shadow|black|
|2|snow/ice|cyan|
|3|water|blue|
|4|land|grey|
# 数据分析
遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的分布来优化模型训练效果,需要对数据进行分析。
## 目录
* [1. 统计分析](#1)
* [2. 确定像素值截断范围](#2)
* [3. 统计截断后的均值和方差](#3)
## <h2 id="1">统计分析</h2>
执行以下脚本,对训练集进行统计分析,屏幕会输出分析结果,同时结果也会保存至文件`train_information.pkl`中:
```
python tools/analysis.py
```
数据统计分析内容如下:
* 图像数量
例如统计出训练集中有64张图片:
```
64 samples in file dataset/remote_sensing_seg/train.txt
```
* 图像最大和最小的尺寸
例如统计出训练集中最大的高宽和最小的高宽分别是(1000, 1000)和(1000, 1000):
```
Minimal image height: 1000 Minimal image width: 1000.
Maximal image height: 1000 Maximal image width: 1000.
```
* 图像通道数量
例如统计出图像的通道数量为10:
```
Image channel is 10.
```
* 图像各通道的最小值和最大值
最小值和最大值分别以列表的形式输出,按照通道从小到大排列。例如:
```
Minimal image value: [7.172e+03 6.561e+03 5.777e+03 5.103e+03 4.291e+03 1.000e+00 1.000e+00 4.232e+03 6.934e+03 7.199e+03]
Maximal image value: [65535. 65535. 65535. 65535. 65535. 65535. 65535. 56534. 65535. 63215.]
```
* 图像各通道的像素值分布
针对各个通道,统计出各像素值的数量,并以柱状图的形式呈现在以'distribute.png'结尾的图片中。**需要注意的是,为便于观察,纵坐标为对数坐标**。用户可以查看这些图片来选择是否需要对分布在头部和尾部的像素值进行截断。
```
Image pixel distribution of each channel is saved with 'distribute.png' in the dataset/remote_sensing_seg
```
* 图像各通道归一化后的均值和方差
各通道归一化系数为各通道最大值与最小值之差,均值和方差以列别形式输出,按照通道从小到大排列。例如:
```
Image mean value: [0.23417574 0.22283101 0.2119595 0.2119887 0.27910388 0.21294892 0.17294037 0.10158925 0.43623915 0.41019192]
Image standard deviation: [0.06831269 0.07243951 0.07284761 0.07875261 0.08120818 0.0609302 0.05110716 0.00696064 0.03849307 0.03205579]
```
* 标注图中各类别的数量及比重
统计各类别的像素数量和在数据集全部像素的占比,以(类别值,该类别的数量,该类别的占比)的格式输出。例如:
```
Label pixel information is shown in a format of (label_id, the number of label_id, the ratio of label_id):
(0, 13302870, 0.20785734374999995)
(1, 4577005, 0.07151570312500002)
(2, 3955012, 0.0617970625)
(3, 2814243, 0.04397254687499999)
(4, 39350870, 0.6148573437500001)
```
## <h2 id="2">2 确定像素值截断范围</h2>
遥感影像数据分布范围广,往往存在一些异常值,这会影响算法对实际数据分布的拟合效果。为更好地对数据进行归一化,可以抑制遥感影像中少量的异常值。根据`图像各通道的像素值分布`来确定像素值的截断范围,并在后续图像预处理过程中对超出范围的像素值通过截断进行校正,从而去除异常值带来的干扰。**注意:该步骤是否执行根据数据集实际分布来决定。**
例如各通道的像素值分布可视化效果如下:
<img src="./images/image_pixel_distribution.png" width = "600" height = "600" alt="像素值分布图" align=center />
对于上述分布,我们选取的截断范围是(按照通道从小到大排列):
```
截断范围最小值: clip_min_value = [7172, 6561, 5777, 5103, 4291, 4000, 4000, 4232, 6934, 7199]
截断范围最大值: clip_max_value = [50000, 50000, 50000, 50000, 50000, 40000, 30000, 18000, 40000, 36000]
```
## <h2 id="3">3 确定像素值截断范围</h2>
为避免数据截断范围选取不当带来的影响,应该统计异常值像素占比,确保受影响的像素比例不要过高。接着对截断后的数据计算归一化后的均值和方差,**用于后续模型训练时的图像预处理参数设置**
执行以下脚本:
```
python tools/cal_clipped_mean_std.py
```
截断像素占比统计结果如下:
```
Channel 0, the ratio of pixels to be clipped = 0.00054778125
Channel 1, the ratio of pixels to be clipped = 0.0011129375
Channel 2, the ratio of pixels to be clipped = 0.000843703125
Channel 3, the ratio of pixels to be clipped = 0.00127125
Channel 4, the ratio of pixels to be clipped = 0.001330140625
Channel 5, the ratio of pixels to be clipped = 8.1375e-05
Channel 6, the ratio of pixels to be clipped = 0.0007348125
Channel 7, the ratio of pixels to be clipped = 6.5625e-07
Channel 8, the ratio of pixels to be clipped = 0.000185921875
Channel 9, the ratio of pixels to be clipped = 0.000139671875
```
可看出,被截断像素占比均不超过0.2%。
裁剪后数据的归一化系数如下:
```
Image mean value: [0.15163569 0.15142828 0.15574491 0.1716084 0.2799778 0.27652043 0.28195933 0.07853807 0.56333154 0.5477584 ]
Image standard deviation: [0.09301891 0.09818967 0.09831126 0.1057784 0.10842132 0.11062996 0.12791838 0.02637859 0.0675052 0.06168227]
(normalized by (clip_max_value - clip_min_value), arranged in 0-10 channel order)
```
import numpy as np
from PIL import Image
import paddlex as pdx
model_dir = "saved_model/remote_sensing_unet/best_model/"
img_file = "dataset/remote_sensing_seg/data/LC80150242014146LGN00_23_data.tif"
label_file = "dataset/remote_sensing_seg/mask/LC80150242014146LGN00_23_mask.png"
color = [255, 255, 255, 0, 0, 0, 255, 255, 0, 255, 0, 0, 150, 150, 150]
# 预测并可视化预测结果
model = pdx.load_model(model_dir)
pred = model.predict(img_file)
pdx.seg.visualize(
img_file, pred, weight=0., save_dir='./output/pred', color=color)
# 可视化标注文件
label = np.asarray(Image.open(label_file))
pred = {'label_map': label}
pdx.seg.visualize(
img_file, pred, weight=0., save_dir='./output/gt', color=color)
import paddlex as pdx
train_analysis = pdx.datasets.analysis.Seg(
data_dir='dataset/remote_sensing_seg',
file_list='dataset/remote_sensing_seg/train.txt',
label_list='dataset/remote_sensing_seg/labels.txt')
train_analysis.analysis()
import paddlex as pdx
clip_min_value = [7172, 6561, 5777, 5103, 4291, 4000, 4000, 4232, 6934, 7199]
clip_max_value = [
50000, 50000, 50000, 50000, 50000, 40000, 30000, 18000, 40000, 36000
]
data_info_file = 'dataset/remote_sensing_seg/train_infomation.pkl'
train_analysis = pdx.datasets.analysis.Seg(
data_dir='dataset/remote_sensing_seg',
file_list='dataset/remote_sensing_seg/train.txt',
label_list='dataset/remote_sensing_seg/labels.txt')
train_analysis.cal_clipped_mean_std(clip_min_value, clip_max_value,
data_info_file)
......@@ -16,7 +16,6 @@
import os.path as osp
import argparse
from paddlex.seg import transforms
import paddlex.remotesensing.transforms as rs_transforms
import paddlex as pdx
......@@ -28,6 +27,24 @@ def parse_args():
help='dataset directory',
default=None,
type=str)
parser.add_argument(
'--train_file_list',
dest='train_file_list',
help='train file_list',
default=None,
type=str)
parser.add_argument(
'--eval_file_list',
dest='eval_file_list',
help='eval file_list',
default=None,
type=str)
parser.add_argument(
'--label_list',
dest='label_list',
help='label_list file',
default=None,
type=str)
parser.add_argument(
'--save_dir',
dest='save_dir',
......@@ -93,6 +110,9 @@ def parse_args():
args = parse_args()
data_dir = args.data_dir
train_list = args.train_file_list
val_list = args.eval_file_list
label_list = args.label_list
save_dir = args.save_dir
num_classes = args.num_classes
channel = args.channel
......@@ -110,27 +130,19 @@ train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(0.5),
transforms.ResizeStepScaling(0.5, 2.0, 0.25),
transforms.RandomPaddingCrop(im_padding_value=[1000] * channel),
rs_transforms.Clip(
transforms.Clip(
min_val=clip_min_value, max_val=clip_max_value),
transforms.Normalize(
min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
])
train_transforms.decode_image = rs_transforms.decode_image
eval_transforms = transforms.Compose([
rs_transforms.Clip(
transforms.Clip(
min_val=clip_min_value, max_val=clip_max_value),
transforms.Normalize(
min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
])
eval_transforms.decode_image = rs_transforms.decode_image
train_list = osp.join(data_dir, 'train.txt')
val_list = osp.join(data_dir, 'val.txt')
label_list = osp.join(data_dir, 'labels.txt')
train_dataset = pdx.datasets.SegDataset(
data_dir=data_dir,
file_list=train_list,
......
......@@ -23,6 +23,7 @@ import multiprocessing as mp
import paddlex.utils.logging as logging
from paddlex.utils import path_normalization
from paddlex.cv.transforms.seg_transforms import Compose
from .dataset import get_encoding
......@@ -57,38 +58,6 @@ class Seg:
self.file_list.append([full_path_im, full_path_label])
self.num_samples = len(self.file_list)
@staticmethod
def decode_image(im, label):
if isinstance(im, np.ndarray):
if len(im.shape) != 3:
raise Exception(
"im should be 3-dimensions, but now is {}-dimensions".
format(len(im.shape)))
else:
try:
im = cv2.imread(im)
except:
raise ValueError('Can\'t read The image file {}!'.format(im))
im = im.astype('float32')
if label is not None:
if isinstance(label, np.ndarray):
if len(label.shape) != 2:
raise Exception(
"label should be 2-dimensions, but now is {}-dimensions".
format(len(label.shape)))
else:
try:
label = np.asarray(Image.open(label))
except:
ValueError('Can\'t read The label file {}!'.format(label))
im_height, im_width, _ = im.shape
label_height, label_width = label.shape
if im_height != label_height or im_width != label_width:
raise Exception(
"The height or width of the image is not same as the label")
return (im, label)
def _get_shape(self):
max_height = max(self.im_height_list)
max_width = max(self.im_width_list)
......@@ -127,48 +96,25 @@ class Seg:
im_pixel_info[c][v] = n
else:
im_pixel_info[c][v] += n
mode = osp.split(self.file_list_path)[-1].split('.')[0]
with open(
osp.join(self.data_dir,
'{}_image_pixel_info.pkl'.format(mode)), 'wb') as f:
pickle.dump(im_pixel_info, f)
import matplotlib.pyplot as plt
plot_id = (channel // 3 + 1) * 100 + 31
for c in range(channel):
if c > 8:
continue
plt.subplot(plot_id + c)
plt.bar(im_pixel_info[c].keys(),
im_pixel_info[c].values(),
width=1,
log=True)
plt.xlabel('image pixel value')
plt.ylabel('number')
plt.title('channel={}'.format(c))
plt.savefig(
osp.join(self.data_dir, '{}_image_pixel_info.png'.format(mode)),
dpi=800)
plt.close()
return im_pixel_info
def _get_mean_std(self):
im_mean = np.asarray(self.im_mean_list)
im_mean = im_mean.sum(axis=0)
im_mean = im_mean / len(self.file_list)
im_mean /= 255.
im_mean /= self.max_im_value - self.min_im_value
im_std = np.asarray(self.im_std_list)
im_std = im_std.sum(axis=0)
im_std = im_std / len(self.file_list)
im_std /= 255.
im_std /= self.max_im_value - self.min_im_value
return (im_mean, im_std)
def _get_image_info(self, start, end):
for id in range(start, end):
full_path_im, full_path_label = self.file_list[id]
image, label = self.decode_image(full_path_im, full_path_label)
image, label = Compose.decode_image(full_path_im, full_path_label)
height, width, channel = image.shape
self.im_height_list[id] = height
......@@ -176,9 +122,9 @@ class Seg:
self.im_channel_list[id] = channel
self.im_mean_list[
id] = [np.mean(image[:, :, c]) for c in range(channel)]
id] = [image[:, :, c].mean() for c in range(channel)]
self.im_std_list[
id] = [np.mean(image[:, :, c]) for c in range(channel)]
id] = [image[:, :, c].std() for c in range(channel)]
for c in range(channel):
unique, counts = np.unique(image[:, :, c], return_counts=True)
self.im_value_list[id].extend([unique])
......@@ -192,7 +138,7 @@ class Seg:
clip_max_value):
for id in range(start, end):
full_path_im, full_path_label = self.file_list[id]
image, label = self.decode_image(full_path_im, full_path_label)
image, label = Compose.decode_image(full_path_im, full_path_label)
for c in range(self.channel_num):
np.clip(
image[:, :, c],
......@@ -219,7 +165,6 @@ class Seg:
self.label_value_num_list = [[] for i in range(len(self.file_list))]
num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
num_workers = 6
threads = []
one_worker_file = len(self.file_list) // num_workers
for i in range(num_workers):
......@@ -228,39 +173,41 @@ class Seg:
i + 1) if i < num_workers - 1 else len(self.file_list)
t = threading.Thread(
target=self._get_image_info, args=(start, end))
print("====", len(self.file_list), start, end)
#t.daemon = True
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
print('ok')
import time
import sys
sys.exit(0)
time.sleep(1000000)
return
#self._get_image_info(0, len(self.file_list))
unique, counts = np.unique(self.im_channel_list, return_counts=True)
print('==== unique')
if len(unique) > 1:
raise Exception("There are {} kinds of image channels: {}.".format(
len(unique), unique[:]))
self.channel_num = unique[0]
shape_info = self._get_shape()
print('==== shape_info')
self.max_height = shape_info['max_height']
self.max_width = shape_info['max_width']
self.min_height = shape_info['min_height']
self.min_width = shape_info['min_width']
self.label_pixel_info = self._get_label_pixel_info()
print('==== label_pixel_info')
self.im_pixel_info = self._get_image_pixel_info()
print('==== im_pixel_info')
im_mean, im_std = self._get_mean_std()
print('==== get_mean_std')
mode = osp.split(self.file_list_path)[-1].split('.')[0]
import matplotlib.pyplot as plt
for c in range(self.channel_num):
plt.figure()
plt.bar(self.im_pixel_info[c].keys(),
self.im_pixel_info[c].values(),
width=1,
log=True)
plt.xlabel('image pixel value')
plt.ylabel('number')
plt.title('channel={}'.format(c))
plt.savefig(
osp.join(self.data_dir,
'{}_channel{}_distribute.png'.format(mode, c)),
dpi=100)
plt.close()
max_im_value = list()
min_im_value = list()
for c in range(self.channel_num):
......@@ -269,70 +216,78 @@ class Seg:
self.max_im_value = np.asarray(max_im_value)
self.min_im_value = np.asarray(min_im_value)
im_mean, im_std = self._get_mean_std()
info = {
'channel_num': self.channel_num,
'image_pixel': self.im_pixel_info,
'label_pixel': self.label_pixel_info,
'file_num': len(self.file_list),
'max_height': self.max_height,
'max_width': self.max_width,
'min_height': self.min_height,
'min_width': self.min_width,
'max_image_value': self.max_im_value,
'min_image_value': self.min_im_value
}
saved_pkl_file = osp.join(self.data_dir,
'{}_infomation.pkl'.format(mode))
with open(osp.join(saved_pkl_file), 'wb') as f:
pickle.dump(info, f)
logging.info(
"############## The analysis results are as follows ##############\n"
)
logging.info("{} samples in file {}\n".format(
len(self.file_list), self.file_list_path))
logging.info("Maximal image height: {} Maximal image width: {}.\n".
format(self.max_height, self.max_width))
logging.info("Minimal image height: {} Minimal image width: {}.\n".
format(self.min_height, self.min_width))
logging.info("Maximal image height: {} Maximal image width: {}.\n".
format(self.max_height, self.max_width))
logging.info("Image channel is {}.\n".format(self.channel_num))
logging.info(
"Image mean value: {} Image standard deviation: {} (normalized by 255, sorted by a BGR format).\n".
format(im_mean, im_std))
"Minimal image value: {} Maximal image value: {} (arranged in 0-{} channel order) \n".
format(self.min_im_value, self.max_im_value, self.channel_num))
logging.info(
"Image pixel distribution of each channel is saved with 'distribute.png' in the {}"
.format(self.data_dir))
logging.info(
"Image mean value: {} Image standard deviation: {} (normalized by the (max_im_value - min_im_value), arranged in 0-{} channel order).\n".
format(im_mean, im_std, self.channel_num))
logging.info(
"Label pixel information is shown in a format of (label_id, the number of label_id, the ratio of label_id):"
)
for v, (n, r) in self.label_pixel_info.items():
logging.info("({}, {}, {})".format(v, n, r))
mode = osp.split(self.file_list_path)[-1].split('.')[0]
saved_pkl_file = osp.join(self.data_dir,
'{}_image_pixel_info.pkl'.format(mode))
saved_png_file = osp.join(self.data_dir,
'{}_image_pixel_info.png'.format(mode))
logging.info(
"Image pixel information is saved in the file '{}' and shown in the file '{}'".
format(saved_pkl_file, saved_png_file))
def cal_clipvalue_ratio(self, clip_min_value, clip_max_value):
if len(clip_min_value) != self.channel_num or len(
clip_max_value) != self.channel_num:
logging.info("Dataset information is saved in {}".format(
saved_pkl_file))
def cal_clipped_mean_std(self, clip_min_value, clip_max_value,
data_info_file):
with open(data_info_file, 'rb') as f:
im_info = pickle.load(f)
channel_num = im_info['channel_num']
min_im_value = im_info['min_image_value']
max_im_value = im_info['max_image_value']
im_pixel_info = im_info['image_pixel']
if len(clip_min_value) != channel_num or len(
clip_max_value) != channel_num:
raise Exception(
"The length of clip_min_value or clip_max_value should be equal to the number of image channel {}."
.format(self.channle_num))
for c in range(self.channel_num):
if clip_min_value[c] < self.min_im_value[c] or clip_min_value[
c] > self.max_im_value[c]:
raise Exception(
"Clip_min_value of the channel {} is not in [{}, {}]".
format(c, self.min_im_value[c], self.max_im_value[c]))
if clip_max_value[c] < self.min_im_value[c] or clip_max_value[
c] > self.max_im_value[c]:
raise Exception(
"Clip_max_value of the channel {} is not in [{}, {}]".
format(c, self.min_im_value[c], self.max_im_value[c]))
clip_pixel_num = 0
pixel_num = sum(self.im_pixel_info[c].values())
for v, n in self.im_pixel_info[c].items():
if v < clip_min_value[c] or v > clip_max_value[c]:
clip_pixel_num += n
logging.info("Channel {}, the ratio of pixels to be clipped = {}".
format(c, clip_pixel_num / pixel_num))
def cal_clipped_mean_std(self, clip_min_value, clip_max_value):
for c in range(self.channel_num):
if clip_min_value[c] < self.min_im_value[c] or clip_min_value[
c] > self.max_im_value[c]:
.format(channle_num))
for c in range(channel_num):
if clip_min_value[c] < min_im_value[c] or clip_min_value[
c] > max_im_value[c]:
raise Exception(
"Clip_min_value of the channel {} is not in [{}, {}]".
format(c, self.min_im_value[c], self.max_im_value[c]))
if clip_max_value[c] < self.min_im_value[c] or clip_max_value[
c] > self.max_im_value[c]:
format(c, min_im_value[c], max_im_value[c]))
if clip_max_value[c] < min_im_value[c] or clip_max_value[
c] > max_im_value[c]:
raise Exception(
"Clip_max_value of the channel {} is not in [{}, {}]".
format(c, self.min_im_value[c], self.max_im_value[c]))
format(c, min_im_value[c], self.max_im_value[c]))
self.clipped_im_mean_list = [[] for i in range(len(self.file_list))]
self.clipped_im_std_list = [[] for i in range(len(self.file_list))]
......@@ -340,6 +295,7 @@ class Seg:
num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
threads = []
one_worker_file = len(self.file_list) // num_workers
self.channel_num = channel_num
for i in range(num_workers):
start = one_worker_file * i
end = one_worker_file * (
......@@ -349,9 +305,9 @@ class Seg:
args=(start, end, clip_min_value, clip_max_value))
threads.append(t)
for t in threads:
t.setDaemon(True)
t.start()
t.join()
for t in threads:
t.join()
im_mean = np.asarray(self.clipped_im_mean_list)
im_mean = im_mean.sum(axis=0)
......@@ -361,6 +317,15 @@ class Seg:
im_std = im_std.sum(axis=0)
im_std = im_std / len(self.file_list)
for c in range(channel_num):
clip_pixel_num = 0
pixel_num = sum(im_pixel_info[c].values())
for v, n in im_pixel_info[c].items():
if v < clip_min_value[c] or v > clip_max_value[c]:
clip_pixel_num += n
logging.info("Channel {}, the ratio of pixels to be clipped = {}".
format(c, clip_pixel_num / pixel_num))
logging.info(
"Image mean value: {} Image standard deviation: {} (normalized by (clip_max_value - clip_min_value)).\n".
format(im_mean, im_std))
"Image mean value: {} Image standard deviation: {} (normalized by (clip_max_value - clip_min_value), arranged in 0-{} channel order).\n".
format(im_mean, im_std, self.channel_num))
......@@ -20,6 +20,7 @@ import numpy as np
import time
import paddlex.utils.logging as logging
from .detection_eval import fixed_linspace, backup_linspace, loadRes
from paddlex.cv.datasets.dataset import is_pic
def visualize_detection(image, result, threshold=0.5, save_dir='./'):
......@@ -44,7 +45,11 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
return image
def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
def visualize_segmentation(image,
result,
weight=0.6,
save_dir='./',
color=None):
"""
Convert segment result to color image, and save added image.
Args:
......@@ -52,10 +57,14 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
result: the predict result of image
weight: the image weight of visual image, and the result weight is (1 - weight)
save_dir: the directory for saving visual image
color: the list of a BGR-mode color for each label.
"""
label_map = result['label_map']
color_map = get_color_map_list(256)
if color is not None:
color_map[0:len(color) // 3][:] = color
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(label_map, color_map[:, 0])
c2 = cv2.LUT(label_map, color_map[:, 1])
......@@ -65,11 +74,26 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
if isinstance(image, np.ndarray):
im = image
image_name = str(int(time.time() * 1000)) + '.jpg'
if image.shape[2] != 3:
logging.info(
"The image is not 3-channel array, so predicted label map is shown as a pseudo color image."
)
weight = 0.
else:
image_name = os.path.split(image)[-1]
im = cv2.imread(image)
if not is_pic(image):
logging.info(
"The image cannot be opened by opencv, so predicted label map is shown as a pseudo color image."
)
image_name = image_name.split('.')[0] + '.jpg'
weight = 0.
else:
im = cv2.imread(image)
vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
if abs(weight) < 1e-5:
vis_result = pseudo_img
else:
vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
if save_dir is not None:
if not os.path.exists(save_dir):
......
......@@ -20,7 +20,9 @@ import os.path as osp
import numpy as np
from PIL import Image
import cv2
import imghdr
from collections import OrderedDict
import paddlex.utils.logging as logging
......@@ -60,6 +62,30 @@ class Compose(SegTransform):
"Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
)
@staticmethod
def read_img(img_path):
img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
import gdal
gdal.UseExceptions()
gdal.PushErrorHandler('CPLQuietErrorHandler')
try:
dataset = gdal.Open(img_path)
except:
logging.error(gdal.GetLastErrorMsg())
if dataset == None:
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0))
elif img_format == 'png':
return np.asarray(Image.open(img_path))
elif ext == '.npy':
return np.load(img_path)
else:
raise Exception('Image format {} is not supported!'.format(ext))
@staticmethod
def decode_image(im, label):
if isinstance(im, np.ndarray):
......@@ -69,7 +95,7 @@ class Compose(SegTransform):
format(len(im.shape)))
else:
try:
im = cv2.imread(im)
im = Compose.read_img(im)
except:
raise ValueError('Can\'t read The image file {}!'.format(im))
im = im.astype('float32')
......@@ -85,11 +111,11 @@ class Compose(SegTransform):
label = np.asarray(Image.open(label))
except:
ValueError('Can\'t read The label file {}!'.format(label))
im_height, im_width, _ = im.shape
label_height, label_width = label.shape
if im_height != label_height or im_width != label_width:
raise Exception(
"The height or width of the image is not same as the label")
im_height, im_width, _ = im.shape
label_height, label_width = label.shape
if im_height != label_height or im_width != label_width:
raise Exception(
"The height or width of the image is not same as the label")
return (im, label)
def __call__(self, im, im_info=None, label=None):
......@@ -570,12 +596,15 @@ class ResizeStepScaling(SegTransform):
class Normalize(SegTransform):
"""对图像进行标准化。
1.尺度缩放到 [0,1]。
2.对图像进行减均值除以标准差操作。
1.像素值减去min_val
2.像素值除以(max_val-min_val)
3.对图像进行减均值除以标准差操作。
Args:
mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
min_val (list): 图像数据集的最小值。默认值[0, 0, 0]。
max_val (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]。
Raises:
ValueError: mean或std不是list对象。std包含0。
......@@ -1099,6 +1128,33 @@ class RandomDistort(SegTransform):
return (im, im_info, label)
class Clip(SegTransform):
"""
对图像上超出一定范围的数据进行截断。
Args:
min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
"""
def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
self.min_val = min_val
self.max_val = max_val
if not (isinstance(self.min_val, list) and isinstance(self.max_val,
list)):
raise ValueError("{}: input type is invalid.".format(self))
def __call__(self, im, im_info=None, label=None):
for k in range(im.shape[2]):
np.clip(
im[:, :, k], self.min_val[k], self.max_val[k], out=im[:, :, k])
if label is None:
return (im, im_info)
else:
return (im, im_info, label)
class ArrangeSegmenter(SegTransform):
"""获取训练/验证/预测所需的信息。
......
import os
import os.path as osp
import imghdr
import gdal
gdal.UseExceptions()
gdal.PushErrorHandler('CPLQuietErrorHandler')
import numpy as np
from PIL import Image
from paddlex.seg import transforms
import paddlex.utils.logging as logging
def read_img(img_path):
img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
try:
dataset = gdal.Open(img_path)
except:
logging.error(gdal.GetLastErrorMsg())
if dataset == None:
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0))
elif img_format == 'png':
return np.asarray(Image.open(img_path))
elif ext == '.npy':
return np.load(img_path)
else:
raise Exception('Image format {} is not supported!'.format(ext))
def decode_image(im, label):
if isinstance(im, np.ndarray):
if len(im.shape) != 3:
raise Exception(
"im should be 3-dimensions, but now is {}-dimensions".format(
len(im.shape)))
else:
try:
im = read_img(im)
except:
raise ValueError('Can\'t read The image file {}!'.format(im))
im = im.astype('float32')
if label is not None:
if isinstance(label, np.ndarray):
if len(label.shape) != 2:
raise Exception(
"label should be 2-dimensions, but now is {}-dimensions".
format(len(label.shape)))
else:
try:
label = np.asarray(Image.open(label))
except:
ValueError('Can\'t read The label file {}!'.format(label))
im_height, im_width, _ = im.shape
label_height, label_width = label.shape
if im_height != label_height or im_width != label_width:
raise Exception(
"The height or width of the image is not same as the label")
return (im, label)
class Clip(transforms.SegTransform):
"""
对图像上超出一定范围的数据进行裁剪。
Args:
min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
"""
def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
self.min_val = min_val
self.max_val = max_val
if not (isinstance(self.min_val, list) and isinstance(self.max_val,
list)):
raise ValueError("{}: input type is invalid.".format(self))
def __call__(self, im, im_info=None, label=None):
for k in range(im.shape[2]):
np.clip(
im[:, :, k], self.min_val[k], self.max_val[k], out=im[:, :, k])
if label is None:
return (im, im_info)
else:
return (im, im_info, label)
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import os.path as osp
import sys
import argparse
from PIL import Image
from tqdm import tqdm
import imghdr
import logging
import pickle
import gdal
def parse_args():
parser = argparse.ArgumentParser(
description='Data analyse and data check before training.')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='Dataset directory',
default=None,
type=str)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
default=None,
type=int)
parser.add_argument(
'--separator',
dest='separator',
help='file list separator',
default=" ",
type=str)
parser.add_argument(
'--ignore_index',
dest='ignore_index',
help='Ignored class index',
default=255,
type=int)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def read_img(img_path):
img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
dataset = gdal.Open(img_path)
if dataset == None:
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0))
elif ext == '.npy':
return np.load(img_path)
else:
raise Exception('Not support {} image format!'.format(ext))
def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value):
channel = img.shape[2]
means = np.zeros(channel)
stds = np.zeros(channel)
for k in range(channel):
img_k = img[:, :, k]
# count mean, std
means[k] = np.mean(img_k)
stds[k] = np.std(img_k)
# count min, max
min_value = np.min(img_k)
max_value = np.max(img_k)
if img_max_value[k] < max_value:
img_max_value[k] = max_value
if img_min_value[k] > min_value:
img_min_value[k] = min_value
# count the distribution of image value, value number
unique, counts = np.unique(img_k, return_counts=True)
add_num = []
max_unique = np.max(unique)
add_len = max_unique - len(img_value_num[k]) + 1
if add_len > 0:
img_value_num[k] += ([0] * add_len)
for i in range(len(unique)):
value = unique[i]
img_value_num[k][value] += counts[i]
img_value_num[k] += add_num
return means, stds, img_min_value, img_max_value, img_value_num
def data_distribution_statistics(data_dir, img_value_num, logger):
"""count the distribution of image value, value number
"""
logger.info(
"\n-----------------------------\nThe whole dataset statistics...")
if not img_value_num:
return
logger.info("\nImage pixel statistics:")
total_ratio = []
[total_ratio.append([]) for i in range(len(img_value_num))]
for k in range(len(img_value_num)):
total_num = sum(img_value_num[k])
total_ratio[k] = [i / total_num for i in img_value_num[k]]
total_ratio[k] = np.around(total_ratio[k], decimals=4)
with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f:
pickle.dump([total_ratio, img_value_num], f)
def data_range_statistics(img_min_value, img_max_value, logger):
"""print min value, max value
"""
logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".
format(img_min_value, img_max_value))
def cal_normalize_coefficient(total_means, total_stds, total_img_num, logger):
"""count mean, std
"""
total_means = total_means / total_img_num
total_stds = total_stds / total_img_num
logger.info("\nCount the channel-by-channel mean and std of the image:\n"
"mean = {}\nstd = {}".format(total_means, total_stds))
def error_print(str):
return "".join(["\nNOT PASS ", str])
def correct_print(str):
return "".join(["\nPASS ", str])
def pil_imread(file_path):
"""read pseudo-color label"""
im = Image.open(file_path)
return np.asarray(im)
def get_img_shape_range(img, max_width, max_height, min_width, min_height):
"""获取图片最大和最小宽高"""
img_shape = img.shape
height, width = img_shape[0], img_shape[1]
max_height = max(height, max_height)
max_width = max(width, max_width)
min_height = min(height, min_height)
min_width = min(width, min_width)
return max_width, max_height, min_width, min_height
def get_img_channel_num(img, img_channels):
"""获取图像的通道数"""
img_shape = img.shape
if img_shape[-1] not in img_channels:
img_channels.append(img_shape[-1])
return img_channels
def is_label_single_channel(label):
"""判断标签是否为灰度图"""
label_shape = label.shape
if len(label_shape) == 2:
return True
else:
return False
def image_label_shape_check(img, label):
"""
验证图像和标注的大小是否匹配
"""
flag = True
img_height = img.shape[0]
img_width = img.shape[1]
label_height = label.shape[0]
label_width = label.shape[1]
if img_height != label_height or img_width != label_width:
flag = False
return flag
def ground_truth_check(label, label_path):
"""
验证标注图像的格式
统计标注图类别和像素数
params:
label: 标注图
label_path: 标注图路径
return:
png_format: 返回是否是png格式图片
unique: 返回标注类别
counts: 返回标注的像素数
"""
if imghdr.what(label_path) == "png":
png_format = True
else:
png_format = False
unique, counts = np.unique(label, return_counts=True)
return png_format, unique, counts
def sum_label_check(label_classes, num_of_each_class, ignore_index,
num_classes, total_label_classes, total_num_of_each_class):
"""
统计所有标注图上的类别和每个类别的像素数
params:
label_classes: 标注类别
num_of_each_class: 各个类别的像素数目
"""
is_label_correct = True
if ignore_index in label_classes:
label_classes2 = np.delete(label_classes,
np.where(label_classes == ignore_index))
else:
label_classes2 = label_classes
if min(label_classes2) < 0 or max(label_classes2) > num_classes - 1:
is_label_correct = False
add_class = []
add_num = []
for i in range(len(label_classes)):
gi = label_classes[i]
if gi in total_label_classes:
j = total_label_classes.index(gi)
total_num_of_each_class[j] += num_of_each_class[i]
else:
add_class.append(gi)
add_num.append(num_of_each_class[i])
total_num_of_each_class += add_num
total_label_classes += add_class
return is_label_correct, total_num_of_each_class, total_label_classes
def label_class_check(num_classes, total_label_classes,
total_num_of_each_class, wrong_labels, logger):
"""
检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。
**NOTE:**
标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。
标注类别最好从0开始,否则可能影响精度。
"""
total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
total_ratio = np.around(total_ratio, decimals=4)
total_nc = sorted(
zip(total_label_classes, total_ratio, total_num_of_each_class))
if len(wrong_labels) == 0 and not total_nc[0][0]:
logger.info(correct_print("label class check!"))
else:
logger.info(error_print("label class check!"))
if total_nc[0][0]:
logger.info("Warning: label classes should start from 0")
if len(wrong_labels) > 0:
logger.info("fatal error: label class is out of range [0, {}]".
format(num_classes - 1))
for i in wrong_labels:
logger.debug(i)
return total_nc
def label_class_statistics(total_nc, logger):
"""
对标注图像进行校验,输出校验结果
"""
logger.info("\nLabel class statistics:\n"
"(label class, percentage, total pixel number) = {} ".format(
total_nc))
def shape_check(shape_unequal_image, logger):
"""输出shape校验结果"""
if len(shape_unequal_image) == 0:
logger.info(correct_print("shape check"))
logger.info("All images are the same shape as the labels")
else:
logger.info(error_print("shape check"))
logger.info(
"Some images are not the same shape as the labels as follow: ")
for i in shape_unequal_image:
logger.debug(i)
def separator_check(wrong_lines, file_list, separator, logger):
"""检查分割符是否复合要求"""
if len(wrong_lines) == 0:
logger.info(
correct_print(
file_list.split(os.sep)[-1] + " DATASET.separator check"))
else:
logger.info(
error_print(
file_list.split(os.sep)[-1] + " DATASET.separator check"))
logger.info("The following list is not separated by {}".format(
separator))
for i in wrong_lines:
logger.debug(i)
def imread_check(imread_failed, logger):
if len(imread_failed) == 0:
logger.info(correct_print("dataset reading check"))
logger.info("All images can be read successfully")
else:
logger.info(error_print("dataset reading check"))
logger.info("Failed to read {} images".format(len(imread_failed)))
for i in imread_failed:
logger.debug(i)
def single_channel_label_check(label_not_single_channel, logger):
if len(label_not_single_channel) == 0:
logger.info(correct_print("label single_channel check"))
logger.info("All label images are single_channel")
else:
logger.info(error_print("label single_channel check"))
logger.info(
"{} label images are not single_channel\nLabel pixel statistics may be insignificant"
.format(len(label_not_single_channel)))
for i in label_not_single_channel:
logger.debug(i)
def img_shape_range_statistics(max_width, min_width, max_height, min_height,
logger):
logger.info("\nImage size statistics:")
logger.info(
"max width = {} min width = {} max height = {} min height = {}".
format(max_width, min_width, max_height, min_height))
def img_channels_statistics(img_channels, logger):
logger.info("\nImage channels statistics\nImage channels = {}".format(
np.unique(img_channels)))
def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
logger):
train_file_list = osp.join(data_dir, 'train.txt')
val_file_list = osp.join(data_dir, 'val.txt')
test_file_list = osp.join(data_dir, 'test.txt')
total_img_num = 0
has_label = False
for file_list in [train_file_list, val_file_list, test_file_list]:
# initialization
imread_failed = []
max_width = 0
max_height = 0
min_width = sys.float_info.max
min_height = sys.float_info.max
label_not_single_channel = []
shape_unequal_image = []
wrong_labels = []
wrong_lines = []
total_label_classes = []
total_num_of_each_class = []
img_channels = []
with open(file_list, 'r') as fid:
logger.info("\n-----------------------------\nCheck {}...".format(
file_list))
lines = fid.readlines()
if not lines:
logger.info("File list is empty!")
continue
for line in tqdm(lines):
line = line.strip()
parts = line.split(separator)
if len(parts) == 1:
if file_list == train_file_list or file_list == val_file_list:
logger.info("Train or val list must have labels!")
break
img_name = parts
img_path = os.path.join(data_dir, img_name[0])
try:
img = read_img(img_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
elif len(parts) == 2:
has_label = True
img_name, label_name = parts[0], parts[1]
img_path = os.path.join(data_dir, img_name)
label_path = os.path.join(data_dir, label_name)
try:
img = read_img(img_path)
label = pil_imread(label_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
is_single_channel = is_label_single_channel(label)
if not is_single_channel:
label_not_single_channel.append(line)
continue
is_equal_img_label_shape = image_label_shape_check(img,
label)
if not is_equal_img_label_shape:
shape_unequal_image.append(line)
png_format, label_classes, num_of_each_class = ground_truth_check(
label, label_path)
is_label_correct, total_num_of_each_class, total_label_classes = sum_label_check(
label_classes, num_of_each_class, ignore_index,
num_classes, total_label_classes,
total_num_of_each_class)
if not is_label_correct:
wrong_labels.append(line)
else:
wrong_lines.append(lines)
continue
if total_img_num == 0:
channel = img.shape[2]
total_means = np.zeros(channel)
total_stds = np.zeros(channel)
img_min_value = [sys.float_info.max] * channel
img_max_value = [0] * channel
img_value_num = []
[img_value_num.append([]) for i in range(channel)]
means, stds, img_min_value, img_max_value, img_value_num = img_pixel_statistics(
img, img_value_num, img_min_value, img_max_value)
total_means += means
total_stds += stds
max_width, max_height, min_width, min_height = get_img_shape_range(
img, max_width, max_height, min_width, min_height)
img_channels = get_img_channel_num(img, img_channels)
total_img_num += 1
# data check
separator_check(wrong_lines, file_list, separator, logger)
imread_check(imread_failed, logger)
if has_label:
single_channel_label_check(label_not_single_channel, logger)
shape_check(shape_unequal_image, logger)
total_nc = label_class_check(num_classes, total_label_classes,
total_num_of_each_class,
wrong_labels, logger)
# data analyse on train, validation, test set.
img_channels_statistics(img_channels, logger)
img_shape_range_statistics(max_width, min_width, max_height,
min_height, logger)
if has_label:
label_class_statistics(total_nc, logger)
# data analyse on the whole dataset.
data_range_statistics(img_min_value, img_max_value, logger)
data_distribution_statistics(data_dir, img_value_num, logger)
cal_normalize_coefficient(total_means, total_stds, total_img_num, logger)
def main():
args = parse_args()
data_dir = args.data_dir
ignore_index = args.ignore_index
num_classes = args.num_classes
separator = args.separator
logger = logging.getLogger()
logger.setLevel('DEBUG')
BASIC_FORMAT = "%(message)s"
formatter = logging.Formatter(BASIC_FORMAT)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
sh.setLevel('INFO')
th = logging.FileHandler(
os.path.join(data_dir, 'data_analyse_and_check.log'), 'w')
th.setFormatter(formatter)
logger.addHandler(sh)
logger.addHandler(th)
data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
logger)
print("\nDetailed error information can be viewed in {}.".format(
os.path.join(data_dir, 'data_analyse_and_check.log')))
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册