提交 9a35bdde 编写于 作者: C chenguowei01

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSeg into develop

# Mac system
.DS_Store
# Pycharm
.idea/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
......
......@@ -121,13 +121,13 @@ def infer(args):
# 图像背景替换
if args.image_path is not None:
if not osp.exists(args.image_path):
raise ('The --image_path is not existed: {}'.format(
raise Exception('The --image_path is not existed: {}'.format(
args.image_path))
if args.background_image_path is None:
raise ('The --background_image_path is not set. Please set it')
raise Exception('The --background_image_path is not set. Please set it')
else:
if not osp.exists(args.background_image_path):
raise ('The --background_image_path is not existed: {}'.format(
raise Exception('The --background_image_path is not existed: {}'.format(
args.background_image_path))
img = cv2.imread(args.image_path)
score_map, im_info = predict(img, model, test_transforms)
......@@ -144,15 +144,15 @@ def infer(args):
is_video_bg = False
if args.background_video_path is not None:
if not osp.exists(args.background_video_path):
raise ('The --background_video_path is not existed: {}'.format(
raise Exception('The --background_video_path is not existed: {}'.format(
args.background_video_path))
is_video_bg = True
elif args.background_image_path is not None:
if not osp.exists(args.background_image_path):
raise ('The --background_image_path is not existed: {}'.format(
raise Exception('The --background_image_path is not existed: {}'.format(
args.background_image_path))
else:
raise (
raise Exception(
'Please offer backgound image or video. You should set --backbground_iamge_paht or --background_video_path'
)
......@@ -162,9 +162,9 @@ def infer(args):
prev_cfd = np.zeros((resize_h, resize_w), np.float32)
is_init = True
if args.video_path is not None:
print('Please waite. It is computing......')
print('Please wait. It is computing......')
if not osp.exists(args.video_path):
raise ('The --video_path is not existed: {}'.format(
raise Exception('The --video_path is not existed: {}'.format(
args.video_path))
cap_video = cv2.VideoCapture(args.video_path)
......
......@@ -109,7 +109,7 @@ def video_infer(args):
fps = cap.get(cv2.CAP_PROP_FPS)
if args.video_path:
print('Please waite. It is computing......')
print('Please wait. It is computing......')
# 用于保存预测结果视频
if not osp.exists(args.save_dir):
os.makedirs(args.save_dir)
......
# PaddleSeg遥感影像分割
遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。
PaddleSeg遥感影像分割涵盖图像预处理、数据增强、模型训练、预测流程,帮助用户利用深度学习技术解决遥感影像分割问题。
PaddleSeg遥感影像分割涵盖数据分析、预处理、数据增强、模型训练、预测等流程,帮助用户利用深度学习技术解决遥感影像分割问题。
## 特点
- 针对遥感数据多通道、分布范围大、分布不均的特点,我们支持多通道训练预测,内置10+多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。
- 针对遥感影像多通道、标注数据稀少的特点,我们支持多通道训练预测,内置10+多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。
- 内置U-Net, HRNet两种主流分割网络,可选择不同的损失函数如Dice Loss, BCE Loss等方式强化小目标和不均衡样本场景下的分割精度。
以下是遥感云检测的示例效果:
- 针对遥感影像分布范围广、分布不均的特点,我们提供数据分析工具,帮助深入了解数据组成、优化模型训练效果。为确保正常训练,我们提供数据校验工具,帮助排查数据问题。
![](./docs/imgs/rs.png)
- 内置U-Net, HRNet两种主流分割网络,可选择不同的损失函数如Dice Loss, BCE Loss等方式强化小目标和不均衡样本场景下的分割精度。
## 前置依赖
**Note:** 若没有特殊说明,以下所有命令需要在`PaddleSeg/contrib/RemoteSensing/`目录下执行。
......@@ -22,10 +20,15 @@ PaddlePaddle的安装, 请按照[官网指引](https://paddlepaddle.org.cn/insta
- Python 3.5+
- 其他依赖安装
通过以下命令安装python包依赖,请确保至少执行过一次以下命令:
```
pip install -r requirements.txt
```
另外需要安装gdal. **Note:** 使用pip安装gdal可能出错,推荐使用conda进行安装:
```
conda install gdal
```
## 目录结构说明
```
......@@ -40,93 +43,107 @@ RemoteSensing # 根目录
|-- utils # 公用模块
|-- train_demo.py # 训练demo脚本
|-- predict_demo.py # 预测demo脚本
|-- visualize_demo.py # 可视化demo脚本
|-- README.md # 使用手册
```
## 数据协议
数据集包含原图、标注图及相应的文件列表文件。
参考数据文件结构如下:
```
./dataset/ # 数据集根目录
|--images # 原图目录
| |--xxx1.npy
| |--...
| └--...
|
|--annotations # 标注图目录
| |--xxx1.png
| |--...
| └--...
|
|--train_list.txt # 训练文件列表文件
|
|--val_list.txt # 验证文件列表文件
|
└--labels.txt # 标签列表
## 使用教程
```
其中,相应的文件名可根据需要自行定义。
基于L8 SPARCS数据集进行云雪分割,提供数据准备、数据分析、训练、预测、可视化的全流程展示。
遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。PaddleSeg以numpy.ndarray数据类型进行图像预处理。为统一接口并方便数据加载,我们采用numpy存储格式`npy`作为原图格式,采用`png`无损压缩格式作为标注图片格式。
原图的尺寸应为(h, w, channel),其中h, w为图像的高和宽,channel为图像的通道数。
标注图像为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增
例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)
### 1. 数据准备
#### L8 SPARCS数据集
[L8 SPARCS](https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation)数据集包含80张 Landsat 8 卫星影像,涵盖10个波段
原始标注图片包含7个类别,分别是 “cloud”, “cloud shadow”, “shadow over water”, “snow/ice”, ”water”, “land”和”flooded”
`train_list.txt``val_list.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为标注图像文件相对于dataset的相对路径。如下所示:
```
images/xxx1.npy annotations/xxx1.png
images/xxx2.npy annotations/xxx2.png
...
```
<p align="center">
<img src="./docs/imgs/dataset.png" align="middle"
</p>
具体要求和如何生成文件列表可参考[文件列表规范](../../docs/data_prepare.md#文件列表)
`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
```
labelA
labelB
...
```
<p align='center'>
L8 SPARCS数据集示例
</p>
由于“flooded”和“shadow over water”2个类别占比仅为1.8%和0.24%,我们将其进行合并,
“flooded”归为“land”,“shadow over water”归为“shadow”,合并后标注包含5个类别。
数值、类别、颜色对应表:
## 快速上手
|Pixel value|Class|Color|
|---|---|---|
|0|cloud|white|
|1|shadow|black|
|2|snow/ice|cyan|
|3|water|blue|
|4|land|grey|
本章节在一个小数据集上展示了如何通过RemoteSensing进行训练预测。
### 1. 准备数据集
为了快速体验,我们准备了一个小型demo数据集,已位于`RemoteSensing/dataset/demo/`目录下.
执行以下命令下载并解压经过处理之后的数据集`remote_sensing_seg`
```shell script
mkdir dataset && cd dataset
wget https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip
unzip remote_sensing_seg.zip
cd ..
```
其中`data`目录存放遥感影像,`data_vis`目录存放彩色合成预览图,`mask`目录存放标注图。
对于您自己的数据集,您需要按照上述的数据协议进行格式转换,可分别使用numpy和Pillow库保存遥感数据和标注图片。其中numpy API示例如下:
```python
import numpy as np
#### 数据协议
对于您自己的数据集,需要按照我们的[数据协议](docs/data_prepare.md)进行数据准备。
# 将遥感数据保存到以 .npy 为扩展名的文件中
# img类型:numpy.ndarray
np.save(save_path, img)
```
### 2. 数据校验与分析
为确保能正常训练,我们应该先对数据集进行校验。同时,遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的组成、优化模型训练效果,需要对数据进行分析。
具体步骤参见[数据校验与分析](docs/data_analyse_and_check.md)章节。
### 2. 模型训练
### 3. 模型训练
#### (1) 设置GPU卡号
```shell script
export CUDA_VISIBLE_DEVICES=0
```
#### (2) 以U-Net为例,在RemoteSensing目录下运行`train_demo.py`即可开始训练。
```shell script
python train_demo.py --model_type unet --data_dir dataset/demo/ --save_dir saved_model/unet/ --channel 3 --num_epochs 20
python train_demo.py --data_dir dataset/remote_sensing_seg \
--model_type unet \
--save_dir saved_model/remote_sensing_unet \
--num_classes 5 \
--channel 10 \
--lr 0.01 \
--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000 \
--mean 0.14311188522260637 0.14288498042151332 0.14812997807748615 0.16377211813814938 0.2737538363784552 0.2740934379398823 0.27749601919204 0.07767443032935262 0.5694699410349131 0.5549716085195542 \
--std 0.09101632762467489 0.09600705942721106 0.096193618606776 0.10371446736389771 0.10911951586604118 0.11043593115173281 0.12648042598739268 0.027746262217260665 0.06822348076384514 0.062377591186668725 \
--num_epochs 500 \
--train_batch_size 3
```
### 3. 模型预测
训练过程将自动开启边训边评估策略,并使用VisualDL保存训练日志,显示如下:
![](docs/imgs/visualdl.png)
`mIoU`最高的模型将自动保存在`saved_model/remote_sensing_unet/best_model`目录下,最高mIoU=0.7782
### 4. 模型预测
#### (1) 设置GPU卡号
```shell script
export CUDA_VISIBLE_DEVICES=0
```
#### (2) 以刚训练好的U-Net最优模型为例,在RemoteSensing目录下运行`predict_demo.py`即可开始训练
#### (2) 以刚训练好的U-Net最优模型为例,在RemoteSensing目录下运行`predict_demo.py`即可开始预测
```shell script
python predict_demo.py --data_dir dataset/demo/ --file_list val.txt --load_model_dir saved_model/unet/best_model
python predict_demo.py --data_dir dataset/remote_sensing_seg/ \
--file_list val.txt \
--load_model_dir saved_model/remote_sensing_unet/best_model \
--save_img_dir saved_model/remote_sensing_unet/best_model/predict \
--color_map 255 255 255 0 0 0 0 255 255 0 0 255 150 150 150
```
### 5. 可视化
我们提供可视化API对预测效果进行直观的展示和对比。每张可视化图片包括彩色合成预览图、标注图、预测结果,使得效果好坏一目了然。
```shell script
python visualize_demo.py --data_dir dataset/remote_sensing_seg/ \
--file_list val.txt \
--pred_dir saved_model/remote_sensing_unet/best_model/predict \
--save_dir saved_model/remote_sensing_unet/best_model/vis_results
````
3张可视化图片示例:
![](docs/imgs/vis.png)
## API说明
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
__background__
cloud
\ No newline at end of file
images/1001.npy annotations/1001.png
images/1002.npy annotations/1002.png
images/1005.npy annotations/1005.png
images/0.npy annotations/0.png
images/1003.npy annotations/1003.png
images/1000.npy annotations/1000.png
images/1004.npy annotations/1004.png
images/100.npy annotations/100.png
images/1.npy annotations/1.png
images/10.npy annotations/10.png
# 数据校验和分析
为确保能正常训练,我们应该先对数据集进行校验。同时,遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的组成、优化模型训练效果,需要对数据进行分析。
接下来以`remote_sensing_seg`数据集为例展示数据校验和分析的全过程。
## Step1 数据校验和初步分析
我们提供数据校验和分析的脚本,帮助您排查基本的数据问题,为如何配置训练参数提供指导。使用方式如下所示:
```shell script
python tools/data_analyse_and_check.py --data_dir 'dataset/remote_sensing_seg/' --num_classes 5
```
参数说明:
- --data_dir: 数据集所在目录
- --num_classes: 数据的类别数
运行后,命令行将显示概览信息,详细的错误信息将以data_analyse_and_check.log文件保存到数据集所在目录。
### 数据校验
数据校验内容如下:
#### 1 列表分割符校验(separator_check)
检查在`train.txt``val.txt``test.txt`列表文件中的分隔符设置是否正确。
#### 2 数据读取校验(imread_check)
检查是否能成功读取`train.txt``val.txt``test.txt`中所有图片。
若不正确返回错误信息。错误可能有多种情况,如数据集路径设置错误、图片损坏等。
#### 3 标注通道数校验(single_channel_label_check)
检查标注图的通道数。正确的标注图应该为单通道图像。
#### 4 标注类别校验(label_class_check)
检查实际标注类别是否和配置参数`num_classes``ignore_index`匹配。
**NOTE:**
标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`
标注类别最好从0开始,否则可能影响精度。
#### 5 图像与标注图尺寸一致性校验(shape_check)
验证图像尺寸和对应标注图尺寸是否一致。
### 数据分析
数据统计分析内容如下:
#### 1 标注类别统计(label_class_statistics)
统计每种类别的像素总数和所占比例。统计结果示例如下:
```
Label class statistics:
(label class, percentage, total pixel number) = [(0, 0.1372, 2194601), (1, 0.0827, 1322459), (2, 0.0179, 286548), (3, 0.1067, 1706810), (4, 0.6556, 10489582)]
```
#### 2 图像尺寸范围统计(img_shape_range_statistics)
统计数据集中图片的最大和最小的宽高。
#### 3 图像通道数统计(img_channels_statistics)
统计数据集中图片的通道个数。
#### 4 数据范围统计(data_range_statistics)
逐通道地统计数据集的数值范围。
#### 5 数据分布统计(data_distribution_statistics)
逐通道地统计数据集分布。并将分布保存为`pkl`文件,方便后续可视化和数据裁剪。
#### 6 归一化系数计算(cal_normalize_coefficient)
逐通道地计算归一化系数mean、standard deviation.
**备注:** 数据分析步骤1\~3在训练集、验证集、测试集上分别进行,步骤4\~6在整个数据集上进行。
## Step2 数据分布可视化,确定数据裁剪范围
### 数据分布可视化
我们提供可视化数据分布脚本,对数据集的数据分布按通道进行可视化。
可视化需要先安装matplotlib:
```shell script
pip install matplotlib
```
使用方式如下:
```shell script
python tools/data_distribution_vis.py --pkl_path 'dataset/remote_sensing_seg/img_pixel_statistics.pkl'
```
参数说明:
- --pkl_path: 数据分布文件保存路径
其中部分通道的可视化效果如下:
![](./imgs/data_distribution.png)
需要注意的是,为便于观察,纵坐标为对数坐标。
### 确定数据裁剪范围
遥感影像数据分布范围广,其中往往存在一些异常值,影响算法对实际数据分布的拟合效果。为更好地对数据进行归一化,需要抑制遥感影像中少量的异常值。
我们可以根据上述的数据分布统计结果来确定数据裁剪范围,并在后续图像预处理过程中对超出范围的像素值通过截断进行校正,从而去除异常值带来的干扰。
例如对于上述数据分布进行逐通道数据裁剪,我们选取的截断范围是:
```
裁剪范围最小值: clip_min_value = [7172, 6561, 5777, 5103, 4291, 4000, 4000, 4232, 6934, 7199]
裁剪范围最大值: clip_max_value = [50000, 50000, 50000, 50000, 50000, 40000, 30000, 18000, 40000, 36000]
```
## Step3 统计裁剪比例、归一化系数
为避免数据裁剪范围选取不当带来的影响,应该统计异常值像素占比,确保受影响的像素比例不要过高。
接着对裁剪后的数据计算归一化系数mean和standard deviation,用于图像预处理中的归一化参数设置。
使用方式如下:
```shell script
python tools/cal_norm_coef.py --data_dir 'dataset/remote_sensing_seg/' \
--pkl_path 'dataset/remote_sensing_seg/img_pixel_statistics.pkl' \
--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000
```
参数说明:
- --data_dir: 数据集路径
- --pkl_path: 数据分布文件保存路径
- --clip_min_value: 数据裁剪范围最小值
- --clip_max_value: 数据裁剪范围最大值
裁剪像素占比统计结果如下:
```
channel 0, the percentage of pixels to be clipped = 0.0005625999999999687
channel 1, the percentage of pixels to be clipped = 0.0011332250000000155
channel 2, the percentage of pixels to be clipped = 0.0008772375000000165
channel 3, the percentage of pixels to be clipped = 0.0013191750000000058
channel 4, the percentage of pixels to be clipped = 0.0012433250000000173
channel 5, the percentage of pixels to be clipped = 7.49875000000122e-05
channel 6, the percentage of pixels to be clipped = 0.0006973750000000001
channel 7, the percentage of pixels to be clipped = 4.950000000003563e-06
channel 8, the percentage of pixels to be clipped = 0.00014873749999999575
channel 9, the percentage of pixels to be clipped = 0.00011173750000004201
```
可看出,被裁剪像素占比均不超过0.2%
裁剪后数据的归一化系数如下:
```
Count the channel-by-channel mean and std of the image:
mean = [0.14311189 0.14288498 0.14812998 0.16377212 0.27375384 0.27409344 0.27749602 0.07767443 0.56946994 0.55497161]
std = [0.09101633 0.09600706 0.09619362 0.10371447 0.10911952 0.11043593 0.12648043 0.02774626 0.06822348 0.06237759]
```
# 数据准备
## 数据协议
数据集包含原图、标注图及相应的文件列表文件。
### 数据格式
遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。
PaddleSeg已兼容以下4种格式图片读取:
- `tif`
- `png`
- `img`
- `npy`
### 原图要求
原图数据的尺寸应为(h, w, channel),其中h, w为图像的高和宽,channel为图像的通道数。
### 标注图要求
标注图像必须为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增。
例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)。
### 文件列表文件
文件列表文件包括`train.txt``val.txt``test.txt``labels.txt`.
`train.txt``val.txt``test.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为标注图像文件相对于dataset的相对路径。如下所示:
```
images/xxx1.tif annotations/xxx1.png
images/xxx2.tif annotations/xxx2.png
...
```
`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
```
labelA
labelB
...
```
## 数据集切分和文件列表生成
数据集切分有2种方式:随机切分和手动切分。对于这2种方式,PaddleSeg均提供了生成文件列表的脚本,您可以按需要选择。
### 1 对数据集按比例随机切分,并生成文件列表
数据文件结构如下:
```
./dataset/ # 数据集根目录
|--images # 原图目录
| |--xxx1.tif
| |--...
| └--...
|
|--annotations # 标注图目录
| |--xxx1.png
| |--...
| └--...
```
其中,相应的文件名可根据需要自行定义。
使用命令如下,支持通过不同的Flags来开启特定功能。
```
python tools/split_dataset_list.py <dataset_root> <images_dir_name> <labels_dir_name> ${FLAGS}
```
参数说明:
- dataset_root: 数据集根目录
- images_dir_name: 原图目录名
- labels_dir_name: 标注图目录名
FLAGS说明:
|FLAG|含义|默认值|参数数目|
|-|-|-|-|
|--split|数据集切分比例|0.7 0.3 0|3|
|--separator|文件列表分隔符|" "|1|
|--format|图片和标签集的数据格式|"tif" "png"|2|
|--label_class|标注类别|'\_\_background\_\_' '\_\_foreground\_\_'|若干|
|--postfix|按文件主名(无扩展名)是否包含指定后缀对图片和标签集进行筛选|"" ""(2个空字符)|2|
运行后将在数据集根目录下生成`train.txt``val.txt``test.txt``labels.txt`.
**Note:** 生成文件列表要求:要么原图和标注图片数量一致,要么只有原图,没有标注图片。若数据集缺少标注图片,将生成不含分隔符和标注图片路径的文件列表。
#### 使用示例
```
python tools/split_dataset_list.py <dataset_root> images annotations --split 0.6 0.2 0.2 --format tif png
```
### 2 已经手工划分好数据集,按照目录结构生成文件列表
数据目录手工划分成如下结构:
```
./dataset/ # 数据集根目录
├── annotations # 标注目录
│   ├── test
│   │   ├── ...
│   │   └── ...
│   ├── train
│   │   ├── ...
│   │   └── ...
│   └── val
│   ├── ...
│   └── ...
└── images # 原图目录
├── test
│   ├── ...
│   └── ...
├── train
│   ├── ...
│   └── ...
└── val
├── ...
└── ...
```
其中,相应的文件名可根据需要自行定义。
使用命令如下,支持通过不同的Flags来开启特定功能。
```
python tools/create_dataset_list.py <dataset_root> ${FLAGS}
```
参数说明:
- dataset_root: 数据集根目录
FLAGS说明:
|FLAG|含义|默认值|参数数目|
|-|-|-|-|
|--separator|文件列表分隔符|" "|1|
|--folder|图片和标签集的文件夹名|"images" "annotations"|2|
|--second_folder|训练/验证/测试集的文件夹名|"train" "val" "test"|若干|
|--format|图片和标签集的数据格式|"tif" "png"|2|
|--label_class|标注类别|'\_\_background\_\_' '\_\_foreground\_\_'|若干|
|--postfix|按文件主名(无扩展名)是否包含指定后缀对图片和标签集进行筛选|"" ""(2个空字符)|2|
运行后将在数据集根目录下生成`train.txt``val.txt``test.txt``labels.txt`.
**Note:** 生成文件列表要求:要么原图和标注图片数量一致,要么只有原图,没有标注图片。若数据集缺少标注图片,将生成不含分隔符和标注图片路径的文件列表。
#### 使用示例
若您已经按上述说明整理好了数据集目录结构,可以运行下面的命令生成文件列表。
```
# 生成文件列表,其分隔符为空格,图片和标签集的数据格式都为png
python tools/create_dataset_list.py <dataset_root> --separator " " --format png png
```
```
# 生成文件列表,其图片和标签集的文件夹名为img和gt,训练和验证集的文件夹名为training和validation,不生成测试集列表
python tools/create_dataset_list.py <dataset_root> \
--folder img gt --second_folder training validation
```
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -125,7 +125,7 @@ class HRNet(BaseModel):
train_reader,
train_batch_size=2,
eval_reader=None,
eval_best_metric='kappa',
eval_best_metric='miou',
save_interval_epochs=1,
log_interval_steps=2,
save_dir='output',
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -127,7 +127,7 @@ class UNet(BaseModel):
train_reader,
train_batch_size=2,
eval_reader=None,
eval_best_metric='kappa',
eval_best_metric='miou',
save_interval_epochs=1,
log_interval_steps=2,
save_dir='output',
......
import os
import os.path as osp
import numpy as np
from PIL import Image as Image
def get_color_map_list(num_classes):
""" Returns the color map for visualizing the segmentation mask,
which can support arbitrary number of classes.
Args:
num_classes: Number of classes
Returns:
The color map
"""
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
return color_map
def splice_imgs(img_list, vis_path):
"""Splice pictures horizontally
"""
IMAGE_WIDTH, IMAGE_HEIGHT = img_list[0].size
padding_width = 20
img_num = len(img_list)
to_image = Image.new('RGB',
(img_num * IMAGE_WIDTH + (img_num - 1) * padding_width,
IMAGE_HEIGHT)) # Create a new picture
padding = Image.new('RGB', (padding_width, IMAGE_HEIGHT), (255, 255, 255))
# Loop through, paste each picture to the corresponding position in order
for i, from_image in enumerate(img_list):
to_image.paste(from_image, (i * (IMAGE_WIDTH + padding_width), 0))
if i < img_num - 1:
to_image.paste(padding,
(i * (IMAGE_WIDTH + padding_width) + IMAGE_WIDTH, 0))
return to_image.save(vis_path)
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,6 +20,7 @@ import numpy as np
from PIL import Image as Image
import argparse
from models import load_model
from models.utils.visualize import get_color_map_list
def parse_args():
......@@ -54,6 +55,13 @@ def parse_args():
help='save directory name of predict results',
default='predict_results',
type=str)
parser.add_argument(
'--color_map',
dest='color_map',
help='color map of predict results',
type=int,
nargs='*',
default=-1)
if len(sys.argv) < 2:
parser.print_help()
sys.exit(1)
......@@ -68,37 +76,41 @@ load_model_dir = args.load_model_dir
save_img_dir = args.save_img_dir
if not osp.exists(save_img_dir):
os.makedirs(save_img_dir)
if args.color_map == -1:
color_map = get_color_map_list(256)
else:
color_map = args.color_map
# predict
model = load_model(load_model_dir)
color_map = [0, 0, 0, 0, 255, 0]
if single_img is not None:
pred = model.predict(single_img)
# 以伪彩色png图片保存预测结果
pred_name = osp.basename(single_img).rstrip('npy') + 'png'
pred_path = osp.join(save_img_dir, pred_name)
pred_name, _ = osp.splitext(osp.basename(single_img))
pred_path = osp.join(save_img_dir, pred_name + '.png')
pred_mask = Image.fromarray(pred['label_map'].astype(np.uint8), mode='P')
pred_mask.putpalette(color_map)
pred_mask.save(pred_path)
print('Predict result is saved in {}'.format(pred_path))
elif (file_list is not None) and (data_dir is not None):
with open(osp.join(data_dir, file_list)) as f:
lines = f.readlines()
for line in lines:
img_path = line.split(' ')[0]
print('Predicting {}'.format(img_path))
img_path_ = osp.join(data_dir, img_path)
pred = model.predict(img_path_)
# 以伪彩色png图片保存预测结果
pred_name = osp.basename(img_path).rstrip('npy') + 'png'
pred_path = osp.join(save_img_dir, pred_name)
pred_name, _ = osp.splitext(osp.basename(img_path))
pred_path = osp.join(save_img_dir, pred_name + '.png')
pred_mask = Image.fromarray(
pred['label_map'].astype(np.uint8), mode='P')
pred_mask.putpalette(color_map)
pred_mask.save(pred_path)
print('Predict result is saved in {}'.format(pred_path))
else:
raise Exception(
'You should either set the parameter single_img, or set the parameters data_dir, file_list.'
'You should either set the parameter single_img, or set the parameters data_dir and file_list.'
)
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,19 +15,39 @@
from __future__ import absolute_import
import os.path as osp
import random
import imghdr
import gdal
import numpy as np
from utils import logging
from .base import BaseReader
from .base import get_encoding
from collections import OrderedDict
from .base import is_pic
from PIL import Image
def read_img(img_path):
img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
dataset = gdal.Open(img_path)
if dataset == None:
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0))
elif img_format == 'png':
return np.asarray(Image.open(img_path))
elif ext == '.npy':
return np.load(img_path)
else:
raise Exception('Not support {} image format!'.format(ext))
class Reader(BaseReader):
"""读取语分分割任务数据集,并对样本进行相应的处理。
"""读取数据集,并对样本进行相应的处理。
Args:
data_dir (str): 数据集所在的目录路径。
file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
label_list (str): 描述数据集包含的类别信息文件路径。
transforms (list): 数据集中每个样本的预处理/增强算子。
num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import os.path as osp
import sys
import argparse
from tqdm import tqdm
import pickle
from data_analyse_and_check import read_img
def parse_args():
parser = argparse.ArgumentParser(
description=
'Compute normalization coefficient and clip percentage before training.'
)
parser.add_argument(
'--data_dir',
dest='data_dir',
help='Dataset directory',
default=None,
type=str)
parser.add_argument(
'--pkl_path',
dest='pkl_path',
help='Path of img_pixel_statistics.pkl',
default=None,
type=str)
parser.add_argument(
'--clip_min_value',
dest='clip_min_value',
help='Min values for clipping data',
nargs='+',
default=None,
type=int)
parser.add_argument(
'--clip_max_value',
dest='clip_max_value',
help='Max values for clipping data',
nargs='+',
default=None,
type=int)
parser.add_argument(
'--separator',
dest='separator',
help='file list separator',
default=" ",
type=str)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def compute_single_img(img, clip_min_value, clip_max_value):
channel = img.shape[2]
means = np.zeros(channel)
stds = np.zeros(channel)
for k in range(channel):
if clip_max_value != [] and clip_min_value != []:
np.clip(
img[:, :, k],
clip_min_value[k],
clip_max_value[k],
out=img[:, :, k])
# Rescaling (min-max normalization)
range_value = [
clip_max_value[i] - clip_min_value[i]
for i in range(len(clip_max_value))
]
img_k = (img[:, :, k].astype(np.float32, copy=False) -
clip_min_value[k]) / range_value[k]
else:
img_k = img[:, :, k]
# count mean, std
means[k] = np.mean(img_k)
stds[k] = np.std(img_k)
return means, stds
def cal_normalize_coefficient(data_dir, separator, clip_min_value,
clip_max_value):
train_file_list = osp.join(data_dir, 'train.txt')
val_file_list = osp.join(data_dir, 'val.txt')
test_file_list = osp.join(data_dir, 'test.txt')
total_img_num = 0
for file_list in [train_file_list, val_file_list, test_file_list]:
with open(file_list, 'r') as fid:
print("\n-----------------------------\nCheck {}...".format(
file_list))
lines = fid.readlines()
if not lines:
print("File list is empty!")
continue
for line in tqdm(lines):
line = line.strip()
parts = line.split(separator)
img_name, grt_name = parts[0], parts[1]
img_path = os.path.join(data_dir, img_name)
img = read_img(img_path)
if total_img_num == 0:
channel = img.shape[2]
total_means = np.zeros(channel)
total_stds = np.zeros(channel)
means, stds = compute_single_img(img, clip_min_value,
clip_max_value)
total_means += means
total_stds += stds
total_img_num += 1
# count mean, std
total_means = total_means / total_img_num
total_stds = total_stds / total_img_num
print("\nCount the channel-by-channel mean and std of the image:\n"
"mean = {}\nstd = {}".format(total_means, total_stds))
def cal_clip_percentage(pkl_path, clip_min_value, clip_max_value):
"""
Calculate the percentage of pixels to be clipped
"""
with open(pkl_path, 'rb') as f:
percentage, img_value_num = pickle.load(f)
for k in range(len(img_value_num)):
range_pixel = 0
for i, element in enumerate(img_value_num[k]):
if clip_min_value[k] <= i <= clip_max_value[k]:
range_pixel += element
sum_pixel = sum(img_value_num[k])
print('channel {}, the percentage of pixels to be clipped = {}'.format(
k, 1 - range_pixel / sum_pixel))
def main():
args = parse_args()
data_dir = args.data_dir
separator = args.separator
clip_min_value = args.clip_min_value
clip_max_value = args.clip_max_value
pkl_path = args.pkl_path
cal_normalize_coefficient(data_dir, separator, clip_min_value,
clip_max_value)
cal_clip_percentage(pkl_path, clip_min_value, clip_max_value)
if __name__ == "__main__":
main()
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -46,10 +46,10 @@ def parse_args():
default=['train', 'val', 'test'])
parser.add_argument(
'--format',
help='data format of images and labels, default npy, png.',
help='data format of images and labels, default tif, png.',
type=str,
nargs=2,
default=['npy', 'png'])
default=['tif', 'png'])
parser.add_argument(
'--label_class',
help='label class names',
......
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import os.path as osp
import sys
import argparse
from PIL import Image
from tqdm import tqdm
import imghdr
import logging
import pickle
import gdal
def parse_args():
parser = argparse.ArgumentParser(
description='Data analyse and data check before training.')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='Dataset directory',
default=None,
type=str)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
default=None,
type=int)
parser.add_argument(
'--separator',
dest='separator',
help='file list separator',
default=" ",
type=str)
parser.add_argument(
'--ignore_index',
dest='ignore_index',
help='Ignored class index',
default=255,
type=int)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def read_img(img_path):
img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
dataset = gdal.Open(img_path)
if dataset == None:
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0))
elif ext == '.npy':
return np.load(img_path)
else:
raise Exception('Not support {} image format!'.format(ext))
def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value):
channel = img.shape[2]
means = np.zeros(channel)
stds = np.zeros(channel)
for k in range(channel):
img_k = img[:, :, k]
# count mean, std
means[k] = np.mean(img_k)
stds[k] = np.std(img_k)
# count min, max
min_value = np.min(img_k)
max_value = np.max(img_k)
if img_max_value[k] < max_value:
img_max_value[k] = max_value
if img_min_value[k] > min_value:
img_min_value[k] = min_value
# count the distribution of image value, value number
unique, counts = np.unique(img_k, return_counts=True)
add_num = []
max_unique = np.max(unique)
add_len = max_unique - len(img_value_num[k]) + 1
if add_len > 0:
img_value_num[k] += ([0] * add_len)
for i in range(len(unique)):
value = unique[i]
img_value_num[k][value] += counts[i]
img_value_num[k] += add_num
return means, stds, img_min_value, img_max_value, img_value_num
def data_distribution_statistics(data_dir, img_value_num, logger):
"""count the distribution of image value, value number
"""
logger.info(
"\n-----------------------------\nThe whole dataset statistics...")
if not img_value_num:
return
logger.info("\nImage pixel statistics:")
total_ratio = []
[total_ratio.append([]) for i in range(len(img_value_num))]
for k in range(len(img_value_num)):
total_num = sum(img_value_num[k])
total_ratio[k] = [i / total_num for i in img_value_num[k]]
total_ratio[k] = np.around(total_ratio[k], decimals=4)
with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f:
pickle.dump([total_ratio, img_value_num], f)
def data_range_statistics(img_min_value, img_max_value, logger):
"""print min value, max value
"""
logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".format(
img_min_value, img_max_value))
def cal_normalize_coefficient(total_means, total_stds, total_img_num, logger):
"""count mean, std
"""
total_means = total_means / total_img_num
total_stds = total_stds / total_img_num
logger.info("\nCount the channel-by-channel mean and std of the image:\n"
"mean = {}\nstd = {}".format(total_means, total_stds))
def error_print(str):
return "".join(["\nNOT PASS ", str])
def correct_print(str):
return "".join(["\nPASS ", str])
def pil_imread(file_path):
"""read pseudo-color label"""
im = Image.open(file_path)
return np.asarray(im)
def get_img_shape_range(img, max_width, max_height, min_width, min_height):
"""获取图片最大和最小宽高"""
img_shape = img.shape
height, width = img_shape[0], img_shape[1]
max_height = max(height, max_height)
max_width = max(width, max_width)
min_height = min(height, min_height)
min_width = min(width, min_width)
return max_width, max_height, min_width, min_height
def get_img_channel_num(img, img_channels):
"""获取图像的通道数"""
img_shape = img.shape
if img_shape[-1] not in img_channels:
img_channels.append(img_shape[-1])
return img_channels
def is_label_single_channel(label):
"""判断标签是否为灰度图"""
label_shape = label.shape
if len(label_shape) == 2:
return True
else:
return False
def image_label_shape_check(img, label):
"""
验证图像和标注的大小是否匹配
"""
flag = True
img_height = img.shape[0]
img_width = img.shape[1]
label_height = label.shape[0]
label_width = label.shape[1]
if img_height != label_height or img_width != label_width:
flag = False
return flag
def ground_truth_check(label, label_path):
"""
验证标注图像的格式
统计标注图类别和像素数
params:
label: 标注图
label_path: 标注图路径
return:
png_format: 返回是否是png格式图片
unique: 返回标注类别
counts: 返回标注的像素数
"""
if imghdr.what(label_path) == "png":
png_format = True
else:
png_format = False
unique, counts = np.unique(label, return_counts=True)
return png_format, unique, counts
def sum_label_check(label_classes, num_of_each_class, ignore_index, num_classes,
total_label_classes, total_num_of_each_class):
"""
统计所有标注图上的类别和每个类别的像素数
params:
label_classes: 标注类别
num_of_each_class: 各个类别的像素数目
"""
is_label_correct = True
if ignore_index in label_classes:
label_classes2 = np.delete(label_classes,
np.where(label_classes == ignore_index))
else:
label_classes2 = label_classes
if min(label_classes2) < 0 or max(label_classes2) > num_classes - 1:
is_label_correct = False
add_class = []
add_num = []
for i in range(len(label_classes)):
gi = label_classes[i]
if gi in total_label_classes:
j = total_label_classes.index(gi)
total_num_of_each_class[j] += num_of_each_class[i]
else:
add_class.append(gi)
add_num.append(num_of_each_class[i])
total_num_of_each_class += add_num
total_label_classes += add_class
return is_label_correct, total_num_of_each_class, total_label_classes
def label_class_check(num_classes, total_label_classes, total_num_of_each_class,
wrong_labels, logger):
"""
检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。
**NOTE:**
标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。
标注类别最好从0开始,否则可能影响精度。
"""
total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
total_ratio = np.around(total_ratio, decimals=4)
total_nc = sorted(
zip(total_label_classes, total_ratio, total_num_of_each_class))
if len(wrong_labels) == 0 and not total_nc[0][0]:
logger.info(correct_print("label class check!"))
else:
logger.info(error_print("label class check!"))
if total_nc[0][0]:
logger.info("Warning: label classes should start from 0")
if len(wrong_labels) > 0:
logger.info(
"fatal error: label class is out of range [0, {}]".format(
num_classes - 1))
for i in wrong_labels:
logger.debug(i)
return total_nc
def label_class_statistics(total_nc, logger):
"""
对标注图像进行校验,输出校验结果
"""
logger.info(
"\nLabel class statistics:\n"
"(label class, percentage, total pixel number) = {} ".format(total_nc))
def shape_check(shape_unequal_image, logger):
"""输出shape校验结果"""
if len(shape_unequal_image) == 0:
logger.info(correct_print("shape check"))
logger.info("All images are the same shape as the labels")
else:
logger.info(error_print("shape check"))
logger.info(
"Some images are not the same shape as the labels as follow: ")
for i in shape_unequal_image:
logger.debug(i)
def separator_check(wrong_lines, file_list, separator, logger):
"""检查分割符是否复合要求"""
if len(wrong_lines) == 0:
logger.info(
correct_print(
file_list.split(os.sep)[-1] + " DATASET.separator check"))
else:
logger.info(
error_print(
file_list.split(os.sep)[-1] + " DATASET.separator check"))
logger.info(
"The following list is not separated by {}".format(separator))
for i in wrong_lines:
logger.debug(i)
def imread_check(imread_failed, logger):
if len(imread_failed) == 0:
logger.info(correct_print("dataset reading check"))
logger.info("All images can be read successfully")
else:
logger.info(error_print("dataset reading check"))
logger.info("Failed to read {} images".format(len(imread_failed)))
for i in imread_failed:
logger.debug(i)
def single_channel_label_check(label_not_single_channel, logger):
if len(label_not_single_channel) == 0:
logger.info(correct_print("label single_channel check"))
logger.info("All label images are single_channel")
else:
logger.info(error_print("label single_channel check"))
logger.info(
"{} label images are not single_channel\nLabel pixel statistics may be insignificant"
.format(len(label_not_single_channel)))
for i in label_not_single_channel:
logger.debug(i)
def img_shape_range_statistics(max_width, min_width, max_height, min_height,
logger):
logger.info("\nImage size statistics:")
logger.info(
"max width = {} min width = {} max height = {} min height = {}".
format(max_width, min_width, max_height, min_height))
def img_channels_statistics(img_channels, logger):
logger.info("\nImage channels statistics\nImage channels = {}".format(
np.unique(img_channels)))
def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
logger):
train_file_list = osp.join(data_dir, 'train.txt')
val_file_list = osp.join(data_dir, 'val.txt')
test_file_list = osp.join(data_dir, 'test.txt')
total_img_num = 0
has_label = False
for file_list in [train_file_list, val_file_list, test_file_list]:
# initialization
imread_failed = []
max_width = 0
max_height = 0
min_width = sys.float_info.max
min_height = sys.float_info.max
label_not_single_channel = []
shape_unequal_image = []
wrong_labels = []
wrong_lines = []
total_label_classes = []
total_num_of_each_class = []
img_channels = []
with open(file_list, 'r') as fid:
logger.info("\n-----------------------------\nCheck {}...".format(
file_list))
lines = fid.readlines()
if not lines:
logger.info("File list is empty!")
continue
for line in tqdm(lines):
line = line.strip()
parts = line.split(separator)
if len(parts) == 1:
if file_list == train_file_list or file_list == val_file_list:
logger.info("Train or val list must have labels!")
break
img_name = parts
img_path = os.path.join(data_dir, img_name[0])
try:
img = read_img(img_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
elif len(parts) == 2:
has_label = True
img_name, label_name = parts[0], parts[1]
img_path = os.path.join(data_dir, img_name)
label_path = os.path.join(data_dir, label_name)
try:
img = read_img(img_path)
label = pil_imread(label_path)
except Exception as e:
imread_failed.append((line, str(e)))
continue
is_single_channel = is_label_single_channel(label)
if not is_single_channel:
label_not_single_channel.append(line)
continue
is_equal_img_label_shape = image_label_shape_check(
img, label)
if not is_equal_img_label_shape:
shape_unequal_image.append(line)
png_format, label_classes, num_of_each_class = ground_truth_check(
label, label_path)
is_label_correct, total_num_of_each_class, total_label_classes = sum_label_check(
label_classes, num_of_each_class, ignore_index,
num_classes, total_label_classes,
total_num_of_each_class)
if not is_label_correct:
wrong_labels.append(line)
else:
wrong_lines.append(lines)
continue
if total_img_num == 0:
channel = img.shape[2]
total_means = np.zeros(channel)
total_stds = np.zeros(channel)
img_min_value = [sys.float_info.max] * channel
img_max_value = [0] * channel
img_value_num = []
[img_value_num.append([]) for i in range(channel)]
means, stds, img_min_value, img_max_value, img_value_num = img_pixel_statistics(
img, img_value_num, img_min_value, img_max_value)
total_means += means
total_stds += stds
max_width, max_height, min_width, min_height = get_img_shape_range(
img, max_width, max_height, min_width, min_height)
img_channels = get_img_channel_num(img, img_channels)
total_img_num += 1
# data check
separator_check(wrong_lines, file_list, separator, logger)
imread_check(imread_failed, logger)
if has_label:
single_channel_label_check(label_not_single_channel, logger)
shape_check(shape_unequal_image, logger)
total_nc = label_class_check(num_classes, total_label_classes,
total_num_of_each_class,
wrong_labels, logger)
# data analyse on train, validation, test set.
img_channels_statistics(img_channels, logger)
img_shape_range_statistics(max_width, min_width, max_height,
min_height, logger)
if has_label:
label_class_statistics(total_nc, logger)
# data analyse on the whole dataset.
data_range_statistics(img_min_value, img_max_value, logger)
data_distribution_statistics(data_dir, img_value_num, logger)
cal_normalize_coefficient(total_means, total_stds, total_img_num, logger)
def main():
args = parse_args()
data_dir = args.data_dir
ignore_index = args.ignore_index
num_classes = args.num_classes
separator = args.separator
logger = logging.getLogger()
logger.setLevel('DEBUG')
BASIC_FORMAT = "%(message)s"
formatter = logging.Formatter(BASIC_FORMAT)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
sh.setLevel('INFO')
th = logging.FileHandler(
os.path.join(data_dir, 'data_analyse_and_check.log'), 'w')
th.setFormatter(formatter)
logger.addHandler(sh)
logger.addHandler(th)
data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
logger)
print("\nDetailed error information can be viewed in {}.".format(
os.path.join(data_dir, 'data_analyse_and_check.log')))
if __name__ == "__main__":
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import sys
import argparse
import matplotlib.pyplot as plt
def parse_args():
parser = argparse.ArgumentParser(
description='Visualize data distribution before training.')
parser.add_argument(
'--pkl_path',
dest='pkl_path',
help='Path of img_pixel_statistics.pkl',
default=None,
type=str)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
path = args.pkl_path
with open(path, 'rb') as f:
percentage, img_value_num = pickle.load(f)
for k in range(len(img_value_num)):
print('channel = {}'.format(k))
plt.bar(
list(range(len(img_value_num[k]))),
img_value_num[k],
width=1,
log=True)
plt.xlabel('image value')
plt.ylabel('number')
plt.title('channel={}'.format(k))
plt.show()
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -25,8 +25,10 @@ def parse_args():
description=
'A tool for proportionally randomizing dataset to produce file lists.')
parser.add_argument('dataset_root', help='the dataset root path', type=str)
parser.add_argument('images', help='the directory name of images', type=str)
parser.add_argument('labels', help='the directory name of labels', type=str)
parser.add_argument(
'images_dir_name', help='the directory name of images', type=str)
parser.add_argument(
'labels_dir_name', help='the directory name of labels', type=str)
parser.add_argument(
'--split', help='', nargs=3, type=float, default=[0.7, 0.3, 0])
parser.add_argument(
......@@ -43,10 +45,10 @@ def parse_args():
type=str)
parser.add_argument(
'--format',
help='data format of images and labels, e.g. jpg, npy or png.',
help='data format of images and labels, e.g. jpg, tif or png.',
type=str,
nargs=2,
default=['npy', 'png'])
default=['tif', 'png'])
parser.add_argument(
'--postfix',
help='postfix of images or labels',
......@@ -84,8 +86,8 @@ def generate_list(args):
for label_class in args.label_class:
f.write(label_class + '\n')
image_dir = os.path.join(dataset_root, args.images)
label_dir = os.path.join(dataset_root, args.labels)
image_dir = os.path.join(dataset_root, args.images_dir_name)
label_dir = os.path.join(dataset_root, args.labels_dir_name)
image_files = get_files(image_dir, args.format[0], args.postfix[0])
label_files = get_files(label_dir, args.format[1], args.postfix[1])
if not image_files:
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -40,12 +40,46 @@ def parse_args():
help='model save directory',
default=None,
type=str)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
default=None,
type=int)
parser.add_argument(
'--channel',
dest='channel',
help='number of data channel',
default=3,
type=int)
parser.add_argument(
'--clip_min_value',
dest='clip_min_value',
help='Min values for clipping data',
nargs='+',
default=None,
type=int)
parser.add_argument(
'--clip_max_value',
dest='clip_max_value',
help='Max values for clipping data',
nargs='+',
default=None,
type=int)
parser.add_argument(
'--mean',
dest='mean',
help='Data means',
nargs='+',
default=None,
type=float)
parser.add_argument(
'--std',
dest='std',
help='Data standard deviation',
nargs='+',
default=None,
type=float)
parser.add_argument(
'--num_epochs',
dest='num_epochs',
......@@ -66,15 +100,32 @@ def parse_args():
args = parse_args()
data_dir = args.data_dir
save_dir = args.save_dir
num_classes = args.num_classes
channel = args.channel
clip_min_value = args.clip_min_value
clip_max_value = args.clip_max_value
mean = args.mean
std = args.std
num_epochs = args.num_epochs
train_batch_size = args.train_batch_size
lr = args.lr
# 定义训练和验证时的transforms
train_transforms = T.Compose([T.RandomHorizontalFlip(0.5), T.Normalize()])
train_transforms = T.Compose([
T.RandomVerticalFlip(0.5),
T.RandomHorizontalFlip(0.5),
T.ResizeStepScaling(0.5, 2.0, 0.25),
T.RandomPaddingCrop(1000),
T.Clip(min_val=clip_min_value, max_val=clip_max_value),
T.Normalize(
min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
])
eval_transforms = T.Compose([T.Normalize()])
eval_transforms = T.Compose([
T.Clip(min_val=clip_min_value, max_val=clip_max_value),
T.Normalize(
min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
])
train_list = osp.join(data_dir, 'train.txt')
val_list = osp.join(data_dir, 'val.txt')
......@@ -95,17 +146,9 @@ eval_reader = Reader(
transforms=eval_transforms)
if args.model_type == 'unet':
model = UNet(
num_classes=2,
input_channel=channel,
use_bce_loss=True,
use_dice_loss=True)
model = UNet(num_classes=num_classes, input_channel=channel)
elif args.model_type == 'hrnet':
model = HRNet(
num_classes=2,
input_channel=channel,
use_bce_loss=True,
use_dice_loss=True)
model = HRNet(num_classes=num_classes, input_channel=channel)
else:
raise ValueError(
"--model_type: {} is set wrong, it shold be one of ('unet', "
......@@ -116,6 +159,7 @@ model.train(
train_reader=train_reader,
train_batch_size=train_batch_size,
eval_reader=eval_reader,
eval_best_metric='miou',
save_interval_epochs=5,
log_interval_steps=10,
save_dir=save_dir,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,11 +15,10 @@
from .ops import *
import random
import os.path as osp
import numpy as np
from PIL import Image
import cv2
from collections import OrderedDict
from readers.reader import read_img
class Compose:
......@@ -58,11 +57,11 @@ class Compose:
if im_info is None:
im_info = dict()
im = np.load(im)
im = read_img(im)
if im is None:
raise ValueError('Can\'t read The image file {}!'.format(im))
if label is not None:
label = np.asarray(Image.open(label))
label = read_img(label)
for op in self.transforms:
outputs = op(im, im_info, label)
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
import os
import os.path as osp
import argparse
from PIL import Image as Image
from models.utils import visualize as vis
def parse_args():
parser = argparse.ArgumentParser(description='RemoteSensing visualization')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='Dataset directory',
default=None,
type=str)
parser.add_argument(
'--file_list',
dest='file_list',
help='The name of file list that need to be visualized',
default=None,
type=str)
parser.add_argument(
'--pred_dir',
dest='pred_dir',
help='Directory for predict results',
default=None,
type=str)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='Save directory for visual results',
default=None,
type=str)
return parser.parse_args()
args = parse_args()
data_dir = args.data_dir
pred_dir = args.pred_dir
save_dir = args.save_dir
file_list = osp.join(data_dir, args.file_list)
if not osp.exists(save_dir):
os.mkdir(save_dir)
with open(file_list) as f:
lines = f.readlines()
for line in lines:
img_list = []
img_line = line.split(' ')[0]
img_name = osp.basename(img_line).replace('data.tif', 'photo.png')
img_path = osp.join(data_dir, 'data_vis', img_name)
img = Image.open(img_path)
img_list.append(img)
print('visualizing {}'.format(img_path))
gt_line = line.split(' ')[1].rstrip('\n')
gt_path = osp.join(data_dir, gt_line)
gt_pil = Image.open(gt_path)
img_list.append(gt_pil)
pred_name = osp.basename(img_line).replace('tif', 'png')
pred_path = osp.join(pred_dir, pred_name)
pred_pil = Image.open(pred_path)
img_list.append(pred_pil)
save_path = osp.join(save_dir, pred_name)
vis.splice_imgs(img_list, save_path)
print('saved in {}'.format(save_path))
......@@ -79,7 +79,7 @@ DEPLOY:
### 5.2 执行预测程序
在终端输入以下命令进行预测:
```bash
python infer.py --conf=/path/to/deploy.yaml --input_dir/path/to/images_directory
python infer.py --conf=/path/to/deploy.yaml --input_dir=/path/to/images_directory
```
参数说明如下:
......
......@@ -115,9 +115,10 @@ class ImageReader:
# image processing thread worker
def process_worker(self, imgs, idx, use_pr=False):
image_path = imgs[idx]
im = cv2.imread(image_path, -1)
if len(im.shape) == 2:
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
cv2_imread_flag = cv2.IMREAD_COLOR
if self.config.channels == 4:
cv2_imread_flag = cv2.IMREAD_UNCHANGED
im = cv2.imread(image_path, cv2_imread_flag)
channels = im.shape[2]
if channels != 3 and channels != 4:
print("Only support rgb(gray) or rgba image.")
......@@ -133,8 +134,10 @@ class ImageReader:
# if use models with no pre-processing/post-processing op optimizations
if not use_pr:
im_mean = np.array(self.config.mean).reshape((3, 1, 1))
im_std = np.array(self.config.std).reshape((3, 1, 1))
im_mean = np.array(self.config.mean).reshape((self.config.channels,
1, 1))
im_std = np.array(self.config.std).reshape((self.config.channels, 1,
1))
# HWC -> CHW, don't use transpose((2, 0, 1))
im = im.swapaxes(1, 2)
im = im.swapaxes(0, 1)
......
......@@ -4,15 +4,16 @@
Lovasz loss基于子模损失(submodular losses)的凸Lovasz扩展,对神经网络的mean IoU损失进行优化。Lovasz loss根据分割目标的类别数量可分为两种:lovasz hinge loss和lovasz softmax loss. 其中lovasz hinge loss适用于二分类问题,lovasz softmax loss适用于多分类问题。该工作发表在CVPR 2018上,可点击[参考文献](#参考文献)查看具体原理。
## Lovasz hinge loss
### 使用指南
## Lovasz loss使用指南
接下来介绍如何使用lovasz loss进行训练。需要注意的是,通常的直接训练方式并一定管用,我们推荐另外2种训练方式:
- (1)与softmax loss或bce loss(binary cross-entropy loss)加权结合使用。
- (2)先使用softmax loss或bec loss进行训练,再使用lovasz softmax loss或lovasz hinge loss进行finetuning.
PaddleSeg通过`cfg.SOLVER.LOSS`参数可以选择训练时的损失函数,
`cfg.SOLVER.LOSS=['lovasz_hinge_loss','bce_loss']`将指定训练loss为`lovasz hinge loss``bce loss`(binary cross-entropy loss)的组合。
配置lovasz loss仅需要设定2个参数:
Lovasz hinge loss有3种使用方式:(1)直接训练使用。(2)bce loss结合使用。(3)先使用bec loss进行训练,再使用lovasz hinge loss进行finetuning. 第1种方式不一定达到理想效果,推荐使用后两种方式。本文以第2种方式为例
首先通过`cfg.SOLVER.LOSS`参数选择训练时的损失函数, 例如`cfg.SOLVER.LOSS=['lovasz_hinge_loss','bce_loss']`将指定训练loss为lovasz hinge loss与bce loss的组合。`cfg.SOLVER.LOSS=['lovasz_softmax_loss','softmax_loss']`将指定训练loss为lovasz softmax loss与softmax loss的组合
同时,也可以通过`cfg.SOLVER.LOSS_WEIGHT`参数对不同loss进行权重配比,灵活运用于训练调参。如下所示
其次,也可以通过`cfg.SOLVER.LOSS_WEIGHT`参数对不同loss进行权重配比,从而灵活地进行训练调参。Lovasz hinge loss配置位于`PaddleSeg/configs/lovasz_hinge_deeplabv3p_mobilenet_road.yaml`,如下所示:
```yaml
SOLVER:
LOSS: ["lovasz_hinge_loss","bce_loss"]
......@@ -21,15 +22,24 @@ SOLVER:
BCE_LOSS: 0.5
```
### 实验对比
Lovasz softmax loss配置位于`PaddleSeg/configs/lovasz_softmax_deeplabv3p_mobilenet_pascal.yaml`,如下所示:
```yaml
SOLVER:
LOSS: ["lovasz_softmax_loss","softmax_loss"]
LOSS_WEIGHT:
LOVASZ_SOFTMAX_LOSS: 0.2
SOFTMAX_LOSS: 0.8
```
## Lovasz hinge loss实验对比
我们以道路提取任务为例应用lovasz hinge loss.
基于MiniDeepGlobeRoadExtraction数据集与bce loss进行了实验对比。
该数据集来源于DeepGlobe比赛的Road Extraction单项,训练数据道路占比为:4.5%. 如下为其图片样例
该数据集来源于DeepGlobe比赛的Road Extraction单项,训练数据道路占比为:4.5%. 道路在整张图片中的比例很小,是典型的类别不均衡场景。图片样例如下
<p align="center">
<img src="./imgs/deepglobe.png" hspace='10'/> <br />
</p>
可以看出道路在整张图片中的比例很小。
为进行快速体验,这里使用DeepLabv3+模型,backbone为MobileNetV2.
......@@ -62,33 +72,21 @@ python pdseg/eval.py --cfg ./configs/lovasz_hinge_deeplabv3p_mobilenet_road.yaml
* 结果比较
lovasz hinge loss + bce loss和softmax loss的对比结果如下图所示。
lovasz hinge loss + bce loss和softmax loss的mIoU曲线如下图所示。
<p align="center">
<img src="./imgs/lovasz-hinge.png" hspace='10'/> <br />
</p>
图中蓝色曲线为lovasz hinge loss + bce loss,最高mIoU为76.2%,橙色曲线为softmax loss, 最高mIoU为73.44%,相比提升2.76个百分点。
分割效果如下:
<p align="center">
<img src="./imgs/lovasz-hinge-vis.png" hspace='10'/> <br />
</p>
可以看出,softmax loss训练的结果中道路并不连续,主干道部分缺失尤为严重。而lovasz loss训练的结果提升显著,主干道并无缺失,连小路也基本连续。
## Lovasz softmax loss
### 使用指南
PaddleSeg通过`cfg.SOLVER.LOSS`参数可以选择训练时的损失函数,
`cfg.SOLVER.LOSS=['lovasz_softmax_loss','softmax_loss']`将指定训练loss为`lovasz softmax loss``softmax loss`的组合。
Lovasz softmax loss有3种使用方式:(1)直接训练使用。(2)softmax loss结合使用。(3)先使用softmax loss进行训练,再使用lovasz softmax loss进行finetuning. 第1种方式不一定达到理想效果,推荐使用后两种方式。本文以第2种方式为例。
同时,也可以通过`cfg.SOLVER.LOSS_WEIGHT`参数对不同loss进行权重配比,灵活运用于训练调参。如下所示
```yaml
SOLVER:
LOSS: ["lovasz_softmax_loss","softmax_loss"]
LOSS_WEIGHT:
LOVASZ_SOFTMAX_LOSS: 0.2
SOFTMAX_LOSS: 0.8
```
### 实验对比
## Lovasz softmax loss实验对比
接下来以PASCAL VOC 2012数据集为例应用lovasz softmax loss. 我们将lovasz softmax loss与softmax loss进行了实验对比。为进行快速体验,这里使用DeepLabv3+模型,backbone为MobileNetV2.
......@@ -125,7 +123,7 @@ python pdseg/eval.py --cfg ./configs/lovasz_softmax_deeplabv3p_mobilenet_pascal.
* 结果比较
lovasz softmax loss + softmax loss和softmax loss的对比结果如下图所示。
lovasz softmax loss + softmax loss和softmax loss的mIoU曲线如下图所示。
<p align="center">
<img src="./imgs/lovasz-softmax.png" hspace='10' /> <br />
</p>
......
# 动态图执行
## 数据集设置
```
data_dir='data/path'
train_list='train/list/path'
val_list='val/list/path'
test_list='test/list/path'
num_classes=number/of/dataset/classes
```
## 训练
```
python3 train.py --model_name UNet \
--data_dir $data_dir \
--train_list $train_list \
--val_list $val_list \
--num_classes $num_classes \
--input_size 192 192 \
--num_epochs 4 \
--save_interval_epochs 1 \
--save_dir output
```
## 评估
```
python3 val.py --model_name UNet \
--data_dir $data_dir \
--val_list $val_list \
--num_classes $num_classes \
--input_size 192 192 \
--model_dir output/epoch_1
```
## 预测
```
python3 infer.py --model_name UNet \
--data_dir $data_dir \
--test_list $test_list \
--num_classes $num_classes \
--input_size 192 192 \
--model_dir output/epoch_1
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle.fluid.io import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/cityscapes.tar"
class Cityscapes(Dataset):
def __init__(self,
data_dir=None,
transforms=None,
mode='train',
download=True):
self.data_dir = data_dir
self.transforms = transforms
self.file_list = list()
self.mode = mode
self.num_classes = 19
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transform is necessary, but it is None.")
self.data_dir = data_dir
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
if mode == 'train':
file_list = os.path.join(self.data_dir, 'train.list')
elif mode == 'eval':
file_list = os.path.join(self.data_dir, 'val.list')
else:
file_list = os.path.join(self.data_dir, 'test.list')
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
return im, label
elif self.mode == 'eval':
return im, label
if self.mode == 'test':
return im, im_info, image_path
def __len__(self):
return len(self.file_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle.fluid.io import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
class OpticDiscSeg(Dataset):
def __init__(self,
data_dir=None,
transforms=None,
mode='train',
download=True):
self.data_dir = data_dir
self.transforms = transforms
self.file_list = list()
self.mode = mode
self.num_classes = 2
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transform is necessary, but it is None.")
self.data_dir = data_dir
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt')
elif mode == 'eval':
file_list = os.path.join(self.data_dir, 'val_list.txt')
else:
file_list = os.path.join(self.data_dir, 'test_list.txt')
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
return im, label
elif self.mode == 'eval':
return im, label
if self.mode == 'test':
return im, im_info, image_path
def __len__(self):
return len(self.file_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
import cv2
import tqdm
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
import utils
import utils.logging as logging
from utils import get_environ_info
def parse_args():
parser = argparse.ArgumentParser(description='Model training')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str,
default='OpticDiscSeg')
# params of prediction
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=2)
parser.add_argument(
'--model_dir',
dest='model_dir',
help='The path of model for evaluation',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the inference results',
type=str,
default='./output/result')
return parser.parse_args()
def mkdir(path):
sub_dir = os.path.dirname(path)
if not os.path.exists(sub_dir):
os.makedirs(sub_dir)
def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
ckpt_path = os.path.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
added_saved_dir = os.path.join(save_dir, 'added')
pred_saved_dir = os.path.join(save_dir, 'prediction')
logging.info("Start to predict...")
for im, im_info, im_path in tqdm.tqdm(test_dataset):
im = im[np.newaxis, ...]
im = to_variable(im)
pred, _ = model(im, mode='test')
pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8')
keys = list(im_info.keys())
for k in keys[::-1]:
if k == 'shape_before_resize':
h, w = im_info[k][0], im_info[k][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif k == 'shape_before_padding':
h, w = im_info[k][0], im_info[k][1]
pred = pred[0:h, 0:w]
im_file = im_path.replace(test_dataset.data_dir, '')
if im_file[0] == '/':
im_file = im_file[1:]
# save added image
added_image = utils.visualize(im_path, pred, weight=0.6)
added_image_path = os.path.join(added_saved_dir, im_file)
mkdir(added_image_path)
cv2.imwrite(added_image_path, added_image)
# save prediction
pred_im = utils.visualize(im_path, pred, weight=0.0)
pred_saved_path = os.path.join(pred_saved_dir, im_file)
mkdir(pred_saved_path)
cv2.imwrite(pred_saved_path, pred_im)
def main(args):
env_info = get_environ_info()
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = dataset(transforms=test_transforms, mode='test')
if args.model_name == 'UNet':
model = models.UNet(num_classes=test_dataset.num_classes)
infer(
model,
model_dir=args.model_dir,
test_dataset=test_dataset,
save_dir=args.save_dir)
if __name__ == '__main__':
args = parse_args()
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet import UNet
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D
class UNet(fluid.dygraph.Layer):
def __init__(self, num_classes, ignore_index=255):
super().__init__()
self.encode = UnetEncoder()
self.decode = UnetDecode()
self.get_logit = GetLogit(64, num_classes)
self.ignore_index = ignore_index
self.EPS = 1e-5
def forward(self, x, label=None, mode='train'):
encode_data, short_cuts = self.encode(x)
decode_data = self.decode(encode_data, short_cuts)
logit = self.get_logit(decode_data)
if mode == 'train':
return self._get_loss(logit, label)
else:
score_map = fluid.layers.softmax(logit, axis=1)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
def _get_loss(self, logit, label):
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
class UnetEncoder(fluid.dygraph.Layer):
def __init__(self):
super().__init__()
self.double_conv = DoubleConv(3, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 512)
def forward(self, x):
short_cuts = []
x = self.double_conv(x)
short_cuts.append(x)
x = self.down1(x)
short_cuts.append(x)
x = self.down2(x)
short_cuts.append(x)
x = self.down3(x)
short_cuts.append(x)
x = self.down4(x)
return x, short_cuts
class UnetDecode(fluid.dygraph.Layer):
def __init__(self):
super().__init__()
self.up1 = Up(512, 256)
self.up2 = Up(256, 128)
self.up3 = Up(128, 64)
self.up4 = Up(64, 64)
def forward(self, x, short_cuts):
x = self.up1(x, short_cuts[3])
x = self.up2(x, short_cuts[2])
x = self.up3(x, short_cuts[1])
x = self.up4(x, short_cuts[0])
return x
class DoubleConv(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
self.conv0 = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=1,
padding=1)
self.bn0 = BatchNorm(num_channels=num_filters)
self.conv1 = Conv2D(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=1,
padding=1)
self.bn1 = BatchNorm(num_channels=num_filters)
def forward(self, x):
x = self.conv0(x)
x = self.bn0(x)
x = fluid.layers.relu(x)
x = self.conv1(x)
x = self.bn1(x)
x = fluid.layers.relu(x)
return x
class Down(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
self.max_pool = Pool2D(
pool_size=2, pool_type='max', pool_stride=2, pool_padding=0)
self.double_conv = DoubleConv(num_channels, num_filters)
def forward(self, x):
x = self.max_pool(x)
x = self.double_conv(x)
return x
class Up(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
self.double_conv = DoubleConv(2 * num_channels, num_filters)
def forward(self, x, short_cut):
short_cut_shape = fluid.layers.shape(short_cut)
x = fluid.layers.resize_bilinear(x, short_cut_shape[2:])
x = fluid.layers.concat([x, short_cut], axis=1)
x = self.double_conv(x)
return x
class GetLogit(fluid.dygraph.Layer):
def __init__(self, num_channels, num_classes):
super().__init__()
self.conv = Conv2D(
num_channels=num_channels,
num_filters=num_classes,
filter_size=3,
stride=1,
padding=1)
def forward(self, x):
x = self.conv(x)
return x
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
import utils.logging as logging
from utils import get_environ_info
from utils import load_pretrained_model
from utils import resume
from utils import Timer, calculate_eta
from val import evaluate
def parse_args():
parser = argparse.ArgumentParser(description='Model training')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str,
default='OpticDiscSeg')
# params of training
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument(
'--num_epochs',
dest='num_epochs',
help='Number epochs for training',
type=int,
default=100)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size of one gpu or cpu',
type=int,
default=2)
parser.add_argument(
'--learning_rate',
dest='learning_rate',
help='Learning rate',
type=float,
default=0.01)
parser.add_argument(
'--pretrained_model',
dest='pretrained_model',
help='The path of pretrained model',
type=str,
default=None)
parser.add_argument(
'--resume_model',
dest='resume_model',
help='The path of resume model',
type=str,
default=None)
parser.add_argument(
'--save_interval_epochs',
dest='save_interval_epochs',
help='The interval epochs for save a model snapshot',
type=int,
default=5)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the model snapshot',
type=str,
default='./output')
parser.add_argument(
'--num_workers',
dest='num_workers',
help='Num workers for data loader',
type=int,
default=0)
parser.add_argument(
'--do_eval',
dest='do_eval',
help='Eval while training',
action='store_true')
parser.add_argument(
'--log_steps',
dest='log_steps',
help='Display logging information at every log_steps',
default=10,
type=int)
parser.add_argument(
'--use_vdl',
dest='use_vdl',
help='Whether to record the data to VisualDL during training',
action='store_true')
return parser.parse_args()
def train(model,
train_dataset,
places=None,
eval_dataset=None,
optimizer=None,
save_dir='output',
num_epochs=100,
batch_size=2,
pretrained_model=None,
resume_model=None,
save_interval_epochs=1,
log_steps=10,
num_classes=None,
num_workers=8,
use_vdl=False):
ignore_index = model.ignore_index
nranks = ParallelEnv().nranks
start_epoch = 0
if resume_model is not None:
start_epoch = resume(model, optimizer, resume_model)
elif pretrained_model is not None:
load_pretrained_model(model, pretrained_model)
if not os.path.isdir(save_dir):
if os.path.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
if nranks > 1:
strategy = fluid.dygraph.prepare_context()
model_parallel = fluid.dygraph.DataParallel(model, strategy)
batch_sampler = DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
places=places,
num_workers=num_workers,
return_list=True,
)
if use_vdl:
from visualdl import LogWriter
log_writer = LogWriter(save_dir)
timer = Timer()
timer.start()
avg_loss = 0.0
steps_per_epoch = len(batch_sampler)
total_steps = steps_per_epoch * (num_epochs - start_epoch)
num_steps = 0
best_mean_iou = -1.0
best_model_epoch = 1
for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader):
images = data[0]
labels = data[1].astype('int64')
if nranks > 1:
loss = model_parallel(images, labels, mode='train')
loss = model_parallel.scale_loss(loss)
loss.backward()
model_parallel.apply_collective_grads()
else:
loss = model(images, labels, mode='train')
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
avg_loss += loss.numpy()[0]
lr = optimizer.current_step_lr()
num_steps += 1
if num_steps % log_steps == 0 and ParallelEnv().local_rank == 0:
avg_loss /= log_steps
time_step = timer.elapsed_time() / log_steps
remain_steps = total_steps - num_steps
logging.info(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={:.4f}, lr={:.6f}, sec/step={:.4f} | ETA {}"
.format(epoch + 1, num_epochs, step + 1, steps_per_epoch,
avg_loss, lr, time_step,
calculate_eta(remain_steps, time_step)))
if use_vdl:
log_writer.add_scalar('Train/loss', avg_loss, num_steps)
log_writer.add_scalar('Train/lr', lr, num_steps)
avg_loss = 0.0
timer.restart()
if ((epoch + 1) % save_interval_epochs == 0
or epoch + 1 == num_epochs) and ParallelEnv().local_rank == 0:
current_save_dir = os.path.join(save_dir,
"epoch_{}".format(epoch + 1))
if not os.path.isdir(current_save_dir):
os.makedirs(current_save_dir)
fluid.save_dygraph(model.state_dict(),
os.path.join(current_save_dir, 'model'))
fluid.save_dygraph(optimizer.state_dict(),
os.path.join(current_save_dir, 'model'))
if eval_dataset is not None:
mean_iou, mean_acc = evaluate(
model,
eval_dataset,
places=places,
model_dir=current_save_dir,
num_classes=num_classes,
batch_size=batch_size,
ignore_index=ignore_index,
epoch_id=epoch + 1)
if mean_iou > best_mean_iou:
best_mean_iou = mean_iou
best_model_epoch = epoch + 1
best_model_dir = os.path.join(save_dir, "best_model")
fluid.save_dygraph(model.state_dict(),
os.path.join(best_model_dir, 'model'))
logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.format(best_model_epoch, best_mean_iou))
if use_vdl:
log_writer.add_scalar('Evaluate/mean_iou', mean_iou,
epoch + 1)
log_writer.add_scalar('Evaluate/mean_acc', mean_acc,
epoch + 1)
model.train()
if use_vdl:
log_writer.close()
def main(args):
env_info = get_environ_info()
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places):
# Creat dataset reader
train_transforms = T.Compose([
T.Resize(args.input_size),
T.RandomHorizontalFlip(),
T.Normalize()
])
train_dataset = dataset(transforms=train_transforms, mode='train')
eval_dataset = None
if args.do_eval:
eval_transforms = T.Compose(
[T.Resize(args.input_size),
T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet':
model = models.UNet(
num_classes=train_dataset.num_classes, ignore_index=255)
# Creat optimizer
# todo, may less one than len(loader)
num_steps_each_epoch = len(train_dataset) // (
args.batch_size * ParallelEnv().nranks)
decay_step = args.num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
optimizer = fluid.optimizer.Momentum(
lr_decay,
momentum=0.9,
parameter_list=model.parameters(),
regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
train(
model,
train_dataset,
places=places,
eval_dataset=eval_dataset,
optimizer=optimizer,
save_dir=args.save_dir,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
resume_model=args.resume_model,
save_interval_epochs=args.save_interval_epochs,
log_steps=args.log_steps,
num_classes=train_dataset.num_classes,
num_workers=args.num_workers,
use_vdl=args.use_vdl)
if __name__ == '__main__':
args = parse_args()
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .transforms import *
from . import functional
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
from PIL import Image, ImageEnhance
def normalize(im, mean, std):
im = im.astype(np.float32, copy=False) / 255.0
im -= mean
im /= std
return im
def permute(im):
im = np.transpose(im, (2, 0, 1))
return im
def resize(im, target_size=608, interp=cv2.INTER_LINEAR):
if isinstance(target_size, list) or isinstance(target_size, tuple):
w = target_size[0]
h = target_size[1]
else:
w = target_size
h = target_size
im = cv2.resize(im, (w, h), interpolation=interp)
return im
def resize_long(im, long_size=224, interpolation=cv2.INTER_LINEAR):
value = max(im.shape[0], im.shape[1])
scale = float(long_size) / float(value)
resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale))
im = cv2.resize(
im, (resized_width, resized_height), interpolation=interpolation)
return im
def horizontal_flip(im):
if len(im.shape) == 3:
im = im[:, ::-1, :]
elif len(im.shape) == 2:
im = im[:, ::-1]
return im
def vertical_flip(im):
if len(im.shape) == 3:
im = im[::-1, :, :]
elif len(im.shape) == 2:
im = im[::-1, :]
return im
def brightness(im, brightness_lower, brightness_upper):
brightness_delta = np.random.uniform(brightness_lower, brightness_upper)
im = ImageEnhance.Brightness(im).enhance(brightness_delta)
return im
def contrast(im, contrast_lower, contrast_upper):
contrast_delta = np.random.uniform(contrast_lower, contrast_upper)
im = ImageEnhance.Contrast(im).enhance(contrast_delta)
return im
def saturation(im, saturation_lower, saturation_upper):
saturation_delta = np.random.uniform(saturation_lower, saturation_upper)
im = ImageEnhance.Color(im).enhance(saturation_delta)
return im
def hue(im, hue_lower, hue_upper):
hue_delta = np.random.uniform(hue_lower, hue_upper)
im = np.array(im.convert('HSV'))
im[:, :, 0] = im[:, :, 0] + hue_delta
im = Image.fromarray(im, mode='HSV').convert('RGB')
return im
def rotate(im, rotate_lower, rotate_upper):
rotate_delta = np.random.uniform(rotate_lower, rotate_upper)
im = im.rotate(int(rotate_delta))
return im
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import logging
from . import download
from .metrics import ConfusionMatrix
from .utils import *
from .timer import Timer, calculate_eta
import os
import sys
import time
import requests
import tarfile
import zipfile
import shutil
import functools
lasttime = time.time()
FLUSH_INTERVAL = 0.1
def progress(str, end=False):
global lasttime
if end:
str += "\n"
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()
def _download_file(url, savepath, print_progress):
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(savepath, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(savepath, 'wb') as f:
dl = 0
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % os.path.basename(savepath))
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * dl) / total_length))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
def _uncompress_file_zip(filepath, extrapath):
files = zipfile.ZipFile(filepath, 'r')
filelist = files.namelist()
rootpath = filelist[0]
total_num = len(filelist)
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index, rootpath
files.close()
yield total_num, index, rootpath
def _uncompress_file_tar(filepath, extrapath, mode="r:gz"):
files = tarfile.open(filepath, mode)
filelist = files.getnames()
total_num = len(filelist)
rootpath = filelist[0]
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index, rootpath
files.close()
yield total_num, index, rootpath
def _uncompress_file(filepath, extrapath, delete_file, print_progress):
if print_progress:
print("Uncompress %s" % os.path.basename(filepath))
if filepath.endswith("zip"):
handler = _uncompress_file_zip
elif filepath.endswith("tgz"):
handler = _uncompress_file_tar
else:
handler = functools.partial(_uncompress_file_tar, mode="r")
for total_num, index, rootpath in handler(filepath, extrapath):
if print_progress:
done = int(50 * float(index) / total_num)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * index) / total_num))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
if delete_file:
os.remove(filepath)
return rootpath
def download_file_and_uncompress(url,
savepath=None,
extrapath=None,
extraname=None,
print_progress=True,
cover=False,
delete_file=True):
if savepath is None:
savepath = "."
if extrapath is None:
extrapath = "."
savename = url.split("/")[-1]
savepath = os.path.join(savepath, savename)
savename = ".".join(savename.split(".")[:-1])
savename = os.path.join(extrapath, savename)
extraname = savename if extraname is None else os.path.join(
extrapath, extraname)
if cover:
if os.path.exists(savepath):
shutil.rmtree(savepath)
if os.path.exists(savename):
shutil.rmtree(savename)
if os.path.exists(extraname):
shutil.rmtree(extraname)
if not os.path.exists(extraname):
if not os.path.exists(savename):
if not os.path.exists(savepath):
_download_file(url, savepath, print_progress)
savename = _uncompress_file(savepath, extrapath, delete_file,
print_progress)
savename = os.path.join(extrapath, savename)
shutil.move(savename, extraname)
return savename
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import sys
from paddle.fluid.dygraph.parallel import ParallelEnv
levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'}
log_level = 2
def log(level=2, message=""):
if ParallelEnv().local_rank == 0:
current_time = time.time()
time_array = time.localtime(current_time)
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
if log_level >= level:
print(
"{} [{}]\t{}".format(current_time, levels[level],
message).encode("utf-8").decode("latin1"))
sys.stdout.flush()
def debug(message=""):
log(level=3, message=message)
def info(message=""):
log(level=2, message=message)
def warning(message=""):
log(level=1, message=message)
def error(message=""):
log(level=0, message=message)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
from scipy.sparse import csr_matrix
class ConfusionMatrix(object):
"""
Confusion Matrix for segmentation evaluation
"""
def __init__(self, num_classes=2, streaming=False):
self.confusion_matrix = np.zeros([num_classes, num_classes],
dtype='int64')
self.num_classes = num_classes
self.streaming = streaming
def calculate(self, pred, label, ignore=None):
# If not in streaming mode, clear matrix everytime when call `calculate`
if not self.streaming:
self.zero_matrix()
label = np.transpose(label, (0, 2, 3, 1))
ignore = np.transpose(ignore, (0, 2, 3, 1))
mask = np.array(ignore) == 1
label = np.asarray(label)[mask]
pred = np.asarray(pred)[mask]
one = np.ones_like(pred)
# Accumuate ([row=label, col=pred], 1) into sparse matrix
spm = csr_matrix((one, (label, pred)),
shape=(self.num_classes, self.num_classes))
spm = spm.todense()
self.confusion_matrix += spm
def zero_matrix(self):
""" Clear confusion matrix """
self.confusion_matrix = np.zeros([self.num_classes, self.num_classes],
dtype='int64')
def mean_iou(self):
iou_list = []
avg_iou = 0
# TODO: use numpy sum axis api to simpliy
vji = np.zeros(self.num_classes, dtype=int)
vij = np.zeros(self.num_classes, dtype=int)
for j in range(self.num_classes):
v_j = 0
for i in range(self.num_classes):
v_j += self.confusion_matrix[j][i]
vji[j] = v_j
for i in range(self.num_classes):
v_i = 0
for j in range(self.num_classes):
v_i += self.confusion_matrix[j][i]
vij[i] = v_i
for c in range(self.num_classes):
total = vji[c] + vij[c] - self.confusion_matrix[c][c]
if total == 0:
iou = 0
else:
iou = float(self.confusion_matrix[c][c]) / total
avg_iou += iou
iou_list.append(iou)
avg_iou = float(avg_iou) / float(self.num_classes)
return np.array(iou_list), avg_iou
def accuracy(self):
total = self.confusion_matrix.sum()
total_right = 0
for c in range(self.num_classes):
total_right += self.confusion_matrix[c][c]
if total == 0:
avg_acc = 0
else:
avg_acc = float(total_right) / total
vij = np.zeros(self.num_classes, dtype=int)
for i in range(self.num_classes):
v_i = 0
for j in range(self.num_classes):
v_i += self.confusion_matrix[j][i]
vij[i] = v_i
acc_list = []
for c in range(self.num_classes):
if vij[c] == 0:
acc = 0
else:
acc = self.confusion_matrix[c][c] / float(vij[c])
acc_list.append(acc)
return np.array(acc_list), avg_acc
def kappa(self):
vji = np.zeros(self.num_classes)
vij = np.zeros(self.num_classes)
for j in range(self.num_classes):
v_j = 0
for i in range(self.num_classes):
v_j += self.confusion_matrix[j][i]
vji[j] = v_j
for i in range(self.num_classes):
v_i = 0
for j in range(self.num_classes):
v_i += self.confusion_matrix[j][i]
vij[i] = v_i
total = self.confusion_matrix.sum()
# avoid spillovers
# TODO: is it reasonable to hard code 10000.0?
total = float(total) / 10000.0
vji = vji / 10000.0
vij = vij / 10000.0
tp = 0
tc = 0
for c in range(self.num_classes):
tp += vji[c] * vij[c]
tc += self.confusion_matrix[c][c]
tc = tc / 10000.0
pe = tp / (total * total)
po = tc / total
kappa = (po - pe) / (1 - pe)
return kappa
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
class Timer(object):
""" Simple timer class for measuring time consuming """
def __init__(self):
self._start_time = 0.0
self._end_time = 0.0
self._elapsed_time = 0.0
self._is_running = False
def start(self):
self._is_running = True
self._start_time = time.time()
def restart(self):
self.start()
def stop(self):
self._is_running = False
self._end_time = time.time()
def elapsed_time(self):
self._end_time = time.time()
self._elapsed_time = self._end_time - self._start_time
if not self.is_running:
return 0.0
return self._elapsed_time
@property
def is_running(self):
return self._is_running
def calculate_eta(remaining_step, speed):
if remaining_step < 0:
remaining_step = 0
remaining_time = int(remaining_step * speed)
result = "{:0>2}:{:0>2}:{:0>2}"
arr = []
for i in range(2, -1, -1):
arr.append(int(remaining_time / 60**i))
remaining_time %= 60**i
return result.format(*arr)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import math
import cv2
import paddle.fluid as fluid
from . import logging
def seconds_to_hms(seconds):
h = math.floor(seconds / 3600)
m = math.floor((seconds - h * 3600) / 60)
s = int(seconds - h * 3600 - m * 60)
hms_str = "{}:{}:{}".format(h, m, s)
return hms_str
def get_environ_info():
info = dict()
info['place'] = 'cpu'
info['num'] = int(os.environ.get('CPU_NUM', 1))
if os.environ.get('CUDA_VISIBLE_DEVICES', None) != "":
if hasattr(fluid.core, 'get_cuda_device_count'):
gpu_num = 0
try:
gpu_num = fluid.core.get_cuda_device_count()
except:
os.environ['CUDA_VISIBLE_DEVICES'] = ''
pass
if gpu_num > 0:
info['place'] = 'cuda'
info['num'] = fluid.core.get_cuda_device_count()
return info
def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None:
logging.info('Load pretrained model from {}'.format(pretrained_model))
if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model')
para_state_dict, _ = fluid.load_dygraph(ckpt_path)
model_state_dict = model.state_dict()
keys = model_state_dict.keys()
num_params_loaded = 0
for k in keys:
if k not in para_state_dict:
logging.warning("{} is not in pretrained model".format(k))
elif list(para_state_dict[k].shape) != list(
model_state_dict[k].shape):
logging.warning(
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
.format(k, para_state_dict[k].shape,
model_state_dict[k].shape))
else:
model_state_dict[k] = para_state_dict[k]
num_params_loaded += 1
model.set_dict(model_state_dict)
logging.info("There are {}/{} varaibles are loaded.".format(
num_params_loaded, len(model_state_dict)))
else:
raise ValueError(
'The pretrained model directory is not Found: {}'.format(
pretrained_model))
else:
logging.info('No pretrained model to load, train from scratch')
def resume(model, optimizer, resume_model):
if resume_model is not None:
logging.info('Resume model from {}'.format(resume_model))
if os.path.exists(resume_model):
ckpt_path = os.path.join(resume_model, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
optimizer.set_dict(opti_state_dict)
epoch = resume_model.split('_')[-1]
if epoch.isdigit():
epoch = int(epoch)
return epoch
else:
raise ValueError(
'The resume model directory is not Found: {}'.format(
resume_model))
else:
logging.info('No model need to resume')
def visualize(image, result, save_dir=None, weight=0.6):
"""
Convert segment result to color image, and save added image.
Args:
image: the path of origin image
result: the predict result of image
save_dir: the directory for saving visual image
weight: the image weight of visual image, and the result weight is (1 - weight)
"""
color_map = get_color_map_list(256)
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(result, color_map[:, 0])
c2 = cv2.LUT(result, color_map[:, 1])
c3 = cv2.LUT(result, color_map[:, 2])
pseudo_img = np.dstack((c1, c2, c3))
im = cv2.imread(image)
vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image_name = os.path.split(image)[-1]
out_path = os.path.join(save_dir, image_name)
cv2.imwrite(out_path, vis_result)
else:
return vis_result
def get_color_map_list(num_classes):
""" Returns the color map for visualizing the segmentation mask,
which can support arbitrary number of classes.
Args:
num_classes: Number of classes
Returns:
The color map
"""
num_classes += 1
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
color_map = color_map[1:]
return color_map
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import math
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
import utils.logging as logging
from utils import get_environ_info
from utils import ConfusionMatrix
from utils import Timer, calculate_eta
def parse_args():
parser = argparse.ArgumentParser(description='Model evaluation')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for evaluation, which is one of ('UNet')",
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to evaluation, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str,
default='OpticDiscSeg')
# params of evaluate
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=2)
parser.add_argument(
'--model_dir',
dest='model_dir',
help='The path of model for evaluation',
type=str,
default=None)
return parser.parse_args()
def evaluate(model,
eval_dataset=None,
places=None,
model_dir=None,
num_classes=None,
batch_size=2,
ignore_index=255,
epoch_id=None):
ckpt_path = os.path.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
batch_sampler = BatchSampler(
eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
loader = DataLoader(
eval_dataset,
batch_sampler=batch_sampler,
places=places,
return_list=True,
)
total_steps = len(batch_sampler)
conf_mat = ConfusionMatrix(num_classes, streaming=True)
logging.info(
"Start to evaluating(total_samples={}, total_steps={})...".format(
len(eval_dataset), total_steps))
timer = Timer()
timer.start()
for step, data in enumerate(loader):
images = data[0]
labels = data[1].astype('int64')
pred, _ = model(images, mode='eval')
pred = pred.numpy()
labels = labels.numpy()
mask = labels != ignore_index
conf_mat.calculate(pred=pred, label=labels, ignore=mask)
_, iou = conf_mat.mean_iou()
time_step = timer.elapsed_time()
remain_step = total_steps - step - 1
logging.info(
"[EVAL] Epoch={}, Step={}/{}, iou={:4f}, sec/step={:.4f} | ETA {}".
format(epoch_id, step + 1, total_steps, iou, time_step,
calculate_eta(remain_step, time_step)))
timer.restart()
category_iou, miou = conf_mat.mean_iou()
category_acc, macc = conf_mat.accuracy()
logging.info("[EVAL] #image={} acc={:.4f} IoU={:.4f}".format(
len(eval_dataset), macc, miou))
logging.info("[EVAL] Category IoU: " + str(category_iou))
logging.info("[EVAL] Category Acc: " + str(category_acc))
logging.info("[EVAL] Kappa:{:.4f} ".format(conf_mat.kappa()))
return miou, macc
def main(args):
env_info = get_environ_info()
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet':
model = models.UNet(num_classes=eval_dataset.num_classes)
evaluate(
model,
eval_dataset,
places=places,
model_dir=args.model_dir,
num_classes=eval_dataset.num_classes,
batch_size=args.batch_size)
if __name__ == '__main__':
args = parse_args()
main(args)
......@@ -77,7 +77,7 @@ def softmax_with_loss(logit,
weighted_label_one_hot.stop_gradient = True
loss = loss * ignore_mask
avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask)
avg_loss = fluid.layers.mean(loss) / (fluid.layers.mean(ignore_mask) + cfg.MODEL.DEFAULT_EPSILON)
label.stop_gradient = True
ignore_mask.stop_gradient = True
......@@ -133,10 +133,12 @@ def multi_softmax_with_loss(logits,
for i, logit in enumerate(logits):
if label.shape[2] != logit.shape[2] or label.shape[
3] != logit.shape[3]:
label = fluid.layers.resize_nearest(label, logit.shape[2:])
logit_mask = (label.astype('int32') !=
logit_label = fluid.layers.resize_nearest(label, logit.shape[2:])
else:
logit_label = label
logit_mask = (logit_label.astype('int32') !=
cfg.DATASET.IGNORE_INDEX).astype('int32')
loss = softmax_with_loss(logit, label, logit_mask, num_classes)
loss = softmax_with_loss(logit, logit_label, logit_mask, num_classes, weight=weight)
avg_loss += cfg.MODEL.MULTI_LOSS_WEIGHT[i] * loss
else:
avg_loss = softmax_with_loss(
......@@ -148,7 +150,11 @@ def multi_dice_loss(logits, label, ignore_mask=None):
if isinstance(logits, tuple):
avg_loss = 0
for i, logit in enumerate(logits):
logit_label = fluid.layers.resize_nearest(label, logit.shape[2:])
if label.shape[2] != logit.shape[2] or label.shape[
3] != logit.shape[3]:
logit_label = fluid.layers.resize_nearest(label, logit.shape[2:])
else:
logit_label = label
logit_mask = (logit_label.astype('int32') !=
cfg.DATASET.IGNORE_INDEX).astype('int32')
loss = dice_loss(logit, logit_label, logit_mask)
......@@ -162,7 +168,11 @@ def multi_bce_loss(logits, label, ignore_mask=None):
if isinstance(logits, tuple):
avg_loss = 0
for i, logit in enumerate(logits):
logit_label = fluid.layers.resize_nearest(label, logit.shape[2:])
if label.shape[2] != logit.shape[2] or label.shape[
3] != logit.shape[3]:
logit_label = fluid.layers.resize_nearest(label, logit.shape[2:])
else:
logit_label = label
logit_mask = (logit_label.astype('int32') !=
cfg.DATASET.IGNORE_INDEX).astype('int32')
loss = bce_loss(logit, logit_label, logit_mask)
......
......@@ -2,4 +2,4 @@ pre-commit
yapf == 0.26.0
flake8
pyyaml >= 5.1
visualdl == 2.0.0b1
visualdl == 2.0.0b4
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册