未验证 提交 a196414b 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #290 from LutaoChu/rs-doc

Update the whole process documentation for remote sensing
# PaddleSeg遥感影像分割 # PaddleSeg遥感影像分割
遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。 遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。
PaddleSeg遥感影像分割涵盖图像预处理、数据增强、模型训练、预测流程,帮助用户利用深度学习技术解决遥感影像分割问题。 PaddleSeg遥感影像分割涵盖数据分析、预处理、数据增强、模型训练、预测等流程,帮助用户利用深度学习技术解决遥感影像分割问题。
## 特点 ## 特点
- 针对遥感数据多通道、分布范围大、分布不均的特点,我们支持多通道训练预测,内置10+多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。 - 针对遥感影像多通道、标注数据稀少的特点,我们支持多通道训练预测,内置10+多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。
- 内置U-Net, HRNet两种主流分割网络,可选择不同的损失函数如Dice Loss, BCE Loss等方式强化小目标和不均衡样本场景下的分割精度。 - 针对遥感影像分布范围广、分布不均的特点,我们提供数据分析工具,帮助深入了解数据组成、优化模型训练效果。为确保正常训练,我们提供数据校验工具,帮助排查数据问题。
以下是遥感云检测的示例效果:
![](./docs/imgs/rs.png) - 内置U-Net, HRNet两种主流分割网络,可选择不同的损失函数如Dice Loss, BCE Loss等方式强化小目标和不均衡样本场景下的分割精度。
## 前置依赖 ## 前置依赖
**Note:** 若没有特殊说明,以下所有命令需要在`PaddleSeg/contrib/RemoteSensing/`目录下执行。 **Note:** 若没有特殊说明,以下所有命令需要在`PaddleSeg/contrib/RemoteSensing/`目录下执行。
...@@ -22,10 +20,15 @@ PaddlePaddle的安装, 请按照[官网指引](https://paddlepaddle.org.cn/insta ...@@ -22,10 +20,15 @@ PaddlePaddle的安装, 请按照[官网指引](https://paddlepaddle.org.cn/insta
- Python 3.5+ - Python 3.5+
- 其他依赖安装 - 其他依赖安装
通过以下命令安装python包依赖,请确保至少执行过一次以下命令: 通过以下命令安装python包依赖,请确保至少执行过一次以下命令:
``` ```
pip install -r requirements.txt pip install -r requirements.txt
``` ```
另外需要安装gdal. **Note:** 使用pip安装gdal可能出错,推荐使用conda进行安装:
```
conda install gdal
```
## 目录结构说明 ## 目录结构说明
``` ```
...@@ -40,93 +43,107 @@ RemoteSensing # 根目录 ...@@ -40,93 +43,107 @@ RemoteSensing # 根目录
|-- utils # 公用模块 |-- utils # 公用模块
|-- train_demo.py # 训练demo脚本 |-- train_demo.py # 训练demo脚本
|-- predict_demo.py # 预测demo脚本 |-- predict_demo.py # 预测demo脚本
|-- visualize_demo.py # 可视化demo脚本
|-- README.md # 使用手册 |-- README.md # 使用手册
``` ```
## 数据协议
数据集包含原图、标注图及相应的文件列表文件。
参考数据文件结构如下: ## 使用教程
```
./dataset/ # 数据集根目录
|--images # 原图目录
| |--xxx1.npy
| |--...
| └--...
|
|--annotations # 标注图目录
| |--xxx1.png
| |--...
| └--...
|
|--train_list.txt # 训练文件列表文件
|
|--val_list.txt # 验证文件列表文件
|
└--labels.txt # 标签列表
``` 基于L8 SPARCS数据集进行云雪分割,提供数据准备、数据分析、训练、预测、可视化的全流程展示。
其中,相应的文件名可根据需要自行定义。
遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。PaddleSeg以numpy.ndarray数据类型进行图像预处理。为统一接口并方便数据加载,我们采用numpy存储格式`npy`作为原图格式,采用`png`无损压缩格式作为标注图片格式。 ### 1. 数据准备
原图的尺寸应为(h, w, channel),其中h, w为图像的高和宽,channel为图像的通道数。 #### L8 SPARCS数据集
标注图像为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增 [L8 SPARCS](https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation)数据集包含80张 Landsat 8 卫星影像,涵盖10个波段
例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255) 原始标注图片包含7个类别,分别是 “cloud”, “cloud shadow”, “shadow over water”, “snow/ice”, ”water”, “land”和”flooded”
`train_list.txt``val_list.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为标注图像文件相对于dataset的相对路径。如下所示: <p align="center">
``` <img src="./docs/imgs/dataset.png" align="middle"
images/xxx1.npy annotations/xxx1.png </p>
images/xxx2.npy annotations/xxx2.png
...
```
具体要求和如何生成文件列表可参考[文件列表规范](../../docs/data_prepare.md#文件列表) <p align='center'>
L8 SPARCS数据集示例
`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示: </p>
```
labelA
labelB
...
```
由于“flooded”和“shadow over water”2个类别占比仅为1.8%和0.24%,我们将其进行合并,
“flooded”归为“land”,“shadow over water”归为“shadow”,合并后标注包含5个类别。
数值、类别、颜色对应表:
## 快速上手 |Pixel value|Class|Color|
|---|---|---|
|0|cloud|white|
|1|shadow|black|
|2|snow/ice|cyan|
|3|water|blue|
|4|land|grey|
本章节在一个小数据集上展示了如何通过RemoteSensing进行训练预测。 执行以下命令下载并解压经过处理之后的数据集`remote_sensing_seg`
```shell script
### 1. 准备数据集 mkdir dataset && cd dataset
为了快速体验,我们准备了一个小型demo数据集,已位于`RemoteSensing/dataset/demo/`目录下. wget https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip
unzip remote_sensing_seg.zip
cd ..
```
其中`data`目录存放遥感影像,`data_vis`目录存放彩色合成预览图,`mask`目录存放标注图。
对于您自己的数据集,您需要按照上述的数据协议进行格式转换,可分别使用numpy和Pillow库保存遥感数据和标注图片。其中numpy API示例如下: #### 数据协议
```python 对于您自己的数据集,需要按照我们的[数据协议](docs/data_prepare.md)进行数据准备。
import numpy as np
# 将遥感数据保存到以 .npy 为扩展名的文件中 ### 2. 数据校验与分析
# img类型:numpy.ndarray 为确保能正常训练,我们应该先对数据集进行校验。同时,遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的组成、优化模型训练效果,需要对数据进行分析。
np.save(save_path, img) 具体步骤参见[数据校验与分析](docs/data_analyse_and_check.md)章节。
```
### 2. 模型训练 ### 3. 模型训练
#### (1) 设置GPU卡号 #### (1) 设置GPU卡号
```shell script ```shell script
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
``` ```
#### (2) 以U-Net为例,在RemoteSensing目录下运行`train_demo.py`即可开始训练。 #### (2) 以U-Net为例,在RemoteSensing目录下运行`train_demo.py`即可开始训练。
```shell script ```shell script
python train_demo.py --model_type unet --data_dir dataset/demo/ --save_dir saved_model/unet/ --channel 3 --num_epochs 20 python train_demo.py --data_dir dataset/remote_sensing_seg \
--model_type unet \
--save_dir saved_model/remote_sensing_unet \
--num_classes 5 \
--channel 10 \
--lr 0.01 \
--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000 \
--mean 0.14311188522260637 0.14288498042151332 0.14812997807748615 0.16377211813814938 0.2737538363784552 0.2740934379398823 0.27749601919204 0.07767443032935262 0.5694699410349131 0.5549716085195542 \
--std 0.09101632762467489 0.09600705942721106 0.096193618606776 0.10371446736389771 0.10911951586604118 0.11043593115173281 0.12648042598739268 0.027746262217260665 0.06822348076384514 0.062377591186668725 \
--num_epochs 500 \
--train_batch_size 3
``` ```
### 3. 模型预测 训练过程将自动开启边训边评估策略,并使用VisualDL保存训练日志,显示如下:
![](docs/imgs/visualdl.png)
`mIoU`最高的模型将自动保存在`saved_model/remote_sensing_unet/best_model`目录下,最高mIoU=0.7782
### 4. 模型预测
#### (1) 设置GPU卡号 #### (1) 设置GPU卡号
```shell script ```shell script
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
``` ```
#### (2) 以刚训练好的U-Net最优模型为例,在RemoteSensing目录下运行`predict_demo.py`即可开始训练 #### (2) 以刚训练好的U-Net最优模型为例,在RemoteSensing目录下运行`predict_demo.py`即可开始预测
```shell script ```shell script
python predict_demo.py --data_dir dataset/demo/ --file_list val.txt --load_model_dir saved_model/unet/best_model python predict_demo.py --data_dir dataset/remote_sensing_seg/ \
--file_list val.txt \
--load_model_dir saved_model/remote_sensing_unet/best_model \
--save_img_dir saved_model/remote_sensing_unet/best_model/predict \
--color_map 255 255 255 0 0 0 0 255 255 0 0 255 150 150 150
``` ```
### 5. 可视化
我们提供可视化API对预测效果进行直观的展示和对比。每张可视化图片包括彩色合成预览图、标注图、预测结果,使得效果好坏一目了然。
```shell script
python visualize_demo.py --data_dir dataset/remote_sensing_seg/ \
--file_list val.txt \
--pred_dir saved_model/remote_sensing_unet/best_model/predict \
--save_dir saved_model/remote_sensing_unet/best_model/vis_results
````
3张可视化图片示例:
![](docs/imgs/vis.png)
## API说明 ## API说明
......
__background__
cloud
\ No newline at end of file
images/1001.npy annotations/1001.png
images/1002.npy annotations/1002.png
images/1005.npy annotations/1005.png
images/0.npy annotations/0.png
images/1003.npy annotations/1003.png
images/1000.npy annotations/1000.png
images/1004.npy annotations/1004.png
images/100.npy annotations/100.png
images/1.npy annotations/1.png
images/10.npy annotations/10.png
# 数据校验和分析
为确保能正常训练,我们应该先对数据集进行校验。同时,遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的组成、优化模型训练效果,需要对数据进行分析。
接下来以`remote_sensing_seg`数据集为例展示数据校验和分析的全过程。
## Step1 数据校验和初步分析
我们提供数据校验和分析的脚本,帮助您排查基本的数据问题,为如何配置训练参数提供指导。使用方式如下所示:
```shell script
python tools/data_analyse_and_check.py --data_dir 'dataset/remote_sensing_seg/' --num_classes 5
```
参数说明:
- --data_dir: 数据集所在目录
- --num_classes: 数据的类别数
运行后,命令行将显示概览信息,详细的错误信息将以data_analyse_and_check.log文件保存到数据集所在目录。
### 数据校验
数据校验内容如下:
#### 1 列表分割符校验(separator_check)
检查在`train.txt``val.txt``test.txt`列表文件中的分隔符设置是否正确。
#### 2 数据读取校验(imread_check)
检查是否能成功读取`train.txt``val.txt``test.txt`中所有图片。
若不正确返回错误信息。错误可能有多种情况,如数据集路径设置错误、图片损坏等。
#### 3 标注通道数校验(single_channel_label_check)
检查标注图的通道数。正确的标注图应该为单通道图像。
#### 4 标注类别校验(label_class_check)
检查实际标注类别是否和配置参数`num_classes``ignore_index`匹配。
**NOTE:**
标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`
标注类别最好从0开始,否则可能影响精度。
#### 5 图像与标注图尺寸一致性校验(shape_check)
验证图像尺寸和对应标注图尺寸是否一致。
### 数据分析
数据统计分析内容如下:
#### 1 标注类别统计(label_class_statistics)
统计每种类别的像素总数和所占比例。统计结果示例如下:
```
Label class statistics:
(label class, percentage, total pixel number) = [(0, 0.1372, 2194601), (1, 0.0827, 1322459), (2, 0.0179, 286548), (3, 0.1067, 1706810), (4, 0.6556, 10489582)]
```
#### 2 图像尺寸范围统计(img_shape_range_statistics)
统计数据集中图片的最大和最小的宽高。
#### 3 图像通道数统计(img_channels_statistics)
统计数据集中图片的通道个数。
#### 4 数据范围统计(data_range_statistics)
逐通道地统计数据集的数值范围。
#### 5 数据分布统计(data_distribution_statistics)
逐通道地统计数据集分布。并将分布保存为`pkl`文件,方便后续可视化和数据裁剪。
#### 6 归一化系数计算(cal_normalize_coefficient)
逐通道地计算归一化系数mean、standard deviation.
**备注:** 数据分析步骤1\~3在训练集、验证集、测试集上分别进行,步骤4\~6在整个数据集上进行。
## Step2 数据分布可视化,确定数据裁剪范围
### 数据分布可视化
我们提供可视化数据分布脚本,对数据集的数据分布按通道进行可视化。
可视化需要先安装matplotlib:
```shell script
pip install matplotlib
```
使用方式如下:
```shell script
python tools/data_distribution_vis.py --pkl_path 'dataset/remote_sensing_seg/img_pixel_statistics.pkl'
```
参数说明:
- --pkl_path: 数据分布文件保存路径
其中部分通道的可视化效果如下:
![](./imgs/data_distribution.png)
需要注意的是,为便于观察,纵坐标为对数坐标。
### 确定数据裁剪范围
遥感影像数据分布范围广,其中往往存在一些异常值,影响算法对实际数据分布的拟合效果。为更好地对数据进行归一化,需要抑制遥感影像中少量的异常值。
我们可以根据上述的数据分布统计结果来确定数据裁剪范围,并在后续图像预处理过程中对超出范围的像素值通过截断进行校正,从而去除异常值带来的干扰。
例如对于上述数据分布进行逐通道数据裁剪,我们选取的截断范围是:
```
裁剪范围最小值: clip_min_value = [7172, 6561, 5777, 5103, 4291, 4000, 4000, 4232, 6934, 7199]
裁剪范围最大值: clip_max_value = [50000, 50000, 50000, 50000, 50000, 40000, 30000, 18000, 40000, 36000]
```
## Step3 统计裁剪比例、归一化系数
为避免数据裁剪范围选取不当带来的影响,应该统计异常值像素占比,确保受影响的像素比例不要过高。
接着对裁剪后的数据计算归一化系数mean和standard deviation,用于图像预处理中的归一化参数设置。
使用方式如下:
```shell script
python tools/cal_norm_coef.py --data_dir 'dataset/remote_sensing_seg/' \
--pkl_path 'dataset/remote_sensing_seg/img_pixel_statistics.pkl' \
--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000
```
参数说明:
- --data_dir: 数据集路径
- --pkl_path: 数据分布文件保存路径
- --clip_min_value: 数据裁剪范围最小值
- --clip_max_value: 数据裁剪范围最大值
裁剪像素占比统计结果如下:
```
channel 0, the percentage of pixels to be clipped = 0.0005625999999999687
channel 1, the percentage of pixels to be clipped = 0.0011332250000000155
channel 2, the percentage of pixels to be clipped = 0.0008772375000000165
channel 3, the percentage of pixels to be clipped = 0.0013191750000000058
channel 4, the percentage of pixels to be clipped = 0.0012433250000000173
channel 5, the percentage of pixels to be clipped = 7.49875000000122e-05
channel 6, the percentage of pixels to be clipped = 0.0006973750000000001
channel 7, the percentage of pixels to be clipped = 4.950000000003563e-06
channel 8, the percentage of pixels to be clipped = 0.00014873749999999575
channel 9, the percentage of pixels to be clipped = 0.00011173750000004201
```
可看出,被裁剪像素占比均不超过0.2%
裁剪后数据的归一化系数如下:
```
Count the channel-by-channel mean and std of the image:
mean = [0.14311189 0.14288498 0.14812998 0.16377212 0.27375384 0.27409344 0.27749602 0.07767443 0.56946994 0.55497161]
std = [0.09101633 0.09600706 0.09619362 0.10371447 0.10911952 0.11043593 0.12648043 0.02774626 0.06822348 0.06237759]
```
# 数据准备
## 数据协议
数据集包含原图、标注图及相应的文件列表文件。
### 数据格式
遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。
PaddleSeg已兼容以下4种格式图片读取:
- `tif`
- `png`
- `img`
- `npy`
### 原图要求
原图数据的尺寸应为(h, w, channel),其中h, w为图像的高和宽,channel为图像的通道数。
### 标注图要求
标注图像必须为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增。
例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)。
### 文件列表文件
文件列表文件包括`train.txt``val.txt``test.txt``labels.txt`.
`train.txt``val.txt``test.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为标注图像文件相对于dataset的相对路径。如下所示:
```
images/xxx1.tif annotations/xxx1.png
images/xxx2.tif annotations/xxx2.png
...
```
`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
```
labelA
labelB
...
```
## 数据集切分和文件列表生成
数据集切分有2种方式:随机切分和手动切分。对于这2种方式,PaddleSeg均提供了生成文件列表的脚本,您可以按需要选择。
### 1 对数据集按比例随机切分,并生成文件列表
数据文件结构如下:
```
./dataset/ # 数据集根目录
|--images # 原图目录
| |--xxx1.tif
| |--...
| └--...
|
|--annotations # 标注图目录
| |--xxx1.png
| |--...
| └--...
```
其中,相应的文件名可根据需要自行定义。
使用命令如下,支持通过不同的Flags来开启特定功能。
```
python tools/split_dataset_list.py <dataset_root> <images_dir_name> <labels_dir_name> ${FLAGS}
```
参数说明:
- dataset_root: 数据集根目录
- images_dir_name: 原图目录名
- labels_dir_name: 标注图目录名
FLAGS说明:
|FLAG|含义|默认值|参数数目|
|-|-|-|-|
|--split|数据集切分比例|0.7 0.3 0|3|
|--separator|文件列表分隔符|" "|1|
|--format|图片和标签集的数据格式|"tif" "png"|2|
|--label_class|标注类别|'\_\_background\_\_' '\_\_foreground\_\_'|若干|
|--postfix|按文件主名(无扩展名)是否包含指定后缀对图片和标签集进行筛选|"" ""(2个空字符)|2|
运行后将在数据集根目录下生成`train.txt``val.txt``test.txt``labels.txt`.
**Note:** 生成文件列表要求:要么原图和标注图片数量一致,要么只有原图,没有标注图片。若数据集缺少标注图片,将生成不含分隔符和标注图片路径的文件列表。
#### 使用示例
```
python tools/split_dataset_list.py <dataset_root> images annotations --split 0.6 0.2 0.2 --format tif png
```
### 2 已经手工划分好数据集,按照目录结构生成文件列表
数据目录手工划分成如下结构:
```
./dataset/ # 数据集根目录
├── annotations # 标注目录
│   ├── test
│   │   ├── ...
│   │   └── ...
│   ├── train
│   │   ├── ...
│   │   └── ...
│   └── val
│   ├── ...
│   └── ...
└── images # 原图目录
├── test
│   ├── ...
│   └── ...
├── train
│   ├── ...
│   └── ...
└── val
├── ...
└── ...
```
其中,相应的文件名可根据需要自行定义。
使用命令如下,支持通过不同的Flags来开启特定功能。
```
python tools/create_dataset_list.py <dataset_root> ${FLAGS}
```
参数说明:
- dataset_root: 数据集根目录
FLAGS说明:
|FLAG|含义|默认值|参数数目|
|-|-|-|-|
|--separator|文件列表分隔符|" "|1|
|--folder|图片和标签集的文件夹名|"images" "annotations"|2|
|--second_folder|训练/验证/测试集的文件夹名|"train" "val" "test"|若干|
|--format|图片和标签集的数据格式|"tif" "png"|2|
|--label_class|标注类别|'\_\_background\_\_' '\_\_foreground\_\_'|若干|
|--postfix|按文件主名(无扩展名)是否包含指定后缀对图片和标签集进行筛选|"" ""(2个空字符)|2|
运行后将在数据集根目录下生成`train.txt``val.txt``test.txt``labels.txt`.
**Note:** 生成文件列表要求:要么原图和标注图片数量一致,要么只有原图,没有标注图片。若数据集缺少标注图片,将生成不含分隔符和标注图片路径的文件列表。
#### 使用示例
若您已经按上述说明整理好了数据集目录结构,可以运行下面的命令生成文件列表。
```
# 生成文件列表,其分隔符为空格,图片和标签集的数据格式都为png
python tools/create_dataset_list.py <dataset_root> --separator " " --format png png
```
```
# 生成文件列表,其图片和标签集的文件夹名为img和gt,训练和验证集的文件夹名为training和validation,不生成测试集列表
python tools/create_dataset_list.py <dataset_root> \
--folder img gt --second_folder training validation
```
...@@ -22,6 +22,7 @@ from utils import logging ...@@ -22,6 +22,7 @@ from utils import logging
from .base import BaseReader from .base import BaseReader
from .base import get_encoding from .base import get_encoding
from collections import OrderedDict from collections import OrderedDict
from PIL import Image
def read_img(img_path): def read_img(img_path):
...@@ -33,6 +34,8 @@ def read_img(img_path): ...@@ -33,6 +34,8 @@ def read_img(img_path):
raise Exception('Can not open', img_path) raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray() im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0)) return im_data.transpose((1, 2, 0))
elif img_format == 'png':
return np.asarray(Image.open(img_path))
elif ext == '.npy': elif ext == '.npy':
return np.load(img_path) return np.load(img_path)
else: else:
......
...@@ -46,10 +46,10 @@ def parse_args(): ...@@ -46,10 +46,10 @@ def parse_args():
default=['train', 'val', 'test']) default=['train', 'val', 'test'])
parser.add_argument( parser.add_argument(
'--format', '--format',
help='data format of images and labels, default npy, png.', help='data format of images and labels, default tif, png.',
type=str, type=str,
nargs=2, nargs=2,
default=['npy', 'png']) default=['tif', 'png'])
parser.add_argument( parser.add_argument(
'--label_class', '--label_class',
help='label class names', help='label class names',
......
...@@ -51,6 +51,12 @@ def parse_args(): ...@@ -51,6 +51,12 @@ def parse_args():
help='file list separator', help='file list separator',
default=" ", default=" ",
type=str) type=str)
parser.add_argument(
'--ignore_index',
dest='ignore_index',
help='Ignored class index',
default=255,
type=int)
if len(sys.argv) == 1: if len(sys.argv) == 1:
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
...@@ -106,12 +112,12 @@ def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value): ...@@ -106,12 +112,12 @@ def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value):
return means, stds, img_min_value, img_max_value, img_value_num return means, stds, img_min_value, img_max_value, img_value_num
def dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num, def data_distribution_statistics(data_dir, img_value_num, logger):
img_min_value, img_max_value, total_img_num, """count the distribution of image value, value number
logger): """
logger.info("\n-----------------------------\nDataset pixel statistics...") logger.info(
"\n-----------------------------\nThe whole dataset statistics...")
# count the distribution of image value, value number
if not img_value_num: if not img_value_num:
return return
logger.info("\nImage pixel statistics:") logger.info("\nImage pixel statistics:")
...@@ -124,15 +130,21 @@ def dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num, ...@@ -124,15 +130,21 @@ def dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num,
with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f: with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f:
pickle.dump([total_ratio, img_value_num], f) pickle.dump([total_ratio, img_value_num], f)
# print min value, max value
def data_range_statistics(img_min_value, img_max_value, logger):
"""print min value, max value
"""
logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".format( logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".format(
img_min_value, img_max_value)) img_min_value, img_max_value))
# count mean, std
def cal_normalize_coefficient(total_means, total_stds, total_img_num, logger):
"""count mean, std
"""
total_means = total_means / total_img_num total_means = total_means / total_img_num
total_stds = total_stds / total_img_num total_stds = total_stds / total_img_num
print("\nCount the channel-by-channel mean and std of the image:\n" logger.info("\nCount the channel-by-channel mean and std of the image:\n"
"mean = {}\nstd = {}".format(total_means, total_stds)) "mean = {}\nstd = {}".format(total_means, total_stds))
def error_print(str): def error_print(str):
...@@ -160,12 +172,12 @@ def get_img_shape_range(img, max_width, max_height, min_width, min_height): ...@@ -160,12 +172,12 @@ def get_img_shape_range(img, max_width, max_height, min_width, min_height):
return max_width, max_height, min_width, min_height return max_width, max_height, min_width, min_height
def get_image_dim(img, img_dim): def get_img_channel_num(img, img_channels):
"""获取图像的通道数""" """获取图像的通道数"""
img_shape = img.shape img_shape = img.shape
if img_shape[-1] not in img_dim: if img_shape[-1] not in img_channels:
img_dim.append(img_shape[-1]) img_channels.append(img_shape[-1])
return img_dim return img_channels
def is_label_single_channel(label): def is_label_single_channel(label):
...@@ -215,23 +227,16 @@ def ground_truth_check(label, label_path): ...@@ -215,23 +227,16 @@ def ground_truth_check(label, label_path):
return png_format, unique, counts return png_format, unique, counts
def sum_label_check(png_format, label_classes, num_of_each_class, ignore_index, def sum_label_check(label_classes, num_of_each_class, ignore_index, num_classes,
num_classes, png_format_right_num, png_format_wrong_num,
total_label_classes, total_num_of_each_class): total_label_classes, total_num_of_each_class):
""" """
统计所有标注图上的格式、类别和每个类别的像素数 统计所有标注图上的类别和每个类别的像素数
params: params:
png_format: 是否是png格式图片
label_classes: 标注类别 label_classes: 标注类别
num_of_each_class: 各个类别的像素数目 num_of_each_class: 各个类别的像素数目
""" """
is_label_correct = True is_label_correct = True
if png_format:
png_format_right_num += 1
else:
png_format_wrong_num += 1
if ignore_index in label_classes: if ignore_index in label_classes:
label_classes2 = np.delete(label_classes, label_classes2 = np.delete(label_classes,
np.where(label_classes == ignore_index)) np.where(label_classes == ignore_index))
...@@ -251,32 +256,18 @@ def sum_label_check(png_format, label_classes, num_of_each_class, ignore_index, ...@@ -251,32 +256,18 @@ def sum_label_check(png_format, label_classes, num_of_each_class, ignore_index,
add_num.append(num_of_each_class[i]) add_num.append(num_of_each_class[i])
total_num_of_each_class += add_num total_num_of_each_class += add_num
total_label_classes += add_class total_label_classes += add_class
return is_label_correct, png_format_right_num, png_format_wrong_num, total_num_of_each_class, total_label_classes return is_label_correct, total_num_of_each_class, total_label_classes
def label_check_statistics(num_classes, png_format_wrong_image, def label_class_check(num_classes, total_label_classes, total_num_of_each_class,
png_format_right_num, png_format_wrong_num, wrong_labels, logger):
total_label_classes, total_num_of_each_class,
wrong_labels, logger):
""" """
对标注图像进行校验,输出校验结果 检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。
"""
if png_format_wrong_num == 0:
if png_format_right_num:
logger.info(correct_print("label format check"))
else:
logger.info(error_print("label format check"))
logger.info("No label image to check")
return
else:
logger.info(error_print("label format check"))
logger.info(
"total {} label images are png format, {} label images are not png "
"format".format(png_format_right_num, png_format_wrong_num))
if len(png_format_wrong_image) > 0:
for i in png_format_wrong_image:
logger.debug(i)
**NOTE:**
标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。
标注类别最好从0开始,否则可能影响精度。
"""
total_ratio = total_num_of_each_class / sum(total_num_of_each_class) total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
total_ratio = np.around(total_ratio, decimals=4) total_ratio = np.around(total_ratio, decimals=4)
total_nc = sorted( total_nc = sorted(
...@@ -293,9 +284,15 @@ def label_check_statistics(num_classes, png_format_wrong_image, ...@@ -293,9 +284,15 @@ def label_check_statistics(num_classes, png_format_wrong_image,
num_classes - 1)) num_classes - 1))
for i in wrong_labels: for i in wrong_labels:
logger.debug(i) logger.debug(i)
return total_nc
def label_class_statistics(total_nc, logger):
"""
对标注图像进行校验,输出校验结果
"""
logger.info( logger.info(
"\nLabel pixel statistics:\n" "\nLabel class statistics:\n"
"(label class, percentage, total pixel number) = {} ".format(total_nc)) "(label class, percentage, total pixel number) = {} ".format(total_nc))
...@@ -360,9 +357,9 @@ def img_shape_range_statistics(max_width, min_width, max_height, min_height, ...@@ -360,9 +357,9 @@ def img_shape_range_statistics(max_width, min_width, max_height, min_height,
format(max_width, min_width, max_height, min_height)) format(max_width, min_width, max_height, min_height))
def img_dim_statistics(img_dim, logger): def img_channels_statistics(img_channels, logger):
logger.info("\nImage channels statistics\nImage channels = {}".format( logger.info("\nImage channels statistics\nImage channels = {}".format(
np.unique(img_dim))) np.unique(img_channels)))
def data_analyse_and_check(data_dir, num_classes, separator, ignore_index, def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
...@@ -381,14 +378,11 @@ def data_analyse_and_check(data_dir, num_classes, separator, ignore_index, ...@@ -381,14 +378,11 @@ def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
min_height = sys.float_info.max min_height = sys.float_info.max
label_not_single_channel = [] label_not_single_channel = []
shape_unequal_image = [] shape_unequal_image = []
png_format_wrong_image = []
wrong_labels = [] wrong_labels = []
wrong_lines = [] wrong_lines = []
png_format_right_num = 0
png_format_wrong_num = 0
total_label_classes = [] total_label_classes = []
total_num_of_each_class = [] total_num_of_each_class = []
img_dim = [] img_channels = []
with open(file_list, 'r') as fid: with open(file_list, 'r') as fid:
logger.info("\n-----------------------------\nCheck {}...".format( logger.info("\n-----------------------------\nCheck {}...".format(
...@@ -433,12 +427,9 @@ def data_analyse_and_check(data_dir, num_classes, separator, ignore_index, ...@@ -433,12 +427,9 @@ def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
shape_unequal_image.append(line) shape_unequal_image.append(line)
png_format, label_classes, num_of_each_class = ground_truth_check( png_format, label_classes, num_of_each_class = ground_truth_check(
label, label_path) label, label_path)
if not png_format: is_label_correct, total_num_of_each_class, total_label_classes = sum_label_check(
png_format_wrong_image.append(line) label_classes, num_of_each_class, ignore_index,
is_label_correct, png_format_right_num, png_format_wrong_num, total_num_of_each_class, total_label_classes = sum_label_check( num_classes, total_label_classes,
png_format, label_classes, num_of_each_class,
ignore_index, num_classes, png_format_right_num,
png_format_wrong_num, total_label_classes,
total_num_of_each_class) total_num_of_each_class)
if not is_label_correct: if not is_label_correct:
wrong_labels.append(line) wrong_labels.append(line)
...@@ -460,32 +451,35 @@ def data_analyse_and_check(data_dir, num_classes, separator, ignore_index, ...@@ -460,32 +451,35 @@ def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
total_stds += stds total_stds += stds
max_width, max_height, min_width, min_height = get_img_shape_range( max_width, max_height, min_width, min_height = get_img_shape_range(
img, max_width, max_height, min_width, min_height) img, max_width, max_height, min_width, min_height)
img_dim = get_image_dim(img, img_dim) img_channels = get_img_channel_num(img, img_channels)
total_img_num += 1 total_img_num += 1
# data check
separator_check(wrong_lines, file_list, separator, logger) separator_check(wrong_lines, file_list, separator, logger)
imread_check(imread_failed, logger) imread_check(imread_failed, logger)
img_dim_statistics(img_dim, logger)
img_shape_range_statistics(max_width, min_width, max_height,
min_height, logger)
if has_label: if has_label:
single_channel_label_check(label_not_single_channel, logger) single_channel_label_check(label_not_single_channel, logger)
shape_check(shape_unequal_image, logger) shape_check(shape_unequal_image, logger)
label_check_statistics( total_nc = label_class_check(num_classes, total_label_classes,
num_classes, png_format_wrong_image, png_format_right_num, total_num_of_each_class,
png_format_wrong_num, total_label_classes, wrong_labels, logger)
total_num_of_each_class, wrong_labels, logger)
dataset_pixel_statistics(data_dir, total_means, total_stds, img_value_num, # data analyse on train, validation, test set.
img_min_value, img_max_value, total_img_num, img_channels_statistics(img_channels, logger)
logger) img_shape_range_statistics(max_width, min_width, max_height,
min_height, logger)
if has_label:
label_class_statistics(total_nc, logger)
# data analyse on the whole dataset.
data_range_statistics(img_min_value, img_max_value, logger)
data_distribution_statistics(data_dir, img_value_num, logger)
cal_normalize_coefficient(total_means, total_stds, total_img_num, logger)
def main(): def main():
args = parse_args() args = parse_args()
data_dir = args.data_dir data_dir = args.data_dir
ignore_index = 255 ignore_index = args.ignore_index
num_classes = args.num_classes num_classes = args.num_classes
separator = args.separator separator = args.separator
......
...@@ -25,8 +25,10 @@ def parse_args(): ...@@ -25,8 +25,10 @@ def parse_args():
description= description=
'A tool for proportionally randomizing dataset to produce file lists.') 'A tool for proportionally randomizing dataset to produce file lists.')
parser.add_argument('dataset_root', help='the dataset root path', type=str) parser.add_argument('dataset_root', help='the dataset root path', type=str)
parser.add_argument('images', help='the directory name of images', type=str) parser.add_argument(
parser.add_argument('labels', help='the directory name of labels', type=str) 'images_dir_name', help='the directory name of images', type=str)
parser.add_argument(
'labels_dir_name', help='the directory name of labels', type=str)
parser.add_argument( parser.add_argument(
'--split', help='', nargs=3, type=float, default=[0.7, 0.3, 0]) '--split', help='', nargs=3, type=float, default=[0.7, 0.3, 0])
parser.add_argument( parser.add_argument(
...@@ -43,10 +45,10 @@ def parse_args(): ...@@ -43,10 +45,10 @@ def parse_args():
type=str) type=str)
parser.add_argument( parser.add_argument(
'--format', '--format',
help='data format of images and labels, e.g. jpg, npy or png.', help='data format of images and labels, e.g. jpg, tif or png.',
type=str, type=str,
nargs=2, nargs=2,
default=['npy', 'png']) default=['tif', 'png'])
parser.add_argument( parser.add_argument(
'--postfix', '--postfix',
help='postfix of images or labels', help='postfix of images or labels',
...@@ -84,8 +86,8 @@ def generate_list(args): ...@@ -84,8 +86,8 @@ def generate_list(args):
for label_class in args.label_class: for label_class in args.label_class:
f.write(label_class + '\n') f.write(label_class + '\n')
image_dir = os.path.join(dataset_root, args.images) image_dir = os.path.join(dataset_root, args.images_dir_name)
label_dir = os.path.join(dataset_root, args.labels) label_dir = os.path.join(dataset_root, args.labels_dir_name)
image_files = get_files(image_dir, args.format[0], args.postfix[0]) image_files = get_files(image_dir, args.format[0], args.postfix[0])
label_files = get_files(label_dir, args.format[1], args.postfix[1]) label_files = get_files(label_dir, args.format[1], args.postfix[1])
if not image_files: if not image_files:
......
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
from .ops import * from .ops import *
import random import random
import os.path as osp
import numpy as np import numpy as np
from PIL import Image
import cv2 import cv2
from collections import OrderedDict from collections import OrderedDict
from readers.reader import read_img from readers.reader import read_img
...@@ -63,7 +61,7 @@ class Compose: ...@@ -63,7 +61,7 @@ class Compose:
if im is None: if im is None:
raise ValueError('Can\'t read The image file {}!'.format(im)) raise ValueError('Can\'t read The image file {}!'.format(im))
if label is not None: if label is not None:
label = np.asarray(Image.open(label)) label = read_img(label)
for op in self.transforms: for op in self.transforms:
outputs = op(im, im_info, label) outputs = op(im, im_info, label)
......
import os
import os.path as osp
import argparse
from PIL import Image as Image
from models.utils import visualize as vis
def parse_args():
parser = argparse.ArgumentParser(description='RemoteSensing visualization')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='Dataset directory',
default=None,
type=str)
parser.add_argument(
'--file_list',
dest='file_list',
help='The name of file list that need to be visualized',
default=None,
type=str)
parser.add_argument(
'--pred_dir',
dest='pred_dir',
help='Directory for predict results',
default=None,
type=str)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='Save directory for visual results',
default=None,
type=str)
return parser.parse_args()
args = parse_args()
data_dir = args.data_dir
pred_dir = args.pred_dir
save_dir = args.save_dir
file_list = osp.join(data_dir, args.file_list)
if not osp.exists(save_dir):
os.mkdir(save_dir)
with open(file_list) as f:
lines = f.readlines()
for line in lines:
img_list = []
img_line = line.split(' ')[0]
img_name = osp.basename(img_line).replace('data.tif', 'photo.png')
img_path = osp.join(data_dir, 'data_vis', img_name)
img = Image.open(img_path)
img_list.append(img)
print('visualizing {}'.format(img_path))
gt_line = line.split(' ')[1].rstrip('\n')
gt_path = osp.join(data_dir, gt_line)
gt_pil = Image.open(gt_path)
img_list.append(gt_pil)
pred_name = osp.basename(img_line).replace('tif', 'png')
pred_path = osp.join(pred_dir, pred_name)
pred_pil = Image.open(pred_path)
img_list.append(pred_pil)
save_path = osp.join(save_dir, pred_name)
vis.splice_imgs(img_list, save_path)
print('saved in {}'.format(save_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册