未验证 提交 a196414b 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #290 from LutaoChu/rs-doc

Update the whole process documentation for remote sensing
# PaddleSeg遥感影像分割
遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。
PaddleSeg遥感影像分割涵盖图像预处理、数据增强、模型训练、预测流程,帮助用户利用深度学习技术解决遥感影像分割问题。
PaddleSeg遥感影像分割涵盖数据分析、预处理、数据增强、模型训练、预测等流程,帮助用户利用深度学习技术解决遥感影像分割问题。
## 特点
- 针对遥感数据多通道、分布范围大、分布不均的特点,我们支持多通道训练预测,内置10+多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。
- 针对遥感影像多通道、标注数据稀少的特点,我们支持多通道训练预测,内置10+多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。
- 内置U-Net, HRNet两种主流分割网络,可选择不同的损失函数如Dice Loss, BCE Loss等方式强化小目标和不均衡样本场景下的分割精度。
以下是遥感云检测的示例效果:
- 针对遥感影像分布范围广、分布不均的特点,我们提供数据分析工具,帮助深入了解数据组成、优化模型训练效果。为确保正常训练,我们提供数据校验工具,帮助排查数据问题。
![](./docs/imgs/rs.png)
- 内置U-Net, HRNet两种主流分割网络,可选择不同的损失函数如Dice Loss, BCE Loss等方式强化小目标和不均衡样本场景下的分割精度。
## 前置依赖
**Note:** 若没有特殊说明,以下所有命令需要在`PaddleSeg/contrib/RemoteSensing/`目录下执行。
......@@ -22,10 +20,15 @@ PaddlePaddle的安装, 请按照[官网指引](https://paddlepaddle.org.cn/insta
- Python 3.5+
- 其他依赖安装
通过以下命令安装python包依赖,请确保至少执行过一次以下命令:
```
pip install -r requirements.txt
```
另外需要安装gdal. **Note:** 使用pip安装gdal可能出错,推荐使用conda进行安装:
```
conda install gdal
```
## 目录结构说明
```
......@@ -40,93 +43,107 @@ RemoteSensing # 根目录
|-- utils # 公用模块
|-- train_demo.py # 训练demo脚本
|-- predict_demo.py # 预测demo脚本
|-- visualize_demo.py # 可视化demo脚本
|-- README.md # 使用手册
```
## 数据协议
数据集包含原图、标注图及相应的文件列表文件。
参考数据文件结构如下:
```
./dataset/ # 数据集根目录
|--images # 原图目录
| |--xxx1.npy
| |--...
| └--...
|
|--annotations # 标注图目录
| |--xxx1.png
| |--...
| └--...
|
|--train_list.txt # 训练文件列表文件
|
|--val_list.txt # 验证文件列表文件
|
└--labels.txt # 标签列表
## 使用教程
```
其中,相应的文件名可根据需要自行定义。
基于L8 SPARCS数据集进行云雪分割,提供数据准备、数据分析、训练、预测、可视化的全流程展示。
遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。PaddleSeg以numpy.ndarray数据类型进行图像预处理。为统一接口并方便数据加载,我们采用numpy存储格式`npy`作为原图格式,采用`png`无损压缩格式作为标注图片格式。
原图的尺寸应为(h, w, channel),其中h, w为图像的高和宽,channel为图像的通道数。
标注图像为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增
例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)
### 1. 数据准备
#### L8 SPARCS数据集
[L8 SPARCS](https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation)数据集包含80张 Landsat 8 卫星影像,涵盖10个波段
原始标注图片包含7个类别,分别是 “cloud”, “cloud shadow”, “shadow over water”, “snow/ice”, ”water”, “land”和”flooded”
`train_list.txt``val_list.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为标注图像文件相对于dataset的相对路径。如下所示:
```
images/xxx1.npy annotations/xxx1.png
images/xxx2.npy annotations/xxx2.png
...
```
<p align="center">
<img src="./docs/imgs/dataset.png" align="middle"
</p>
具体要求和如何生成文件列表可参考[文件列表规范](../../docs/data_prepare.md#文件列表)
`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
```
labelA
labelB
...
```
<p align='center'>
L8 SPARCS数据集示例
</p>
由于“flooded”和“shadow over water”2个类别占比仅为1.8%和0.24%,我们将其进行合并,
“flooded”归为“land”,“shadow over water”归为“shadow”,合并后标注包含5个类别。
数值、类别、颜色对应表:
## 快速上手
|Pixel value|Class|Color|
|---|---|---|
|0|cloud|white|
|1|shadow|black|
|2|snow/ice|cyan|
|3|water|blue|
|4|land|grey|
本章节在一个小数据集上展示了如何通过RemoteSensing进行训练预测。
### 1. 准备数据集
为了快速体验,我们准备了一个小型demo数据集,已位于`RemoteSensing/dataset/demo/`目录下.
执行以下命令下载并解压经过处理之后的数据集`remote_sensing_seg`
```shell script
mkdir dataset && cd dataset
wget https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip
unzip remote_sensing_seg.zip
cd ..
```
其中`data`目录存放遥感影像,`data_vis`目录存放彩色合成预览图,`mask`目录存放标注图。
对于您自己的数据集,您需要按照上述的数据协议进行格式转换,可分别使用numpy和Pillow库保存遥感数据和标注图片。其中numpy API示例如下:
```python
import numpy as np
#### 数据协议
对于您自己的数据集,需要按照我们的[数据协议](docs/data_prepare.md)进行数据准备。
# 将遥感数据保存到以 .npy 为扩展名的文件中
# img类型:numpy.ndarray
np.save(save_path, img)
```
### 2. 数据校验与分析
为确保能正常训练,我们应该先对数据集进行校验。同时,遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的组成、优化模型训练效果,需要对数据进行分析。
具体步骤参见[数据校验与分析](docs/data_analyse_and_check.md)章节。
### 2. 模型训练
### 3. 模型训练
#### (1) 设置GPU卡号
```shell script
export CUDA_VISIBLE_DEVICES=0
```
#### (2) 以U-Net为例,在RemoteSensing目录下运行`train_demo.py`即可开始训练。
```shell script
python train_demo.py --model_type unet --data_dir dataset/demo/ --save_dir saved_model/unet/ --channel 3 --num_epochs 20
python train_demo.py --data_dir dataset/remote_sensing_seg \
--model_type unet \
--save_dir saved_model/remote_sensing_unet \
--num_classes 5 \
--channel 10 \
--lr 0.01 \
--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000 \
--mean 0.14311188522260637 0.14288498042151332 0.14812997807748615 0.16377211813814938 0.2737538363784552 0.2740934379398823 0.27749601919204 0.07767443032935262 0.5694699410349131 0.5549716085195542 \
--std 0.09101632762467489 0.09600705942721106 0.096193618606776 0.10371446736389771 0.10911951586604118 0.11043593115173281 0.12648042598739268 0.027746262217260665 0.06822348076384514 0.062377591186668725 \
--num_epochs 500 \
--train_batch_size 3
```
### 3. 模型预测
训练过程将自动开启边训边评估策略,并使用VisualDL保存训练日志,显示如下:
![](docs/imgs/visualdl.png)
`mIoU`最高的模型将自动保存在`saved_model/remote_sensing_unet/best_model`目录下,最高mIoU=0.7782
### 4. 模型预测
#### (1) 设置GPU卡号
```shell script
export CUDA_VISIBLE_DEVICES=0
```
#### (2) 以刚训练好的U-Net最优模型为例,在RemoteSensing目录下运行`predict_demo.py`即可开始训练
#### (2) 以刚训练好的U-Net最优模型为例,在RemoteSensing目录下运行`predict_demo.py`即可开始预测
```shell script
python predict_demo.py --data_dir dataset/demo/ --file_list val.txt --load_model_dir saved_model/unet/best_model
python predict_demo.py --data_dir dataset/remote_sensing_seg/ \
--file_list val.txt \
--load_model_dir saved_model/remote_sensing_unet/best_model \
--save_img_dir saved_model/remote_sensing_unet/best_model/predict \
--color_map 255 255 255 0 0 0 0 255 255 0 0 255 150 150 150
```
### 5. 可视化
我们提供可视化API对预测效果进行直观的展示和对比。每张可视化图片包括彩色合成预览图、标注图、预测结果,使得效果好坏一目了然。
```shell script
python visualize_demo.py --data_dir dataset/remote_sensing_seg/ \
--file_list val.txt \
--pred_dir saved_model/remote_sensing_unet/best_model/predict \
--save_dir saved_model/remote_sensing_unet/best_model/vis_results
````
3张可视化图片示例:
![](docs/imgs/vis.png)
## API说明
......
__background__
cloud
\ No newline at end of file
images/1001.npy annotations/1001.png
images/1002.npy annotations/1002.png
images/1005.npy annotations/1005.png
images/0.npy annotations/0.png
images/1003.npy annotations/1003.png
images/1000.npy annotations/1000.png
images/1004.npy annotations/1004.png
images/100.npy annotations/100.png
images/1.npy annotations/1.png
images/10.npy annotations/10.png
# 数据校验和分析
为确保能正常训练,我们应该先对数据集进行校验。同时,遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的组成、优化模型训练效果,需要对数据进行分析。
接下来以`remote_sensing_seg`数据集为例展示数据校验和分析的全过程。
## Step1 数据校验和初步分析
我们提供数据校验和分析的脚本,帮助您排查基本的数据问题,为如何配置训练参数提供指导。使用方式如下所示:
```shell script
python tools/data_analyse_and_check.py --data_dir 'dataset/remote_sensing_seg/' --num_classes 5
```
参数说明:
- --data_dir: 数据集所在目录
- --num_classes: 数据的类别数
运行后,命令行将显示概览信息,详细的错误信息将以data_analyse_and_check.log文件保存到数据集所在目录。
### 数据校验
数据校验内容如下:
#### 1 列表分割符校验(separator_check)
检查在`train.txt``val.txt``test.txt`列表文件中的分隔符设置是否正确。
#### 2 数据读取校验(imread_check)
检查是否能成功读取`train.txt``val.txt``test.txt`中所有图片。
若不正确返回错误信息。错误可能有多种情况,如数据集路径设置错误、图片损坏等。
#### 3 标注通道数校验(single_channel_label_check)
检查标注图的通道数。正确的标注图应该为单通道图像。
#### 4 标注类别校验(label_class_check)
检查实际标注类别是否和配置参数`num_classes``ignore_index`匹配。
**NOTE:**
标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`
标注类别最好从0开始,否则可能影响精度。
#### 5 图像与标注图尺寸一致性校验(shape_check)
验证图像尺寸和对应标注图尺寸是否一致。
### 数据分析
数据统计分析内容如下:
#### 1 标注类别统计(label_class_statistics)
统计每种类别的像素总数和所占比例。统计结果示例如下:
```
Label class statistics:
(label class, percentage, total pixel number) = [(0, 0.1372, 2194601), (1, 0.0827, 1322459), (2, 0.0179, 286548), (3, 0.1067, 1706810), (4, 0.6556, 10489582)]
```
#### 2 图像尺寸范围统计(img_shape_range_statistics)
统计数据集中图片的最大和最小的宽高。
#### 3 图像通道数统计(img_channels_statistics)
统计数据集中图片的通道个数。
#### 4 数据范围统计(data_range_statistics)
逐通道地统计数据集的数值范围。
#### 5 数据分布统计(data_distribution_statistics)
逐通道地统计数据集分布。并将分布保存为`pkl`文件,方便后续可视化和数据裁剪。
#### 6 归一化系数计算(cal_normalize_coefficient)
逐通道地计算归一化系数mean、standard deviation.
**备注:** 数据分析步骤1\~3在训练集、验证集、测试集上分别进行,步骤4\~6在整个数据集上进行。
## Step2 数据分布可视化,确定数据裁剪范围
### 数据分布可视化
我们提供可视化数据分布脚本,对数据集的数据分布按通道进行可视化。
可视化需要先安装matplotlib:
```shell script
pip install matplotlib
```
使用方式如下:
```shell script
python tools/data_distribution_vis.py --pkl_path 'dataset/remote_sensing_seg/img_pixel_statistics.pkl'
```
参数说明:
- --pkl_path: 数据分布文件保存路径
其中部分通道的可视化效果如下:
![](./imgs/data_distribution.png)
需要注意的是,为便于观察,纵坐标为对数坐标。
### 确定数据裁剪范围
遥感影像数据分布范围广,其中往往存在一些异常值,影响算法对实际数据分布的拟合效果。为更好地对数据进行归一化,需要抑制遥感影像中少量的异常值。
我们可以根据上述的数据分布统计结果来确定数据裁剪范围,并在后续图像预处理过程中对超出范围的像素值通过截断进行校正,从而去除异常值带来的干扰。
例如对于上述数据分布进行逐通道数据裁剪,我们选取的截断范围是:
```
裁剪范围最小值: clip_min_value = [7172, 6561, 5777, 5103, 4291, 4000, 4000, 4232, 6934, 7199]
裁剪范围最大值: clip_max_value = [50000, 50000, 50000, 50000, 50000, 40000, 30000, 18000, 40000, 36000]
```
## Step3 统计裁剪比例、归一化系数
为避免数据裁剪范围选取不当带来的影响,应该统计异常值像素占比,确保受影响的像素比例不要过高。
接着对裁剪后的数据计算归一化系数mean和standard deviation,用于图像预处理中的归一化参数设置。
使用方式如下:
```shell script
python tools/cal_norm_coef.py --data_dir 'dataset/remote_sensing_seg/' \
--pkl_path 'dataset/remote_sensing_seg/img_pixel_statistics.pkl' \
--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000
```
参数说明:
- --data_dir: 数据集路径
- --pkl_path: 数据分布文件保存路径
- --clip_min_value: 数据裁剪范围最小值
- --clip_max_value: 数据裁剪范围最大值
裁剪像素占比统计结果如下:
```
channel 0, the percentage of pixels to be clipped = 0.0005625999999999687
channel 1, the percentage of pixels to be clipped = 0.0011332250000000155
channel 2, the percentage of pixels to be clipped = 0.0008772375000000165
channel 3, the percentage of pixels to be clipped = 0.0013191750000000058
channel 4, the percentage of pixels to be clipped = 0.0012433250000000173
channel 5, the percentage of pixels to be clipped = 7.49875000000122e-05
channel 6, the percentage of pixels to be clipped = 0.0006973750000000001
channel 7, the percentage of pixels to be clipped = 4.950000000003563e-06
channel 8, the percentage of pixels to be clipped = 0.00014873749999999575
channel 9, the percentage of pixels to be clipped = 0.00011173750000004201
```
可看出,被裁剪像素占比均不超过0.2%
裁剪后数据的归一化系数如下:
```
Count the channel-by-channel mean and std of the image:
mean = [0.14311189 0.14288498 0.14812998 0.16377212 0.27375384 0.27409344 0.27749602 0.07767443 0.56946994 0.55497161]
std = [0.09101633 0.09600706 0.09619362 0.10371447 0.10911952 0.11043593 0.12648043 0.02774626 0.06822348 0.06237759]
```
# 数据准备
## 数据协议
数据集包含原图、标注图及相应的文件列表文件。
### 数据格式
遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。
PaddleSeg已兼容以下4种格式图片读取:
- `tif`
- `png`
- `img`
- `npy`
### 原图要求
原图数据的尺寸应为(h, w, channel),其中h, w为图像的高和宽,channel为图像的通道数。
### 标注图要求
标注图像必须为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增。
例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)。
### 文件列表文件
文件列表文件包括`train.txt``val.txt``test.txt``labels.txt`.
`train.txt``val.txt``test.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为标注图像文件相对于dataset的相对路径。如下所示:
```
images/xxx1.tif annotations/xxx1.png
images/xxx2.tif annotations/xxx2.png
...
```
`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
```
labelA
labelB
...
```
## 数据集切分和文件列表生成
数据集切分有2种方式:随机切分和手动切分。对于这2种方式,PaddleSeg均提供了生成文件列表的脚本,您可以按需要选择。
### 1 对数据集按比例随机切分,并生成文件列表
数据文件结构如下:
```
./dataset/ # 数据集根目录
|--images # 原图目录
| |--xxx1.tif
| |--...
| └--...
|
|--annotations # 标注图目录
| |--xxx1.png
| |--...
| └--...
```
其中,相应的文件名可根据需要自行定义。
使用命令如下,支持通过不同的Flags来开启特定功能。
```
python tools/split_dataset_list.py <dataset_root> <images_dir_name> <labels_dir_name> ${FLAGS}
```
参数说明:
- dataset_root: 数据集根目录
- images_dir_name: 原图目录名
- labels_dir_name: 标注图目录名
FLAGS说明:
|FLAG|含义|默认值|参数数目|
|-|-|-|-|
|--split|数据集切分比例|0.7 0.3 0|3|
|--separator|文件列表分隔符|" "|1|
|--format|图片和标签集的数据格式|"tif" "png"|2|
|--label_class|标注类别|'\_\_background\_\_' '\_\_foreground\_\_'|若干|
|--postfix|按文件主名(无扩展名)是否包含指定后缀对图片和标签集进行筛选|"" ""(2个空字符)|2|
运行后将在数据集根目录下生成`train.txt``val.txt``test.txt``labels.txt`.
**Note:** 生成文件列表要求:要么原图和标注图片数量一致,要么只有原图,没有标注图片。若数据集缺少标注图片,将生成不含分隔符和标注图片路径的文件列表。
#### 使用示例
```
python tools/split_dataset_list.py <dataset_root> images annotations --split 0.6 0.2 0.2 --format tif png
```
### 2 已经手工划分好数据集,按照目录结构生成文件列表
数据目录手工划分成如下结构:
```
./dataset/ # 数据集根目录
├── annotations # 标注目录
│   ├── test
│   │   ├── ...
│   │   └── ...
│   ├── train
│   │   ├── ...
│   │   └── ...
│   └── val
│   ├── ...
│   └── ...
└── images # 原图目录
├── test
│   ├── ...
│   └── ...
├── train
│   ├── ...
│   └── ...
└── val
├── ...
└── ...
```
其中,相应的文件名可根据需要自行定义。
使用命令如下,支持通过不同的Flags来开启特定功能。
```
python tools/create_dataset_list.py <dataset_root> ${FLAGS}
```
参数说明:
- dataset_root: 数据集根目录
FLAGS说明:
|FLAG|含义|默认值|参数数目|
|-|-|-|-|
|--separator|文件列表分隔符|" "|1|
|--folder|图片和标签集的文件夹名|"images" "annotations"|2|
|--second_folder|训练/验证/测试集的文件夹名|"train" "val" "test"|若干|
|--format|图片和标签集的数据格式|"tif" "png"|2|
|--label_class|标注类别|'\_\_background\_\_' '\_\_foreground\_\_'|若干|
|--postfix|按文件主名(无扩展名)是否包含指定后缀对图片和标签集进行筛选|"" ""(2个空字符)|2|
运行后将在数据集根目录下生成`train.txt``val.txt``test.txt``labels.txt`.
**Note:** 生成文件列表要求:要么原图和标注图片数量一致,要么只有原图,没有标注图片。若数据集缺少标注图片,将生成不含分隔符和标注图片路径的文件列表。
#### 使用示例
若您已经按上述说明整理好了数据集目录结构,可以运行下面的命令生成文件列表。
```
# 生成文件列表,其分隔符为空格,图片和标签集的数据格式都为png
python tools/create_dataset_list.py <dataset_root> --separator " " --format png png
```
```
# 生成文件列表,其图片和标签集的文件夹名为img和gt,训练和验证集的文件夹名为training和validation,不生成测试集列表
python tools/create_dataset_list.py <dataset_root> \
--folder img gt --second_folder training validation
```
......@@ -22,6 +22,7 @@ from utils import logging
from .base import BaseReader
from .base import get_encoding
from collections import OrderedDict
from PIL import Image
def read_img(img_path):
......@@ -33,6 +34,8 @@ def read_img(img_path):
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0))
elif img_format == 'png':
return np.asarray(Image.open(img_path))
elif ext == '.npy':
return np.load(img_path)
else:
......
......@@ -46,10 +46,10 @@ def parse_args():
default=['train', 'val', 'test'])
parser.add_argument(
'--format',
help='data format of images and labels, default npy, png.',
help='data format of images and labels, default tif, png.',
type=str,
nargs=2,
default=['npy', 'png'])
default=['tif', 'png'])
parser.add_argument(
'--label_class',
help='label class names',
......
......@@ -51,6 +51,12 @@ def parse_args():
help='file list separator',
default=" ",
type=str)
parser.add_argument(
'--ignore_index',
dest='ignore_index',
help='Ignored class index',
default=255,
type=int)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
......@@ -106,12 +112,12 @@ def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value):
return means, stds, img_min_value, img_max_value, img_value_num
def dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num,
img_min_value, img_max_value, total_img_num,
logger):
logger.info("\n-----------------------------\nDataset pixel statistics...")
def data_distribution_statistics(data_dir, img_value_num, logger):
"""count the distribution of image value, value number
"""
logger.info(
"\n-----------------------------\nThe whole dataset statistics...")
# count the distribution of image value, value number
if not img_value_num:
return
logger.info("\nImage pixel statistics:")
......@@ -124,15 +130,21 @@ def dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num,
with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f:
pickle.dump([total_ratio, img_value_num], f)
# print min value, max value
def data_range_statistics(img_min_value, img_max_value, logger):
"""print min value, max value
"""
logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".format(
img_min_value, img_max_value))
# count mean, std
def cal_normalize_coefficient(total_means, total_stds, total_img_num, logger):
"""count mean, std
"""
total_means = total_means / total_img_num
total_stds = total_stds / total_img_num
print("\nCount the channel-by-channel mean and std of the image:\n"
"mean = {}\nstd = {}".format(total_means, total_stds))
logger.info("\nCount the channel-by-channel mean and std of the image:\n"
"mean = {}\nstd = {}".format(total_means, total_stds))
def error_print(str):
......@@ -160,12 +172,12 @@ def get_img_shape_range(img, max_width, max_height, min_width, min_height):
return max_width, max_height, min_width, min_height
def get_image_dim(img, img_dim):
def get_img_channel_num(img, img_channels):
"""获取图像的通道数"""
img_shape = img.shape
if img_shape[-1] not in img_dim:
img_dim.append(img_shape[-1])
return img_dim
if img_shape[-1] not in img_channels:
img_channels.append(img_shape[-1])
return img_channels
def is_label_single_channel(label):
......@@ -215,23 +227,16 @@ def ground_truth_check(label, label_path):
return png_format, unique, counts
def sum_label_check(png_format, label_classes, num_of_each_class, ignore_index,
num_classes, png_format_right_num, png_format_wrong_num,
def sum_label_check(label_classes, num_of_each_class, ignore_index, num_classes,
total_label_classes, total_num_of_each_class):
"""
统计所有标注图上的格式、类别和每个类别的像素数
统计所有标注图上的类别和每个类别的像素数
params:
png_format: 是否是png格式图片
label_classes: 标注类别
num_of_each_class: 各个类别的像素数目
"""
is_label_correct = True
if png_format:
png_format_right_num += 1
else:
png_format_wrong_num += 1
if ignore_index in label_classes:
label_classes2 = np.delete(label_classes,
np.where(label_classes == ignore_index))
......@@ -251,32 +256,18 @@ def sum_label_check(png_format, label_classes, num_of_each_class, ignore_index,
add_num.append(num_of_each_class[i])
total_num_of_each_class += add_num
total_label_classes += add_class
return is_label_correct, png_format_right_num, png_format_wrong_num, total_num_of_each_class, total_label_classes
return is_label_correct, total_num_of_each_class, total_label_classes
def label_check_statistics(num_classes, png_format_wrong_image,
png_format_right_num, png_format_wrong_num,
total_label_classes, total_num_of_each_class,
wrong_labels, logger):
def label_class_check(num_classes, total_label_classes, total_num_of_each_class,
wrong_labels, logger):
"""
对标注图像进行校验,输出校验结果
"""
if png_format_wrong_num == 0:
if png_format_right_num:
logger.info(correct_print("label format check"))
else:
logger.info(error_print("label format check"))
logger.info("No label image to check")
return
else:
logger.info(error_print("label format check"))
logger.info(
"total {} label images are png format, {} label images are not png "
"format".format(png_format_right_num, png_format_wrong_num))
if len(png_format_wrong_image) > 0:
for i in png_format_wrong_image:
logger.debug(i)
检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。
**NOTE:**
标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。
标注类别最好从0开始,否则可能影响精度。
"""
total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
total_ratio = np.around(total_ratio, decimals=4)
total_nc = sorted(
......@@ -293,9 +284,15 @@ def label_check_statistics(num_classes, png_format_wrong_image,
num_classes - 1))
for i in wrong_labels:
logger.debug(i)
return total_nc
def label_class_statistics(total_nc, logger):
"""
对标注图像进行校验,输出校验结果
"""
logger.info(
"\nLabel pixel statistics:\n"
"\nLabel class statistics:\n"
"(label class, percentage, total pixel number) = {} ".format(total_nc))
......@@ -360,9 +357,9 @@ def img_shape_range_statistics(max_width, min_width, max_height, min_height,
format(max_width, min_width, max_height, min_height))
def img_dim_statistics(img_dim, logger):
def img_channels_statistics(img_channels, logger):
logger.info("\nImage channels statistics\nImage channels = {}".format(
np.unique(img_dim)))
np.unique(img_channels)))
def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
......@@ -381,14 +378,11 @@ def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
min_height = sys.float_info.max
label_not_single_channel = []
shape_unequal_image = []
png_format_wrong_image = []
wrong_labels = []
wrong_lines = []
png_format_right_num = 0
png_format_wrong_num = 0
total_label_classes = []
total_num_of_each_class = []
img_dim = []
img_channels = []
with open(file_list, 'r') as fid:
logger.info("\n-----------------------------\nCheck {}...".format(
......@@ -433,12 +427,9 @@ def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
shape_unequal_image.append(line)
png_format, label_classes, num_of_each_class = ground_truth_check(
label, label_path)
if not png_format:
png_format_wrong_image.append(line)
is_label_correct, png_format_right_num, png_format_wrong_num, total_num_of_each_class, total_label_classes = sum_label_check(
png_format, label_classes, num_of_each_class,
ignore_index, num_classes, png_format_right_num,
png_format_wrong_num, total_label_classes,
is_label_correct, total_num_of_each_class, total_label_classes = sum_label_check(
label_classes, num_of_each_class, ignore_index,
num_classes, total_label_classes,
total_num_of_each_class)
if not is_label_correct:
wrong_labels.append(line)
......@@ -460,32 +451,35 @@ def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
total_stds += stds
max_width, max_height, min_width, min_height = get_img_shape_range(
img, max_width, max_height, min_width, min_height)
img_dim = get_image_dim(img, img_dim)
img_channels = get_img_channel_num(img, img_channels)
total_img_num += 1
# data check
separator_check(wrong_lines, file_list, separator, logger)
imread_check(imread_failed, logger)
img_dim_statistics(img_dim, logger)
img_shape_range_statistics(max_width, min_width, max_height,
min_height, logger)
if has_label:
single_channel_label_check(label_not_single_channel, logger)
shape_check(shape_unequal_image, logger)
label_check_statistics(
num_classes, png_format_wrong_image, png_format_right_num,
png_format_wrong_num, total_label_classes,
total_num_of_each_class, wrong_labels, logger)
total_nc = label_class_check(num_classes, total_label_classes,
total_num_of_each_class,
wrong_labels, logger)
dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num,
img_min_value, img_max_value, total_img_num,
logger)
# data analyse on train, validation, test set.
img_channels_statistics(img_channels, logger)
img_shape_range_statistics(max_width, min_width, max_height,
min_height, logger)
if has_label:
label_class_statistics(total_nc, logger)
# data analyse on the whole dataset.
data_range_statistics(img_min_value, img_max_value, logger)
data_distribution_statistics(data_dir, img_value_num, logger)
cal_normalize_coefficient(total_means, total_stds, total_img_num, logger)
def main():
args = parse_args()
data_dir = args.data_dir
ignore_index = 255
ignore_index = args.ignore_index
num_classes = args.num_classes
separator = args.separator
......
......@@ -25,8 +25,10 @@ def parse_args():
description=
'A tool for proportionally randomizing dataset to produce file lists.')
parser.add_argument('dataset_root', help='the dataset root path', type=str)
parser.add_argument('images', help='the directory name of images', type=str)
parser.add_argument('labels', help='the directory name of labels', type=str)
parser.add_argument(
'images_dir_name', help='the directory name of images', type=str)
parser.add_argument(
'labels_dir_name', help='the directory name of labels', type=str)
parser.add_argument(
'--split', help='', nargs=3, type=float, default=[0.7, 0.3, 0])
parser.add_argument(
......@@ -43,10 +45,10 @@ def parse_args():
type=str)
parser.add_argument(
'--format',
help='data format of images and labels, e.g. jpg, npy or png.',
help='data format of images and labels, e.g. jpg, tif or png.',
type=str,
nargs=2,
default=['npy', 'png'])
default=['tif', 'png'])
parser.add_argument(
'--postfix',
help='postfix of images or labels',
......@@ -84,8 +86,8 @@ def generate_list(args):
for label_class in args.label_class:
f.write(label_class + '\n')
image_dir = os.path.join(dataset_root, args.images)
label_dir = os.path.join(dataset_root, args.labels)
image_dir = os.path.join(dataset_root, args.images_dir_name)
label_dir = os.path.join(dataset_root, args.labels_dir_name)
image_files = get_files(image_dir, args.format[0], args.postfix[0])
label_files = get_files(label_dir, args.format[1], args.postfix[1])
if not image_files:
......
......@@ -15,9 +15,7 @@
from .ops import *
import random
import os.path as osp
import numpy as np
from PIL import Image
import cv2
from collections import OrderedDict
from readers.reader import read_img
......@@ -63,7 +61,7 @@ class Compose:
if im is None:
raise ValueError('Can\'t read The image file {}!'.format(im))
if label is not None:
label = np.asarray(Image.open(label))
label = read_img(label)
for op in self.transforms:
outputs = op(im, im_info, label)
......
import os
import os.path as osp
import argparse
from PIL import Image as Image
from models.utils import visualize as vis
def parse_args():
parser = argparse.ArgumentParser(description='RemoteSensing visualization')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='Dataset directory',
default=None,
type=str)
parser.add_argument(
'--file_list',
dest='file_list',
help='The name of file list that need to be visualized',
default=None,
type=str)
parser.add_argument(
'--pred_dir',
dest='pred_dir',
help='Directory for predict results',
default=None,
type=str)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='Save directory for visual results',
default=None,
type=str)
return parser.parse_args()
args = parse_args()
data_dir = args.data_dir
pred_dir = args.pred_dir
save_dir = args.save_dir
file_list = osp.join(data_dir, args.file_list)
if not osp.exists(save_dir):
os.mkdir(save_dir)
with open(file_list) as f:
lines = f.readlines()
for line in lines:
img_list = []
img_line = line.split(' ')[0]
img_name = osp.basename(img_line).replace('data.tif', 'photo.png')
img_path = osp.join(data_dir, 'data_vis', img_name)
img = Image.open(img_path)
img_list.append(img)
print('visualizing {}'.format(img_path))
gt_line = line.split(' ')[1].rstrip('\n')
gt_path = osp.join(data_dir, gt_line)
gt_pil = Image.open(gt_path)
img_list.append(gt_pil)
pred_name = osp.basename(img_line).replace('tif', 'png')
pred_path = osp.join(pred_dir, pred_name)
pred_pil = Image.open(pred_path)
img_list.append(pred_pil)
save_path = osp.join(save_dir, pred_name)
vis.splice_imgs(img_list, save_path)
print('saved in {}'.format(save_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册