提交 f55d38a8 编写于 作者: W wuzewu

Merge branch 'release/v0.1.0'

...@@ -88,6 +88,10 @@ A: 降低Batch size,使用Group Norm策略;请注意训练过程中当`DEFAU ...@@ -88,6 +88,10 @@ A: 降低Batch size,使用Group Norm策略;请注意训练过程中当`DEFAU
</br> </br>
#### Q: 出现错误 ModuleNotFoundError: No module named 'paddle.fluid.contrib.mixed_precision'
A: 请将PaddlePaddle升级至1.5.2版本或以上。
## 在线体验 ## 在线体验
PaddleSeg在AI Studio平台上提供了在线体验的教程,欢迎体验: PaddleSeg在AI Studio平台上提供了在线体验的教程,欢迎体验:
...@@ -101,15 +105,13 @@ PaddleSeg在AI Studio平台上提供了在线体验的教程,欢迎体验: ...@@ -101,15 +105,13 @@ PaddleSeg在AI Studio平台上提供了在线体验的教程,欢迎体验:
</br> </br>
## 交流与反馈 ## 交流与反馈
* 欢迎您通过Github Issues来提交问题、报告与建议 * 欢迎您通过[Github Issues](https://github.com/PaddlePaddle/PaddleSeg/issues)来提交问题、报告与建议
* 微信公众号:飞桨PaddlePaddle * 微信公众号:飞桨PaddlePaddle
* QQ群: 796771754 * QQ群: 796771754
<p align="center"><img width="200" height="200" src="https://user-images.githubusercontent.com/45189361/64117959-1969de80-cdc9-11e9-84f7-e1c2849a004c.jpeg"/>&#8194;&#8194;&#8194;&#8194;&#8194;<img width="200" height="200" margin="500" src="./docs/imgs/qq_group2.png"/></p> <p align="center"><img width="200" height="200" src="https://user-images.githubusercontent.com/45189361/64117959-1969de80-cdc9-11e9-84f7-e1c2849a004c.jpeg"/>&#8194;&#8194;&#8194;&#8194;&#8194;<img width="200" height="200" margin="500" src="./docs/imgs/qq_group2.png"/></p>
<p align="center"> &#8194;&#8194;&#8194;微信公众号&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;官方技术交流QQ群</p> <p align="center"> &#8194;&#8194;&#8194;微信公众号&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;&#8194;官方技术交流QQ群</p>
* 论坛: 欢迎大家在[PaddlePaddle论坛](https://ai.baidu.com/forum/topic/list/168)分享在使用PaddlePaddle中遇到的问题和经验, 营造良好的论坛氛围
## 更新日志 ## 更新日志
* 2019.09.10 * 2019.09.10
......
...@@ -30,13 +30,13 @@ MODEL: ...@@ -30,13 +30,13 @@ MODEL:
MODEL_NAME: "unet" MODEL_NAME: "unet"
DEFAULT_NORM_TYPE: "bn" DEFAULT_NORM_TYPE: "bn"
TEST: TEST:
TEST_MODEL: "./test/saved_model/unet_pet/final/" TEST_MODEL: "./saved_model/unet_pet/final/"
TRAIN: TRAIN:
MODEL_SAVE_DIR: "./test/saved_models/unet_pet/" MODEL_SAVE_DIR: "./saved_model/unet_pet/"
PRETRAINED_MODEL_DIR: "./test/models/unet_coco/" PRETRAINED_MODEL_DIR: "./pretrained_model/unet_bn_coco/"
SNAPSHOT_EPOCH: 10 SNAPSHOT_EPOCH: 10
SOLVER: SOLVER:
NUM_EPOCHS: 500 NUM_EPOCHS: 100
LR: 0.005 LR: 0.005
LR_POLICY: "poly" LR_POLICY: "poly"
OPTIMIZER: "adam" OPTIMIZER: "adam"
# 数据下载
## PASCAL VOC 2012数据集
下载 PASCAL VOC 2012数据集并将分割部分的假彩色标注图(`SegmentationClass`文件夹)转换成灰度图并存储在`SegmentationClassAug`文件夹, 并在文件夹`ImageSets/Segmentation`下重新生成列表文件`train.list、val.list和trainval.list。
```shell
# 下载数据集并进行解压转换
python download_and_convert_voc2012.py
```
如果已经下载好PASCAL VOC 2012数据集,将数据集移至dataset目录后使用下述命令直接进行转换即可。
```shell
# 数据集转换
python convert_voc2012.py
```
## Oxford-IIIT Pet数据集
我们使用了Oxford-IIIT中的猫和狗两个类别数据制作了一个小数据集mini_pet,更多关于数据集的介绍请参考[Oxford-IIIT Pet](https://www.robots.ox.ac.uk/~vgg/data/pets/)。
```shell
# 下载数据集并进行解压
python dataset/download_pet.py
```
## Cityscapes数据集
运行下述命令下载并解压Cityscapes数据集。
```shell
# 下载数据集并进行解压
python dataset/download_cityscapes.py
```
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import numpy as np
import os
from PIL import Image
import glob
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
def remove_colormap(filename):
gray_anno = np.array(Image.open(filename))
return gray_anno
def save_annotation(annotation, filename):
annotation = annotation.astype(dtype=np.uint8)
annotation = Image.fromarray(annotation)
annotation.save(filename)
def convert_list(origin_file, seg_file, output_folder):
with open(seg_file, 'w') as fid_seg:
with open(origin_file) as fid_ori:
lines = fid_ori.readlines()
for line in lines:
line = line.strip()
line = '.'.join([line, 'jpg'])
img_name = os.path.join("JPEGImages", line)
line = line.replace('jpg', 'png')
anno_name = os.path.join(output_folder.split(os.sep)[-1], line)
new_line = ' '.join([img_name, anno_name])
fid_seg.write(new_line + "\n")
if __name__ == "__main__":
pascal_root = "./VOCtrainval_11-May-2012/VOC2012"
pascal_root = os.path.join(LOCAL_PATH, pascal_root)
seg_folder = os.path.join(pascal_root, "SegmentationClass")
txt_folder = os.path.join(pascal_root, "ImageSets/Segmentation")
train_path = os.path.join(txt_folder, "train.txt")
val_path = os.path.join(txt_folder, "val.txt")
trainval_path = os.path.join(txt_folder, "trainval.txt")
# 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert")
if not os.path.exists(os.path.join(LOCAL_PATH, output_folder)):
os.mkdir(os.path.join(LOCAL_PATH, output_folder))
annotation_names = glob.glob(os.path.join(seg_folder, '*.png'))
for annotation_name in annotation_names:
annotation = remove_colormap(annotation_name)
filename = os.path.basename(annotation_name)
save_name = os.path.join(output_folder, filename)
save_annotation(annotation, save_name)
convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import numpy as np
import os
import glob
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = os.path.join(LOCAL_PATH, "..", "test")
sys.path.append(TEST_PATH)
from test_utils import download_file_and_uncompress
from convert_voc2012 import convert_list
from convert_voc2012 import remove_colormap
from convert_voc2012 import save_annotation
def download_VOC_dataset(savepath, extrapath):
url = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
download_file_and_uncompress(
url=url, savepath=savepath, extrapath=extrapath)
if __name__ == "__main__":
download_VOC_dataset(LOCAL_PATH, LOCAL_PATH)
print("Dataset download finish!")
pascal_root = "./VOCtrainval_11-May-2012/VOC2012"
pascal_root = os.path.join(LOCAL_PATH, pascal_root)
seg_folder = os.path.join(pascal_root, "SegmentationClass")
txt_folder = os.path.join(pascal_root, "ImageSets/Segmentation")
train_path = os.path.join(txt_folder, "train.txt")
val_path = os.path.join(txt_folder, "val.txt")
trainval_path = os.path.join(txt_folder, "trainval.txt")
# 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert")
if not os.path.exists(output_folder):
os.mkdir(output_folder)
annotation_names = glob.glob(os.path.join(seg_folder, '*.png'))
for annotation_name in annotation_names:
annotation = remove_colormap(annotation_name)
filename = os.path.basename(annotation_name)
save_name = os.path.join(output_folder, filename)
save_annotation(annotation, save_name)
convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder)
# PaddleSeg 数据标注
用户需预先采集好用于训练、评估和测试的图片,并使用数据标注工具完成数据标注。
PaddleSeg支持2种标注工具:LabelMe、精灵数据标注工具。
标注教程如下:
- [LabelMe标注教程](labelme2seg.md)
- [精灵数据标注工具教程](jingling2seg.md)
最后用我们提供的数据转换脚本将上述标注工具产出的数据格式转换为模型训练时所需的数据格式。
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# YAML_FILE_PATH为yaml配置文件路径 # YAML_FILE_PATH为yaml配置文件路径
python pdseg/check.py --cfg ${YAML_FILE_PATH} python pdseg/check.py --cfg ${YAML_FILE_PATH}
``` ```
运行后,命令行将显示校验结果的概览信息,详细信息可到detail.log文件中查看。 运行后,命令行将显示校验结果的概览信息,详细的错误信息可到detail.log文件中查看。
### 1 列表分割符校验 ### 1 列表分割符校验
判断在`TRAIN_FILE_LIST``VAL_FILE_LIST``TEST_FILE_LIST`列表文件中的分隔符`DATASET.SEPARATOR`设置是否正确。 判断在`TRAIN_FILE_LIST``VAL_FILE_LIST``TEST_FILE_LIST`列表文件中的分隔符`DATASET.SEPARATOR`设置是否正确。
...@@ -31,18 +31,24 @@ python pdseg/check.py --cfg ${YAML_FILE_PATH} ...@@ -31,18 +31,24 @@ python pdseg/check.py --cfg ${YAML_FILE_PATH}
标注类别最好从0开始,否则可能影响精度。 标注类别最好从0开始,否则可能影响精度。
### 6 标注像素统计 ### 6 标注像素统计
统计每种类别像素数量,显示以供参考。 统计每种类别的像素总数和所占比例,显示以供参考。统计结果如下:
```
Doing label pixel statistics:
(label class, total pixel number, percentage) = [(0, 2048984, 0.5211), (1, 1682943, 0.428), (2, 197976, 0.0503), (3, 2257, 0.0006)]
```
### 7 图像格式校验 ### 7 图像格式校验
检查图片类型`DATASET.IMAGE_TYPE`是否设置正确。 检查图片类型`DATASET.IMAGE_TYPE`是否设置正确。
**NOTE:** 当数据集包含三通道图片时`DATASET.IMAGE_TYPE`设置为rgb; **NOTE:** 当数据集包含三通道图片时`DATASET.IMAGE_TYPE`设置为rgb;
当数据集全部为四通道图片时`DATASET.IMAGE_TYPE`设置为rgba; 当数据集全部为四通道图片时`DATASET.IMAGE_TYPE`设置为rgba;
### 8 图像与标注图尺寸一致性校验 ### 8 图像最大尺寸统计
统计数据集中图片的最大高和最大宽,显示以供参考。
### 9 图像与标注图尺寸一致性校验
验证图像尺寸和对应标注图尺寸是否一致。 验证图像尺寸和对应标注图尺寸是否一致。
### 9 模型验证参数`EVAL_CROP_SIZE`校验 ### 10 模型验证参数`EVAL_CROP_SIZE`校验
验证`EVAL_CROP_SIZE`是否设置正确,共有3种情形: 验证`EVAL_CROP_SIZE`是否设置正确,共有3种情形:
-`AUG.AUG_METHOD`为unpadding时,`EVAL_CROP_SIZE`的宽高应不小于`AUG.FIX_RESIZE_SIZE`的宽高。 -`AUG.AUG_METHOD`为unpadding时,`EVAL_CROP_SIZE`的宽高应不小于`AUG.FIX_RESIZE_SIZE`的宽高。
...@@ -51,5 +57,5 @@ python pdseg/check.py --cfg ${YAML_FILE_PATH} ...@@ -51,5 +57,5 @@ python pdseg/check.py --cfg ${YAML_FILE_PATH}
-`AUG.AUG_METHOD`为rangscaling时,`EVAL_CROP_SIZE`的宽高应不小于缩放后图像中最大的宽高。 -`AUG.AUG_METHOD`为rangscaling时,`EVAL_CROP_SIZE`的宽高应不小于缩放后图像中最大的宽高。
### 10 数据增强参数`AUG.INF_RESIZE_VALUE`校验 ### 11 数据增强参数`AUG.INF_RESIZE_VALUE`校验
验证`AUG.INF_RESIZE_VALUE`是否在[`AUG.MIN_RESIZE_VALUE`~`AUG.MAX_RESIZE_VALUE`]范围内。若在范围内,则通过校验。 验证`AUG.INF_RESIZE_VALUE`是否在[`AUG.MIN_RESIZE_VALUE`~`AUG.MAX_RESIZE_VALUE`]范围内。若在范围内,则通过校验。
...@@ -2,10 +2,18 @@ ...@@ -2,10 +2,18 @@
## 数据标注 ## 数据标注
数据标注推荐使用LabelMe工具,具体可参考文档[PaddleSeg 数据标注](./annotation/README.md) 用户需预先采集好用于训练、评估和测试的图片,然后使用数据标注工具完成数据标注。
PddleSeg已支持2种标注工具:LabelMe、精灵数据标注工具。标注教程如下:
## 语义分割标注规范 - [LabelMe标注教程](annotation/labelme2seg.md)
- [精灵数据标注工具教程](annotation/jingling2seg.md)
最后用我们提供的数据转换脚本将上述标注工具产出的数据格式转换为模型训练时所需的数据格式。
## 文件列表
### 文件列表规范
PaddleSeg采用通用的文件列表方式组织训练集、验证集和测试集。像素标注类别需要从0开始递增。 PaddleSeg采用通用的文件列表方式组织训练集、验证集和测试集。像素标注类别需要从0开始递增。
...@@ -57,4 +65,94 @@ PaddleSeg采用通用的文件列表方式组织训练集、验证集和测试 ...@@ -57,4 +65,94 @@ PaddleSeg采用通用的文件列表方式组织训练集、验证集和测试
![cityscapes_filelist](./imgs/file_list.png) ![cityscapes_filelist](./imgs/file_list.png)
若数据集缺少标注图片,则文件列表不用包含分隔符和标注图片路径,如下图所示。
![cityscapes_filelist](./imgs/file_list2.png)
**注意事项**
此时的文件列表仅可在调用`pdseg/vis.py`进行可视化展示时使用,
即仅可在`DATASET.TEST_FILE_LIST``DATASET.VIS_FILE_LIST`配置项中使用。
不可在`DATASET.TRAIN_FILE_LIST``DATASET.VAL_FILE_LIST`配置项中使用。
完整的配置信息可以参考[`./dataset/cityscapes_demo`](../dataset/cityscapes_demo/)目录下的yaml和文件列表。 完整的配置信息可以参考[`./dataset/cityscapes_demo`](../dataset/cityscapes_demo/)目录下的yaml和文件列表。
### 文件列表生成
PaddleSeg提供了生成文件列表的使用脚本,可适用于自定义数据集或cityscapes数据集,并支持通过不同的Flags来开启特定功能。
```
python pdseg/tools/create_dataset_list.py <your/dataset/dir> ${FLAGS}
```
运行后将在数据集根目录下生成训练/验证/测试集的文件列表(文件主名与`--second_folder`一致,扩展名为`.txt`)。
**Note:** 若训练/验证/测试集缺少标注图片,仍可自动生成不含分隔符和标注图片路径的文件列表。
#### 命令行FLAGS列表
|FLAG|用途|默认值|参数数目|
|-|-|-|-|
|--type|指定数据集类型,`cityscapes``自定义`|`自定义`|1|
|--separator|文件列表分隔符|'&#124;'|1|
|--folder|图片和标签集的文件夹名|'images' 'annotations'|2|
|--second_folder|训练/验证/测试集的文件夹名|'train' 'val' 'test'|若干|
|--format|图片和标签集的数据格式|'jpg' 'png'|2|
|--postfix|按文件主名(无扩展名)是否包含指定后缀对图片和标签集进行筛选|'' ''(2个空字符)|2|
#### 使用示例
- **对于自定义数据集**
如果用户想要生成自己数据集的文件列表,需要整理成如下的目录结构:
```
./dataset/ # 数据集根目录
├── annotations # 标注目录
│   ├── test
│   │   ├── ...
│   │   └── ...
│   ├── train
│   │   ├── ...
│   │   └── ...
│   └── val
│   ├── ...
│   └── ...
└── images # 原图目录
├── test
│   ├── ...
│   └── ...
├── train
│   ├── ...
│   └── ...
└── val
├── ...
└── ...
Note:以上目录名可任意
```
必须指定自定义数据集目录,可以按需要设定FLAG。
**Note:** 无需指定`--type`
```
# 生成文件列表,其分隔符为空格,图片和标签集的数据格式都为png
python pdseg/tools/create_dataset_list.py <your/dataset/dir> --separator " " --format png png
```
```
# 生成文件列表,其图片和标签集的文件夹名为img和gt,训练和验证集的文件夹名为training和validation,不生成测试集列表
python pdseg/tools/create_dataset_list.py <your/dataset/dir> \
--folder img gt --second_folder training validation
```
- **对于cityscapes数据集**
必须指定cityscapes数据集目录,`--type`必须为`cityscapes`
在cityscapes类型下,部分FLAG将被重新设定,无需手动指定,具体如下:
|FLAG|固定值|
|-|-|
|--folder|'leftImg8bit' 'gtFine'|
|--format|'png' 'png'|
|--postfix|'_leftImg8bit' '_gtFine_labelTrainIds'|
其余FLAG可以按需要设定。
```
# 生成cityscapes文件列表,其分隔符为逗号
python pdseg/tools/create_dataset_list.py <your/dataset/dir> --type cityscapes --separator ","
```
# PaddleSeg 安装说明 # PaddleSeg 安装说明
## 推荐开发环境 ## 1. 安装PaddlePaddle
* Python 2.7 or 3.5+ 版本要求
* CUDA 9.2
* NVIDIA cuDNN v7.1
* PaddlePaddle >= 1.5.2 * PaddlePaddle >= 1.5.2
* 如果有多卡训练需求,请安装 NVIDIA NCCL >= 2.4.7,并在Linux环境下运行 * Python 2.7 or 3.5+
## 1. 安装PaddlePaddle
### pip安装 ### pip安装
...@@ -27,6 +22,8 @@ PaddlePaddle最新版本1.5支持Conda安装,可以减少相关依赖安装成 ...@@ -27,6 +22,8 @@ PaddlePaddle最新版本1.5支持Conda安装,可以减少相关依赖安装成
conda install -c paddle paddlepaddle-gpu cudatoolkit=9.0 conda install -c paddle paddlepaddle-gpu cudatoolkit=9.0
``` ```
* 如果有多卡训练需求,请安装 NVIDIA NCCL >= 2.4.7,并在Linux环境下运行
更多安装方式详情可以查看 [PaddlePaddle安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html) 更多安装方式详情可以查看 [PaddlePaddle安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html)
......
...@@ -117,16 +117,24 @@ NOTE: ...@@ -117,16 +117,24 @@ NOTE:
```shell ```shell
python pdseg/eval.py --use_gpu \ python pdseg/eval.py --use_gpu \
--cfg configs/unet_pet.yaml \ --cfg configs/unet_pet.yaml \
TEST.TEST_MODEL test/saved_models/unet_pet/final TEST.TEST_MODEL saved_model/unet_pet/final
``` ```
可以看到,在经过训练后,模型在验证集上的mIoU指标达到了0.70+(由于随机种子等因素的影响,效果会有小范围波动,属于正常情况)。
### 模型可视化 ### 模型可视化
通过vis.py来评估模型效果,我们选择最后保存的模型进行效果的评估: 通过vis.py来评估模型效果,我们选择最后保存的模型进行效果的评估:
```shell ```shell
python pdseg/vis.py --use_gpu \ python pdseg/vis.py --use_gpu \
--cfg configs/unet_pet.yaml \ --cfg configs/unet_pet.yaml \
TEST.TEST_MODEL test/saved_models/unet_pet/final TEST.TEST_MODEL saved_model/unet_pet/final
``` ```
执行上述脚本后,会在主目录下产生一个visual/visual_results文件夹,里面存放着测试集图片的预测结果,我们选择其中几张图片进行查看,可以看到,在测试集中的图片上的预测效果已经很不错:
![](./imgs/usage_vis_demo.jpg)
![](./imgs/usage_vis_demo2.jpg)
![](./imgs/usage_vis_demo3.jpg)
`NOTE` `NOTE`
1. 可视化的图片会默认保存在visual/visual_results目录下,可以通过`--vis_dir`来指定输出目录 1. 可视化的图片会默认保存在visual/visual_results目录下,可以通过`--vis_dir`来指定输出目录
2. 训练过程中会使用DATASET.VIS_FILE_LIST中的图片进行可视化显示,而vis.py则会使用DATASET.TEST_FILE_LIST 2. 训练过程中会使用DATASET.VIS_FILE_LIST中的图片进行可视化显示,而vis.py则会使用DATASET.TEST_FILE_LIST
...@@ -42,14 +42,23 @@ include_directories("${PADDLE_DIR}/third_party/install/protobuf/include") ...@@ -42,14 +42,23 @@ include_directories("${PADDLE_DIR}/third_party/install/protobuf/include")
include_directories("${PADDLE_DIR}/third_party/install/glog/include") include_directories("${PADDLE_DIR}/third_party/install/glog/include")
include_directories("${PADDLE_DIR}/third_party/install/gflags/include") include_directories("${PADDLE_DIR}/third_party/install/gflags/include")
include_directories("${PADDLE_DIR}/third_party/install/xxhash/include") include_directories("${PADDLE_DIR}/third_party/install/xxhash/include")
include_directories("${PADDLE_DIR}/third_party/install/snappy/include") if (EXISTS "${PADDLE_DIR}/third_party/install/snappy/include")
include_directories("${PADDLE_DIR}/third_party/install/snappystream/include") include_directories("${PADDLE_DIR}/third_party/install/snappy/include")
endif()
if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/include")
include_directories("${PADDLE_DIR}/third_party/install/snappystream/include")
endif()
include_directories("${PADDLE_DIR}/third_party/install/zlib/include") include_directories("${PADDLE_DIR}/third_party/install/zlib/include")
include_directories("${PADDLE_DIR}/third_party/boost") include_directories("${PADDLE_DIR}/third_party/boost")
include_directories("${PADDLE_DIR}/third_party/eigen3") include_directories("${PADDLE_DIR}/third_party/eigen3")
link_directories("${PADDLE_DIR}/third_party/install/snappy/lib") if (EXISTS "${PADDLE_DIR}/third_party/install/snappy/lib")
link_directories("${PADDLE_DIR}/third_party/install/snappystream/lib") link_directories("${PADDLE_DIR}/third_party/install/snappy/lib")
endif()
if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
link_directories("${PADDLE_DIR}/third_party/install/snappystream/lib")
endif()
link_directories("${PADDLE_DIR}/third_party/install/zlib/lib") link_directories("${PADDLE_DIR}/third_party/install/zlib/lib")
link_directories("${PADDLE_DIR}/third_party/install/protobuf/lib") link_directories("${PADDLE_DIR}/third_party/install/protobuf/lib")
link_directories("${PADDLE_DIR}/third_party/install/glog/lib") link_directories("${PADDLE_DIR}/third_party/install/glog/lib")
...@@ -82,7 +91,7 @@ if (WIN32) ...@@ -82,7 +91,7 @@ if (WIN32)
add_definitions(-DSTATIC_LIB) add_definitions(-DSTATIC_LIB)
endif() endif()
else() else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -o2 -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -o2 -fopenmp -std=c++11")
set(CMAKE_STATIC_LIBRARY_PREFIX "") set(CMAKE_STATIC_LIBRARY_PREFIX "")
endif() endif()
...@@ -160,13 +169,25 @@ if (NOT WIN32) ...@@ -160,13 +169,25 @@ if (NOT WIN32)
set(EXTERNAL_LIB "-lrt -ldl -lpthread") set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${DEPS} set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB} ${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf yaml-cpp snappystream snappy z xxhash glog gflags protobuf yaml-cpp z xxhash
${EXTERNAL_LIB}) ${EXTERNAL_LIB})
if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
set(DEPS ${DEPS} snappystream)
endif()
if (EXISTS "${PADDLE_DIR}/third_party/install/snappy/lib")
set(DEPS ${DEPS} snappy)
endif()
else() else()
set(DEPS ${DEPS} set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB} ${MATH_LIB} ${MKLDNN_LIB}
opencv_world346 glog libyaml-cppmt gflags_static libprotobuf snappy zlibstatic xxhash snappystream ${EXTERNAL_LIB}) opencv_world346 glog libyaml-cppmt gflags_static libprotobuf zlibstatic xxhash ${EXTERNAL_LIB})
set(DEPS ${DEPS} libcmt shlwapi) set(DEPS ${DEPS} libcmt shlwapi)
if (EXISTS "${PADDLE_DIR}/third_party/install/snappy/lib")
set(DEPS ${DEPS} snappy)
endif()
if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
set(DEPS ${DEPS} snappystream)
endif()
endif(NOT WIN32) endif(NOT WIN32)
if(WITH_GPU) if(WITH_GPU)
......
...@@ -71,7 +71,7 @@ deeplabv3p_xception65_humanseg ...@@ -71,7 +71,7 @@ deeplabv3p_xception65_humanseg
### 2. 修改配置 ### 2. 修改配置
基于`PaddleSeg`训练的模型导出时,会自动生成对应的预测模型配置文件,请参考文档:[模型导出](../docs/export_model.md) 基于`PaddleSeg`训练的模型导出时,会自动生成对应的预测模型配置文件,请参考文档:[模型导出](../docs/model_export.md)
`inference`源代码(即本目录)的`conf`目录下提供了示例人像分割模型的配置文件`humanseg.yaml`, 相关的字段含义和说明如下: `inference`源代码(即本目录)的`conf`目录下提供了示例人像分割模型的配置文件`humanseg.yaml`, 相关的字段含义和说明如下:
...@@ -88,7 +88,7 @@ DEPLOY: ...@@ -88,7 +88,7 @@ DEPLOY:
# 预测图片的的标准输入尺寸,输入尺寸不一致会做resize # 预测图片的的标准输入尺寸,输入尺寸不一致会做resize
EVAL_CROP_SIZE: (513, 513) EVAL_CROP_SIZE: (513, 513)
# 均值 # 均值
MEAN: [104.008, 116.669, 122.675] MEAN: [0.40787450980392154, 0.4575254901960784, 0.481078431372549]
# 方差 # 方差
STD: [1.0, 1.0, 1.0] STD: [1.0, 1.0, 1.0]
# 图片类型, rgb 或者 rgba # 图片类型, rgb 或者 rgba
......
DEPLOY: DEPLOY:
USE_GPU: 1 USE_GPU: 1
MODEL_PATH: "/root/projects/models/deeplabv3p_xception65_humanseg" MODEL_PATH: "/root/projects/models/deeplabv3p_xception65_humanseg"
MODEL_NAME: "unet"
MODEL_FILENAME: "__model__" MODEL_FILENAME: "__model__"
PARAMS_FILENAME: "__params__" PARAMS_FILENAME: "__params__"
EVAL_CROP_SIZE: (513, 513) EVAL_CROP_SIZE: (513, 513)
MEAN: [0.5, 0.5, 0.5] MEAN: [0.40787450980392154, 0.4575254901960784, 0.481078431372549]
STD: [1.0, 1.0, 1.0] STD: [0.00392156862745098, 0.00392156862745098, 0.00392156862745098]
IMAGE_TYPE: "rgb" IMAGE_TYPE: "rgb"
NUM_CLASSES: 2 NUM_CLASSES: 2
CHANNELS : 3 CHANNELS : 3
PRE_PROCESSOR: "SegPreProcessor" PRE_PROCESSOR: "SegPreProcessor"
PREDICTOR_MODE: "NATIVE" PREDICTOR_MODE: "NATIVE"
BATCH_SIZE : 3 BATCH_SIZE : 1
#include "seg_predictor.h" #include "seg_predictor.h"
#include <unsupported/Eigen/CXX11/Tensor>
namespace PaddleSolution { namespace PaddleSolution {
...@@ -78,26 +79,8 @@ namespace PaddleSolution { ...@@ -78,26 +79,8 @@ namespace PaddleSolution {
//post process //post process
_mask.clear(); _mask.clear();
_scoremap.clear(); _scoremap.clear();
int out_img_len = eval_height * eval_width; std::vector<int> out_shape{eval_num_class, eval_height, eval_width};
for (int i = 0; i < out_img_len; ++i) { utils::argmax(p_out, out_shape, _mask, _scoremap);
float max_value = -1;
int label = 0;
for (int j = 0; j < eval_num_class; ++j) {
int index = i + j * out_img_len;
if (index >= blob_out_len) {
break;
}
float value = p_out[index];
if (value > max_value) {
max_value = value;
label = j;
}
}
if (label == 0) max_value = 0;
_mask[i] = uchar(label);
_scoremap[i] = uchar(max_value * 255);
}
cv::Mat mask_png = cv::Mat(eval_height, eval_width, CV_8UC1); cv::Mat mask_png = cv::Mat(eval_height, eval_width, CV_8UC1);
mask_png.data = _mask.data(); mask_png.data = _mask.data();
std::string nname(fname); std::string nname(fname);
...@@ -251,6 +234,7 @@ namespace PaddleSolution { ...@@ -251,6 +234,7 @@ namespace PaddleSolution {
int idx = u * default_batch_size + i; int idx = u * default_batch_size + i;
imgs_batch.push_back(imgs[idx]); imgs_batch.push_back(imgs[idx]);
} }
if (!_preprocessor->batch_process(imgs_batch, input_buffer.data(), org_height.data(), org_width.data())) { if (!_preprocessor->batch_process(imgs_batch, input_buffer.data(), org_height.data(), org_width.data())) {
return -1; return -1;
} }
......
...@@ -32,21 +32,7 @@ namespace PaddleSolution { ...@@ -32,21 +32,7 @@ namespace PaddleSolution {
if (*ori_h != rh || *ori_w != rw) { if (*ori_h != rh || *ori_w != rw) {
cv::resize(im, im, resize_size, 0, 0, cv::INTER_LINEAR); cv::resize(im, im, resize_size, 0, 0, cv::INTER_LINEAR);
} }
utils::normalize(im, data, _config->_mean, _config->_std);
float* pmean = _config->_mean.data();
float* pscale = _config->_std.data();
for (int h = 0; h < rh; ++h) {
const uchar* ptr = im.ptr<uchar>(h);
int im_index = 0;
for (int w = 0; w < rw; ++w) {
for (int c = 0; c < channels; ++c) {
int top_index = (c * rh + h) * rw + w;
float pixel = static_cast<float>(ptr[im_index++]);
pixel = (pixel / 255 - pmean[c]) / pscale[c];
data[top_index] = pixel;
}
}
}
return true; return true;
} }
......
#pragma once #pragma once
#include "preprocessor.h" #include "preprocessor.h"
#include "utils/utils.h"
namespace PaddleSolution { namespace PaddleSolution {
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>
#ifdef _WIN32 #ifdef _WIN32
#include <filesystem> #include <filesystem>
#else #else
...@@ -59,5 +63,58 @@ namespace PaddleSolution { ...@@ -59,5 +63,58 @@ namespace PaddleSolution {
return imgs; return imgs;
} }
#endif #endif
// normalize and HWC_BGR -> CHW_RGB
inline void normalize(cv::Mat& im, float* data, std::vector<float>& fmean, std::vector<float>& fstd) {
int rh = im.rows;
int rw = im.cols;
int rc = im.channels();
double normf = (double)1.0 / 255.0;
#pragma omp parallel for
for (int h = 0; h < rh; ++h) {
const uchar* ptr = im.ptr<uchar>(h);
int im_index = 0;
for (int w = 0; w < rw; ++w) {
for (int c = 0; c < rc; ++c) {
int top_index = (c * rh + h) * rw + w;
float pixel = static_cast<float>(ptr[im_index++]);
pixel = (pixel * normf - fmean[c]) / fstd[c];
data[top_index] = pixel;
}
}
}
}
// argmax
inline void argmax(float* out, std::vector<int>& shape, std::vector<uchar>& mask, std::vector<uchar>& scoremap) {
int out_img_len = shape[1] * shape[2];
int blob_out_len = out_img_len * shape[0];
/*
Eigen::TensorMap<Eigen::Tensor<float, 3>> out_3d(out, shape[0], shape[1], shape[2]);
Eigen::Tensor<Eigen::DenseIndex, 2> argmax = out_3d.argmax(0);
*/
float max_value = -1;
int label = 0;
#pragma omp parallel private(label)
for (int i = 0; i < out_img_len; ++i) {
max_value = -1;
label = 0;
#pragma omp for reduction(max : max_value)
for (int j = 0; j < shape[0]; ++j) {
int index = i + j * out_img_len;
if (index >= blob_out_len) {
continue;
}
float value = out[index];
if (value > max_value) {
max_value = value;
label = j;
}
}
if (label == 0) max_value = 0;
mask[i] = uchar(label);
scoremap[i] = uchar(max_value * 255);
}
}
} }
} }
...@@ -16,6 +16,7 @@ import logging ...@@ -16,6 +16,7 @@ import logging
from utils.config import cfg from utils.config import cfg
def init_global_variable(): def init_global_variable():
""" """
初始化全局变量 初始化全局变量
...@@ -31,8 +32,8 @@ def init_global_variable(): ...@@ -31,8 +32,8 @@ def init_global_variable():
global min_aspectratio # 图片最小宽高比 global min_aspectratio # 图片最小宽高比
global max_aspectratio # 图片最大宽高比 global max_aspectratio # 图片最大宽高比
global img_dim # 图片的通道数 global img_dim # 图片的通道数
global list_wrong #文件名格式错误列表 global list_wrong # 文件名格式错误列表
global imread_failed #图片读取失败列表, 二元列表 global imread_failed # 图片读取失败列表, 二元列表
global label_wrong # 标注图片出错列表 global label_wrong # 标注图片出错列表
global label_gray_wrong # 标注图非灰度图列表 global label_gray_wrong # 标注图非灰度图列表
...@@ -52,29 +53,33 @@ def init_global_variable(): ...@@ -52,29 +53,33 @@ def init_global_variable():
label_wrong = [] label_wrong = []
label_gray_wrong = [] label_gray_wrong = []
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='PaddleSeg check') parser = argparse.ArgumentParser(description='PaddleSeg check')
parser.add_argument( parser.add_argument(
'--cfg', '--cfg',
dest='cfg_file', dest='cfg_file',
help='Config file for training (and optionally testing)', help='Config file for training (and optionally testing)',
default=None, default=None,
type=str type=str)
)
return parser.parse_args() return parser.parse_args()
def error_print(str): def error_print(str):
return "".join(["\nNOT PASS ", str]) return "".join(["\nNOT PASS ", str])
def correct_print(str): def correct_print(str):
return "".join(["\nPASS ", str]) return "".join(["\nPASS ", str])
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR): def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
""" """
解决 cv2.imread 在window平台打开中文路径的问题. 解决 cv2.imread 在window平台打开中文路径的问题.
""" """
return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag) return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)
def get_image_max_height_width(img): def get_image_max_height_width(img):
"""获取图片最大宽和高""" """获取图片最大宽和高"""
global max_width, max_height global max_width, max_height
...@@ -83,21 +88,24 @@ def get_image_max_height_width(img): ...@@ -83,21 +88,24 @@ def get_image_max_height_width(img):
max_height = max(height, max_height) max_height = max(height, max_height)
max_width = max(width, max_width) max_width = max(width, max_width)
def get_image_min_max_aspectratio(img): def get_image_min_max_aspectratio(img):
"""计算图片最大宽高比""" """计算图片最大宽高比"""
global min_aspectratio, max_aspectratio global min_aspectratio, max_aspectratio
img_shape = img.shape img_shape = img.shape
height, width = img_shape[0], img_shape[1] height, width = img_shape[0], img_shape[1]
min_aspectratio = min(width/height, min_aspectratio) min_aspectratio = min(width / height, min_aspectratio)
max_aspectratio = max(width/height, max_aspectratio) max_aspectratio = max(width / height, max_aspectratio)
return min_aspectratio, max_aspectratio return min_aspectratio, max_aspectratio
def get_image_dim(img): def get_image_dim(img):
"""获取图像的通道数""" """获取图像的通道数"""
img_shape = img.shape img_shape = img.shape
if img_shape[-1] not in img_dim: if img_shape[-1] not in img_dim:
img_dim.append(img_shape[-1]) img_dim.append(img_shape[-1])
def is_label_gray(grt): def is_label_gray(grt):
"""判断标签是否为灰度图""" """判断标签是否为灰度图"""
grt_shape = grt.shape grt_shape = grt.shape
...@@ -106,6 +114,7 @@ def is_label_gray(grt): ...@@ -106,6 +114,7 @@ def is_label_gray(grt):
else: else:
return False return False
def image_label_shape_check(img, grt): def image_label_shape_check(img, grt):
""" """
验证图像和标注的大小是否匹配 验证图像和标注的大小是否匹配
...@@ -117,11 +126,11 @@ def image_label_shape_check(img, grt): ...@@ -117,11 +126,11 @@ def image_label_shape_check(img, grt):
grt_height = grt.shape[0] grt_height = grt.shape[0]
grt_width = grt.shape[1] grt_width = grt.shape[1]
if img_height != grt_height or img_width != grt_width: if img_height != grt_height or img_width != grt_width:
flag = False flag = False
return flag return flag
def ground_truth_check(grt, grt_path): def ground_truth_check(grt, grt_path):
""" """
验证标注图像的格式 验证标注图像的格式
...@@ -143,6 +152,7 @@ def ground_truth_check(grt, grt_path): ...@@ -143,6 +152,7 @@ def ground_truth_check(grt, grt_path):
return png_format, unique, counts return png_format, unique, counts
def sum_gt_check(png_format, grt_classes, num_of_each_class): def sum_gt_check(png_format, grt_classes, num_of_each_class):
""" """
统计所有标注图上的格式、类别和每个类别的像素数 统计所有标注图上的格式、类别和每个类别的像素数
...@@ -160,7 +170,8 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): ...@@ -160,7 +170,8 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
png_format_wrong_num += 1 png_format_wrong_num += 1
if cfg.DATASET.IGNORE_INDEX in grt_classes: if cfg.DATASET.IGNORE_INDEX in grt_classes:
grt_classes2 = np.delete(grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX)) grt_classes2 = np.delete(
grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX))
else: else:
grt_classes2 = grt_classes grt_classes2 = grt_classes
if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1: if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1:
...@@ -179,6 +190,7 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): ...@@ -179,6 +190,7 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
total_grt_classes += add_class total_grt_classes += add_class
return is_label_correct return is_label_correct
def gt_check(): def gt_check():
""" """
对标注图像进行校验,输出校验结果 对标注图像进行校验,输出校验结果
...@@ -192,16 +204,20 @@ def gt_check(): ...@@ -192,16 +204,20 @@ def gt_check():
return return
else: else:
logger.info(error_print("label format check")) logger.info(error_print("label format check"))
logger.info("total {} label images are png format, {} label images are not png " logger.info(
"format".format(png_format_right_num, png_format_wrong_num)) "total {} label images are png format, {} label images are not png "
"format".format(png_format_right_num, png_format_wrong_num))
if len(png_format_wrong_image) > 0: if len(png_format_wrong_image) > 0:
for i in png_format_wrong_image: for i in png_format_wrong_image:
logger.debug(i) logger.debug(i)
total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
total_nc = sorted(zip(total_grt_classes, total_num_of_each_class)) total_ratio = np.around(total_ratio, decimals=4)
logger.info("\nDoing label pixel statistics...\nTotal label classes " total_nc = sorted(
"and their corresponding numbers:\n{} ".format(total_nc)) zip(total_grt_classes, total_num_of_each_class, total_ratio))
logger.info(
"\nDoing label pixel statistics:\n"
"(label class, total pixel number, percentage) = {} ".format(total_nc))
if len(label_wrong) == 0 and not total_nc[0][0]: if len(label_wrong) == 0 and not total_nc[0][0]:
logger.info(correct_print("label class check!")) logger.info(correct_print("label class check!"))
...@@ -210,13 +226,15 @@ def gt_check(): ...@@ -210,13 +226,15 @@ def gt_check():
if total_nc[0][0]: if total_nc[0][0]:
logger.info("Warning: label classes should start from 0") logger.info("Warning: label classes should start from 0")
if len(label_wrong) > 0: if len(label_wrong) > 0:
logger.info("fatal error: label class is out of range [0, {}]".format(cfg.DATASET.NUM_CLASSES - 1)) logger.info(
"fatal error: label class is out of range [0, {}]".format(
cfg.DATASET.NUM_CLASSES - 1))
for i in label_wrong: for i in label_wrong:
logger.debug(i) logger.debug(i)
def eval_crop_size_check(max_height, max_width, min_aspectratio,
def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio): max_aspectratio):
""" """
判断eval_crop_siz与验证集及测试集的max_height, max_width的关系 判断eval_crop_siz与验证集及测试集的max_height, max_width的关系
param param
...@@ -225,69 +243,109 @@ def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio ...@@ -225,69 +243,109 @@ def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio
""" """
if cfg.AUG.AUG_METHOD == "stepscaling": if cfg.AUG.AUG_METHOD == "stepscaling":
if max_width <= cfg.EVAL_CROP_SIZE[0] and max_height <= cfg.EVAL_CROP_SIZE[1]: if max_width <= cfg.EVAL_CROP_SIZE[
0] and max_height <= cfg.EVAL_CROP_SIZE[1]:
logger.info(correct_print("EVAL_CROP_SIZE check")) logger.info(correct_print("EVAL_CROP_SIZE check"))
logger.info(
"satisfy current EVAL_CROP_SIZE: ({},{}) >= max width and max height of images: ({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], max_width,
max_height))
else: else:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
if max_width > cfg.EVAL_CROP_SIZE[0]: if max_width > cfg.EVAL_CROP_SIZE[0]:
logger.info("The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!".format( logger.info(
cfg.EVAL_CROP_SIZE[0], max_width)) "EVAL_CROP_SIZE[0]: {} should >= max width of images {}!".
format(cfg.EVAL_CROP_SIZE[0], max_width))
if max_height > cfg.EVAL_CROP_SIZE[1]: if max_height > cfg.EVAL_CROP_SIZE[1]:
logger.info(error_print("The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!".format( logger.info(
cfg.EVAL_CROP_SIZE[1], max_height))) "EVAL_CROP_SIZE[1]: {} should >= max height of images {}!".
format(cfg.EVAL_CROP_SIZE[1], max_height))
elif cfg.AUG.AUG_METHOD == "rangescaling": elif cfg.AUG.AUG_METHOD == "rangescaling":
if min_aspectratio <= 1 and max_aspectratio >= 1: if min_aspectratio <= 1 and max_aspectratio >= 1:
if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE: if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE \
and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE:
logger.info(correct_print("EVAL_CROP_SIZE check")) logger.info(correct_print("EVAL_CROP_SIZE check"))
logger.info(
"satisfy current EVAL_CROP_SIZE: ({},{}) >= ({},{}) ".
format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE))
else: else:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info(
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], "EVAL_CROP_SIZE must >= img size({},{}), current EVAL_CROP_SIZE is ({},{})"
cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE)) .format(cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE,
cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1]))
elif min_aspectratio > 1: elif min_aspectratio > 1:
max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio
max_height_rangscaling = round(max_height_rangscaling) max_height_rangscaling = round(max_height_rangscaling)
if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= max_height_rangscaling: if cfg.EVAL_CROP_SIZE[
0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[
1] >= max_height_rangscaling:
logger.info(correct_print("EVAL_CROP_SIZE check")) logger.info(correct_print("EVAL_CROP_SIZE check"))
logger.info(
"satisfy current EVAL_CROP_SIZE: ({},{}) >= ({},{}) ".
format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling))
else: else:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info(
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], "EVAL_CROP_SIZE must >= img size({},{}), current EVAL_CROP_SIZE is ({},{})"
cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling)) .format(cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling,
cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1]))
elif max_aspectratio < 1: elif max_aspectratio < 1:
max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio
max_width_rangscaling = round(max_width_rangscaling) max_width_rangscaling = round(max_width_rangscaling)
if cfg.EVAL_CROP_SIZE[0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE: if cfg.EVAL_CROP_SIZE[
0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[
1] >= cfg.AUG.INF_RESIZE_VALUE:
logger.info(correct_print("EVAL_CROP_SIZE check")) logger.info(correct_print("EVAL_CROP_SIZE check"))
logger.info(
"satisfy current EVAL_CROP_SIZE: ({},{}) >= ({},{}) ".
format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
max_height_rangscaling, cfg.AUG.INF_RESIZE_VALUE))
else: else:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info(
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], "EVAL_CROP_SIZE must >= img size({},{}), current EVAL_CROP_SIZE is ({},{})"
max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE)) .format(max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE,
cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1]))
elif cfg.AUG.AUG_METHOD == "unpadding": elif cfg.AUG.AUG_METHOD == "unpadding":
if len(cfg.AUG.FIX_RESIZE_SIZE) != 2: if len(cfg.AUG.FIX_RESIZE_SIZE) != 2:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("you set AUG.AUG_METHOD = 'unpadding', but AUG.FIX_RESIZE_SIZE is wrong. " logger.info(
"AUG.FIX_RESIZE_SIZE should be a tuple of length 2") "you set AUG.AUG_METHOD = 'unpadding', but AUG.FIX_RESIZE_SIZE is wrong. "
elif cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]: "AUG.FIX_RESIZE_SIZE should be a tuple of length 2")
elif cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[0] \
and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]:
logger.info(correct_print("EVAL_CROP_SIZE check")) logger.info(correct_print("EVAL_CROP_SIZE check"))
logger.info(
"satisfy current EVAL_CROP_SIZE: ({},{}) >= AUG.FIX_RESIZE_SIZE: ({},{}) "
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
else: else:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info(
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], "EVAL_CROP_SIZE: ({},{}) must >= AUG.FIX_RESIZE_SIZE: ({},{})".
cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1])) format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
else: else:
logger.info("\nERROR! cfg.AUG.AUG_METHOD setting wrong, it should be one of " logger.info(
"[unpadding, stepscaling, rangescaling]") "\nERROR! cfg.AUG.AUG_METHOD setting wrong, it should be one of "
"[unpadding, stepscaling, rangescaling]")
def inf_resize_value_check(): def inf_resize_value_check():
if cfg.AUG.AUG_METHOD == "rangescaling": if cfg.AUG.AUG_METHOD == "rangescaling":
if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \ if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \
cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE: cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE:
logger.info("\nWARNING! you set AUG.AUG_METHOD = 'rangescaling'" logger.info(
"AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: " "\nWARNING! you set AUG.AUG_METHOD = 'rangescaling'"
"[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.MIN_RESIZE_VALUE, cfg.AUG.MAX_RESIZE_VALUE)) "AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: "
"[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE,
cfg.AUG.MIN_RESIZE_VALUE,
cfg.AUG.MAX_RESIZE_VALUE))
def image_type_check(img_dim): def image_type_check(img_dim):
""" """
...@@ -299,13 +357,17 @@ def image_type_check(img_dim): ...@@ -299,13 +357,17 @@ def image_type_check(img_dim):
if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba': if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba':
logger.info(error_print("DATASET.IMAGE_TYPE check")) logger.info(error_print("DATASET.IMAGE_TYPE check"))
logger.info("DATASET.IMAGE_TYPE is {} but the type of image has " logger.info("DATASET.IMAGE_TYPE is {} but the type of image has "
"gray or rgb\n".format(cfg.DATASET.IMAGE_TYPE)) "gray or rgb\n".format(cfg.DATASET.IMAGE_TYPE))
elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb': elif (1 not in img_dim and 3 not in img_dim
and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb':
logger.info(correct_print("DATASET.IMAGE_TYPE check")) logger.info(correct_print("DATASET.IMAGE_TYPE check"))
logger.info("\nWARNING: DATASET.IMAGE_TYPE is {} but the type of all image is rgba".format(cfg.DATASET.IMAGE_TYPE)) logger.info(
"\nWARNING: DATASET.IMAGE_TYPE is {} but the type of all image is rgba"
.format(cfg.DATASET.IMAGE_TYPE))
else: else:
logger.info(correct_print("DATASET.IMAGE_TYPE check")) logger.info(correct_print("DATASET.IMAGE_TYPE check"))
def shape_check(): def shape_check():
"""输出shape校验结果""" """输出shape校验结果"""
if len(shape_unequal_image) == 0: if len(shape_unequal_image) == 0:
...@@ -313,7 +375,8 @@ def shape_check(): ...@@ -313,7 +375,8 @@ def shape_check():
logger.info("All images are the same shape as the labels") logger.info("All images are the same shape as the labels")
else: else:
logger.info(error_print("shape check")) logger.info(error_print("shape check"))
logger.info("Some images are not the same shape as the labels as follow: ") logger.info(
"Some images are not the same shape as the labels as follow: ")
for i in shape_unequal_image: for i in shape_unequal_image:
logger.debug(i) logger.debug(i)
...@@ -321,13 +384,19 @@ def shape_check(): ...@@ -321,13 +384,19 @@ def shape_check():
def file_list_check(list_name): def file_list_check(list_name):
"""检查分割符是否复合要求""" """检查分割符是否复合要求"""
if len(list_wrong) == 0: if len(list_wrong) == 0:
logger.info(correct_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check")) logger.info(
correct_print(
list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
else: else:
logger.info(error_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check")) logger.info(
logger.info("The following list is not separated by {}".format(cfg.DATASET.SEPARATOR)) error_print(
list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
logger.info("The following list is not separated by {}".format(
cfg.DATASET.SEPARATOR))
for i in list_wrong: for i in list_wrong:
logger.debug(i) logger.debug(i)
def imread_check(): def imread_check():
if len(imread_failed) == 0: if len(imread_failed) == 0:
logger.info(correct_print("dataset reading check")) logger.info(correct_print("dataset reading check"))
...@@ -338,18 +407,25 @@ def imread_check(): ...@@ -338,18 +407,25 @@ def imread_check():
for i in imread_failed: for i in imread_failed:
logger.debug(i) logger.debug(i)
def label_gray_check(): def label_gray_check():
if len(label_gray_wrong) == 0: if len(label_gray_wrong) == 0:
logger.info(correct_print("label gray check")) logger.info(correct_print("label gray check"))
logger.info("All label images are gray") logger.info("All label images are gray")
else: else:
logger.info(error_print("label gray check")) logger.info(error_print("label gray check"))
logger.info("{} label images are not gray\nLabel pixel statistics may " logger.info(
"be insignificant".format(len(label_gray_wrong))) "{} label images are not gray\nLabel pixel statistics may be insignificant"
.format(len(label_gray_wrong)))
for i in label_gray_wrong: for i in label_gray_wrong:
logger.debug(i) logger.debug(i)
def max_img_size_statistics():
logger.info("\nDoing max image size statistics:")
logger.info("max width and max height of images are ({},{})".format(
max_width, max_height))
def check_train_dataset(): def check_train_dataset():
list_file = cfg.DATASET.TRAIN_FILE_LIST list_file = cfg.DATASET.TRAIN_FILE_LIST
...@@ -376,15 +452,18 @@ def check_train_dataset(): ...@@ -376,15 +452,18 @@ def check_train_dataset():
if not is_gray: if not is_gray:
label_gray_wrong.append(line) label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY) grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
get_image_max_height_width(img)
get_image_dim(img) get_image_dim(img)
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
shape_unequal_image.append(line) shape_unequal_image.append(line)
png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) png_format, grt_classes, num_of_each_class = ground_truth_check(
grt, grt_path)
if not png_format: if not png_format:
png_format_wrong_image.append(line) png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) is_label_correct = sum_gt_check(png_format, grt_classes,
num_of_each_class)
if not is_label_correct: if not is_label_correct:
label_wrong.append(line) label_wrong.append(line)
...@@ -393,12 +472,10 @@ def check_train_dataset(): ...@@ -393,12 +472,10 @@ def check_train_dataset():
label_gray_check() label_gray_check()
gt_check() gt_check()
image_type_check(img_dim) image_type_check(img_dim)
max_img_size_statistics()
shape_check() shape_check()
def check_val_dataset(): def check_val_dataset():
list_file = cfg.DATASET.VAL_FILE_LIST list_file = cfg.DATASET.VAL_FILE_LIST
logger.info("\n-----------------------------\n2. Check val dataset...") logger.info("\n-----------------------------\n2. Check val dataset...")
...@@ -417,7 +494,8 @@ def check_val_dataset(): ...@@ -417,7 +494,8 @@ def check_val_dataset():
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED) grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e: except Exception as e:
imread_failed.append((line, e.message)) imread_failed.append((line, str(e)))
continue
is_gray = is_label_gray(grt) is_gray = is_label_gray(grt)
if not is_gray: if not is_gray:
...@@ -429,10 +507,12 @@ def check_val_dataset(): ...@@ -429,10 +507,12 @@ def check_val_dataset():
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
shape_unequal_image.append(line) shape_unequal_image.append(line)
png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) png_format, grt_classes, num_of_each_class = ground_truth_check(
grt, grt_path)
if not png_format: if not png_format:
png_format_wrong_image.append(line) png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) is_label_correct = sum_gt_check(png_format, grt_classes,
num_of_each_class)
if not is_label_correct: if not is_label_correct:
label_wrong.append(line) label_wrong.append(line)
...@@ -441,8 +521,11 @@ def check_val_dataset(): ...@@ -441,8 +521,11 @@ def check_val_dataset():
label_gray_check() label_gray_check()
gt_check() gt_check()
image_type_check(img_dim) image_type_check(img_dim)
max_img_size_statistics()
shape_check() shape_check()
eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio) eval_crop_size_check(max_height, max_width, min_aspectratio,
max_aspectratio)
def check_test_dataset(): def check_test_dataset():
list_file = cfg.DATASET.TEST_FILE_LIST list_file = cfg.DATASET.TEST_FILE_LIST
...@@ -470,7 +553,7 @@ def check_test_dataset(): ...@@ -470,7 +553,7 @@ def check_test_dataset():
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED) grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e: except Exception as e:
imread_failed.append((line, e.message)) imread_failed.append((line, str(e)))
continue continue
is_gray = is_label_gray(grt) is_gray = is_label_gray(grt)
...@@ -480,10 +563,12 @@ def check_test_dataset(): ...@@ -480,10 +563,12 @@ def check_test_dataset():
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
shape_unequal_image.append(line) shape_unequal_image.append(line)
png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) png_format, grt_classes, num_of_each_class = ground_truth_check(
grt, grt_path)
if not png_format: if not png_format:
png_format_wrong_image.append(line) png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) is_label_correct = sum_gt_check(png_format, grt_classes,
num_of_each_class)
if not is_label_correct: if not is_label_correct:
label_wrong.append(line) label_wrong.append(line)
else: else:
...@@ -500,14 +585,17 @@ def check_test_dataset(): ...@@ -500,14 +585,17 @@ def check_test_dataset():
if has_label: if has_label:
gt_check() gt_check()
image_type_check(img_dim) image_type_check(img_dim)
max_img_size_statistics()
if has_label: if has_label:
shape_check() shape_check()
eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio) eval_crop_size_check(max_height, max_width, min_aspectratio,
max_aspectratio)
def main(args): def main(args):
if args.cfg_file is not None: if args.cfg_file is not None:
cfg.update_from_file(args.cfg_file) cfg.update_from_file(args.cfg_file)
cfg.check_and_infer(reset_dataset=True) cfg.check_and_infer()
logger.info(pprint.pformat(cfg)) logger.info(pprint.pformat(cfg))
init_global_variable() init_global_variable()
...@@ -521,6 +609,9 @@ def main(args): ...@@ -521,6 +609,9 @@ def main(args):
inf_resize_value_check() inf_resize_value_check()
print("\nDetailed error information can be viewed in detail.log file.")
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
logger = logging.getLogger() logger = logging.getLogger()
...@@ -535,5 +626,3 @@ if __name__ == "__main__": ...@@ -535,5 +626,3 @@ if __name__ == "__main__":
logger.addHandler(sh) logger.addHandler(sh)
logger.addHandler(th) logger.addHandler(th)
main(args) main(args)
...@@ -361,7 +361,7 @@ def hsv_color_jitter(crop_img, ...@@ -361,7 +361,7 @@ def hsv_color_jitter(crop_img,
if brightness_jitter_ratio > 0 or \ if brightness_jitter_ratio > 0 or \
saturation_jitter_ratio > 0 or \ saturation_jitter_ratio > 0 or \
contrast_jitter_ratio > 0: contrast_jitter_ratio > 0:
random_jitter(crop_img, saturation_jitter_ratio, crop_img = random_jitter(crop_img, saturation_jitter_ratio,
brightness_jitter_ratio, contrast_jitter_ratio) brightness_jitter_ratio, contrast_jitter_ratio)
return crop_img return crop_img
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os.path
import argparse
import warnings
def parse_args():
parser = argparse.ArgumentParser(
description='PaddleSeg generate file list on cityscapes or your customized dataset.')
parser.add_argument(
'dataset_root',
help='dataset root directory',
type=str
)
parser.add_argument(
'--type',
help='dataset type: \n'
'- cityscapes \n'
'- custom(default)',
default="custom",
type=str
)
parser.add_argument(
'--separator',
dest='separator',
help='file list separator',
default="|",
type=str
)
parser.add_argument(
'--folder',
help='the folder names of images and labels',
type=str,
nargs=2,
default=['images', 'annotations']
)
parser.add_argument(
'--second_folder',
help='the second-level folder names of train set, validation set, test set',
type=str,
nargs='*',
default=['train', 'val', 'test']
)
parser.add_argument(
'--format',
help='data format of images and labels, e.g. jpg or png.',
type=str,
nargs=2,
default=['jpg', 'png']
)
parser.add_argument(
'--postfix',
help='postfix of images or labels',
type=str,
nargs=2,
default=['', '']
)
return parser.parse_args()
def cityscape_cfg(args):
args.postfix = ['_leftImg8bit', '_gtFine_labelTrainIds']
args.folder = ['leftImg8bit', 'gtFine']
args.format = ['png', 'png']
def get_files(image_or_label, dataset_split, args):
dataset_root = args.dataset_root
postfix = args.postfix
format = args.format
folder = args.folder
pattern = '*%s.%s' % (postfix[image_or_label], format[image_or_label])
search_files = os.path.join(dataset_root, folder[image_or_label],
dataset_split, pattern)
search_files2 = os.path.join(dataset_root, folder[image_or_label],
dataset_split, "*", pattern) # 包含子目录
search_files3 = os.path.join(dataset_root, folder[image_or_label],
dataset_split, "*", "*", pattern) # 包含三级目录
filenames = glob.glob(search_files)
filenames2 = glob.glob(search_files2)
filenames3 = glob.glob(search_files3)
filenames = filenames + filenames2 + filenames3
return sorted(filenames)
def generate_list(args):
dataset_root = args.dataset_root
separator = args.separator
for dataset_split in args.second_folder:
print("Creating {}.txt...".format(dataset_split))
image_files = get_files(0, dataset_split, args)
label_files = get_files(1, dataset_split, args)
if not image_files:
img_dir = os.path.join(dataset_root, args.folder[0], dataset_split)
print("No files in {}".format(img_dir))
num_images = len(image_files)
if not label_files:
label_dir = os.path.join(dataset_root, args.folder[1], dataset_split)
print("No files in {}".format(label_dir))
num_label = len(label_files)
if num_images < num_label:
warnings.warn("number of images = {} < number of labels = {}."
.format(num_images, num_label))
continue
file_list = os.path.join(dataset_root, dataset_split + '.txt')
with open(file_list, "w") as f:
for item in range(num_images):
left = image_files[item].replace(dataset_root, '')
if left[0] == os.path.sep:
left = left.lstrip(os.path.sep)
try:
right = label_files[item].replace(dataset_root, '')
if right[0] == os.path.sep:
right = right.lstrip(os.path.sep)
line = left + separator + right + '\n'
except:
line = left + '\n'
f.write(line)
print(line)
if __name__ == '__main__':
args = parse_args()
if args.type == 'cityscapes':
cityscape_cfg(args)
generate_list(args)
...@@ -88,7 +88,7 @@ class SegConfig(dict): ...@@ -88,7 +88,7 @@ class SegConfig(dict):
except KeyError: except KeyError:
raise KeyError('Non-existent config key: {}'.format(key)) raise KeyError('Non-existent config key: {}'.format(key))
def check_and_infer(self, reset_dataset=False): def check_and_infer(self):
if self.DATASET.IMAGE_TYPE in ['rgb', 'gray']: if self.DATASET.IMAGE_TYPE in ['rgb', 'gray']:
self.DATASET.DATA_DIM = 3 self.DATASET.DATA_DIM = 3
elif self.DATASET.IMAGE_TYPE in ['rgba']: elif self.DATASET.IMAGE_TYPE in ['rgba']:
...@@ -110,17 +110,13 @@ class SegConfig(dict): ...@@ -110,17 +110,13 @@ class SegConfig(dict):
'EVAL_CROP_SIZE is empty! Please set a pair of values in format (width, height)' 'EVAL_CROP_SIZE is empty! Please set a pair of values in format (width, height)'
) )
if reset_dataset: # Ensure file list is use UTF-8 encoding
# Ensure file list is use UTF-8 encoding train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r', 'utf-8').readlines()
train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r', val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r', 'utf-8').readlines()
'utf-8').readlines() test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r', 'utf-8').readlines()
val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r', self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets)
'utf-8').readlines() self.DATASET.VAL_TOTAL_IMAGES = len(val_sets)
test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r', self.DATASET.TEST_TOTAL_IMAGES = len(test_sets)
'utf-8').readlines()
self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets)
self.DATASET.VAL_TOTAL_IMAGES = len(val_sets)
self.DATASET.TEST_TOTAL_IMAGES = len(test_sets)
if self.MODEL.MODEL_NAME == 'icnet' and \ if self.MODEL.MODEL_NAME == 'icnet' and \
len(self.MODEL.MULTI_LOSS_WEIGHT) != 3: len(self.MODEL.MULTI_LOSS_WEIGHT) != 3:
......
...@@ -31,10 +31,10 @@ MODEL: ...@@ -31,10 +31,10 @@ MODEL:
ASPP_WITH_SEP_CONV: True ASPP_WITH_SEP_CONV: True
DECODER_USE_SEP_CONV: True DECODER_USE_SEP_CONV: True
TEST: TEST:
TEST_MODEL: "snapshots/cityscape_v5/final/" TEST_MODEL: "./saved_model/cityscape_v5/final/"
TRAIN: TRAIN:
MODEL_SAVE_DIR: "snapshots/cityscape_v5/" MODEL_SAVE_DIR: "./saved_model/cityscape_v5/"
PRETRAINED_MODEL_DIR: "pretrain/deeplabv3plus_gn_init" PRETRAINED_MODEL_DIR: "pretrained_model/deeplabv3plus_gn_init"
SNAPSHOT_EPOCH: 10 SNAPSHOT_EPOCH: 10
SOLVER: SOLVER:
LR: 0.001 LR: 0.001
......
...@@ -12,15 +12,15 @@ AUG: ...@@ -12,15 +12,15 @@ AUG:
MIN_SCALE_FACTOR: 0.75 # for stepscaling MIN_SCALE_FACTOR: 0.75 # for stepscaling
SCALE_STEP_SIZE: 0.25 # for stepscaling SCALE_STEP_SIZE: 0.25 # for stepscaling
MIRROR: True MIRROR: True
BATCH_SIZE: 6 BATCH_SIZE: 4
DATASET: DATASET:
DATA_DIR: "./dataset/pet/" DATA_DIR: "./dataset/mini_pet/"
IMAGE_TYPE: "rgb" # choice rgb or rgba IMAGE_TYPE: "rgb" # choice rgb or rgba
NUM_CLASSES: 4 # including ignore NUM_CLASSES: 3
TEST_FILE_LIST: "./dataset/pet/test_list.txt" TEST_FILE_LIST: "./dataset/mini_pet/file_list/test_list.txt"
TRAIN_FILE_LIST: "./dataset/pet/train_list.txt" TRAIN_FILE_LIST: "./dataset/mini_pet/file_list/train_list.txt"
VAL_FILE_LIST: "./dataset/pet/val_list.txt" VAL_FILE_LIST: "./dataset/mini_pet/file_list/val_list.txt"
VIS_FILE_LIST: "./dataset/pet/val_list.txt" VIS_FILE_LIST: "./dataset/mini_pet/file_list/test_list.txt"
IGNORE_INDEX: 255 IGNORE_INDEX: 255
SEPARATOR: " " SEPARATOR: " "
FREEZE: FREEZE:
...@@ -30,13 +30,13 @@ MODEL: ...@@ -30,13 +30,13 @@ MODEL:
MODEL_NAME: "unet" MODEL_NAME: "unet"
DEFAULT_NORM_TYPE: "bn" DEFAULT_NORM_TYPE: "bn"
TEST: TEST:
TEST_MODEL: "./test/saved_model/unet_pet/final/" TEST_MODEL: "./saved_model/unet_pet/final/"
TRAIN: TRAIN:
MODEL_SAVE_DIR: "./test/saved_models/unet_pet/" MODEL_SAVE_DIR: "./saved_model/unet_pet/"
PRETRAINED_MODEL_DIR: "./test/models/unet_coco/" PRETRAINED_MODEL_DIR: "./test/models/unet_coco_init/"
SNAPSHOT_EPOCH: 10 SNAPSHOT_EPOCH: 10
SOLVER: SOLVER:
NUM_EPOCHS: 500 NUM_EPOCHS: 100
LR: 0.005 LR: 0.005
LR_POLICY: "poly" LR_POLICY: "poly"
OPTIMIZER: "adam" OPTIMIZER: "adam"
...@@ -50,7 +50,7 @@ if __name__ == "__main__": ...@@ -50,7 +50,7 @@ if __name__ == "__main__":
dest="devices", dest="devices",
help="GPU id of running. if more than one, use spacing to separate.", help="GPU id of running. if more than one, use spacing to separate.",
nargs="+", nargs="+",
default=0, default=[0],
type=int) type=int)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -51,7 +51,7 @@ if __name__ == "__main__": ...@@ -51,7 +51,7 @@ if __name__ == "__main__":
dest="devices", dest="devices",
help="GPU id of running. if more than one, use spacing to separate.", help="GPU id of running. if more than one, use spacing to separate.",
nargs="+", nargs="+",
default=0, default=[0],
type=int) type=int)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -21,7 +21,7 @@ python dataset/download_pet.py ...@@ -21,7 +21,7 @@ python dataset/download_pet.py
接着下载对应的预训练模型 接着下载对应的预训练模型
```shell ```shell
python pretrained_model/download_model.py deeplabv3p_xception65_bn_cityscapes python pretrained_model/download_model.py deeplabv3p_xception65_bn_coco
``` ```
## 三. 准备配置 ## 三. 准备配置
...@@ -47,7 +47,7 @@ python pretrained_model/download_model.py deeplabv3p_xception65_bn_cityscapes ...@@ -47,7 +47,7 @@ python pretrained_model/download_model.py deeplabv3p_xception65_bn_cityscapes
数据集的配置和数据路径有关,在本教程中,数据存放在`dataset/mini_pet` 数据集的配置和数据路径有关,在本教程中,数据存放在`dataset/mini_pet`
其他配置则根据数据集和机器环境的情况进行调节,最终我们保存一个如下内容的yaml配置文件,存放路径为`configs/test_deeplabv3p_pet.yaml` 其他配置则根据数据集和机器环境的情况进行调节,最终我们保存一个如下内容的yaml配置文件,存放路径为**configs/deeplabv3p_xception65_pet.yaml**
```yaml ```yaml
# 数据集配置 # 数据集配置
...@@ -91,7 +91,7 @@ SOLVER: ...@@ -91,7 +91,7 @@ SOLVER:
在开始训练和评估之前,我们还需要对配置和数据进行一次校验,确保数据和配置是正确的。使用下述命令启动校验流程 在开始训练和评估之前,我们还需要对配置和数据进行一次校验,确保数据和配置是正确的。使用下述命令启动校验流程
```shell ```shell
python pdseg/check.py --cfg ./configs/test_deeplabv3p_pet.yaml python pdseg/check.py --cfg ./configs/deeplabv3p_xception65_pet.yaml
``` ```
...@@ -100,7 +100,7 @@ python pdseg/check.py --cfg ./configs/test_deeplabv3p_pet.yaml ...@@ -100,7 +100,7 @@ python pdseg/check.py --cfg ./configs/test_deeplabv3p_pet.yaml
校验通过后,使用下述命令启动训练 校验通过后,使用下述命令启动训练
```shell ```shell
python pdseg/train.py --use_gpu --cfg ./configs/test_deeplabv3p_pet.yaml python pdseg/train.py --use_gpu --cfg ./configs/deeplabv3p_xception65_pet.yaml
``` ```
## 六. 进行评估 ## 六. 进行评估
...@@ -108,7 +108,7 @@ python pdseg/train.py --use_gpu --cfg ./configs/test_deeplabv3p_pet.yaml ...@@ -108,7 +108,7 @@ python pdseg/train.py --use_gpu --cfg ./configs/test_deeplabv3p_pet.yaml
模型训练完成,使用下述命令启动评估 模型训练完成,使用下述命令启动评估
```shell ```shell
python pdseg/eval.py --use_gpu --cfg ./configs/test_deeplabv3p_pet.yaml python pdseg/eval.py --use_gpu --cfg ./configs/deeplabv3p_xception65_pet.yaml
``` ```
## 模型组合 ## 模型组合
...@@ -123,7 +123,7 @@ python pdseg/eval.py --use_gpu --cfg ./configs/test_deeplabv3p_pet.yaml ...@@ -123,7 +123,7 @@ python pdseg/eval.py --use_gpu --cfg ./configs/test_deeplabv3p_pet.yaml
|xception41_imagenet|-|bn|ImageNet|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: xception_41 <br> MODEL.DEFAULT_NORM_TYPE: bn| |xception41_imagenet|-|bn|ImageNet|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: xception_41 <br> MODEL.DEFAULT_NORM_TYPE: bn|
|xception65_imagenet|-|bn|ImageNet|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: xception_65 <br> MODEL.DEFAULT_NORM_TYPE: bn| |xception65_imagenet|-|bn|ImageNet|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: xception_65 <br> MODEL.DEFAULT_NORM_TYPE: bn|
|deeplabv3p_mobilenetv2-1-0_bn_coco|MobileNet V2|bn|COCO|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: mobilenet <br> MODEL.DEEPLAB.DEPTH_MULTIPLIER: 1.0 <br> MODEL.DEEPLAB.ENCODER_WITH_ASPP: False <br> MODEL.DEEPLAB.ENABLE_DECODER: False <br> MODEL.DEFAULT_NORM_TYPE: bn| |deeplabv3p_mobilenetv2-1-0_bn_coco|MobileNet V2|bn|COCO|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: mobilenet <br> MODEL.DEEPLAB.DEPTH_MULTIPLIER: 1.0 <br> MODEL.DEEPLAB.ENCODER_WITH_ASPP: False <br> MODEL.DEEPLAB.ENABLE_DECODER: False <br> MODEL.DEFAULT_NORM_TYPE: bn|
|deeplabv3p_xception65_bn_coco|Xception|bn|COCO|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: xception_65 <br> MODEL.DEFAULT_NORM_TYPE: bn | |**deeplabv3p_xception65_bn_coco**|Xception|bn|COCO|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: xception_65 <br> MODEL.DEFAULT_NORM_TYPE: bn |
|deeplabv3p_mobilenetv2-1-0_bn_cityscapes|MobileNet V2|bn|Cityscapes|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: mobilenet <br> MODEL.DEEPLAB.DEPTH_MULTIPLIER: 1.0 <br> MODEL.DEEPLAB.ENCODER_WITH_ASPP: False <br> MODEL.DEEPLAB.ENABLE_DECODER: False <br> MODEL.DEFAULT_NORM_TYPE: bn| |deeplabv3p_mobilenetv2-1-0_bn_cityscapes|MobileNet V2|bn|Cityscapes|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: mobilenet <br> MODEL.DEEPLAB.DEPTH_MULTIPLIER: 1.0 <br> MODEL.DEEPLAB.ENCODER_WITH_ASPP: False <br> MODEL.DEEPLAB.ENABLE_DECODER: False <br> MODEL.DEFAULT_NORM_TYPE: bn|
|deeplabv3p_xception65_gn_cityscapes|Xception|gn|Cityscapes|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: xception_65 <br> MODEL.DEFAULT_NORM_TYPE: gn| |deeplabv3p_xception65_gn_cityscapes|Xception|gn|Cityscapes|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: xception_65 <br> MODEL.DEFAULT_NORM_TYPE: gn|
|**deeplabv3p_xception65_bn_cityscapes**|Xception|bn|Cityscapes|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: xception_65 <br> MODEL.DEFAULT_NORM_TYPE: bn| |deeplabv3p_xception65_bn_cityscapes|Xception|bn|Cityscapes|MODEL.MODEL_NAME: deeplabv3p <br> MODEL.DEEPLAB.BACKBONE: xception_65 <br> MODEL.DEFAULT_NORM_TYPE: bn|
...@@ -47,7 +47,7 @@ python pretrained_model/download_model.py icnet_bn_cityscapes ...@@ -47,7 +47,7 @@ python pretrained_model/download_model.py icnet_bn_cityscapes
数据集的配置和数据路径有关,在本教程中,数据存放在`dataset/mini_pet` 数据集的配置和数据路径有关,在本教程中,数据存放在`dataset/mini_pet`
其他配置则根据数据集和机器环境的情况进行调节,最终我们保存一个如下内容的yaml配置文件,存放路径为`configs/test_pet.yaml` 其他配置则根据数据集和机器环境的情况进行调节,最终我们保存一个如下内容的yaml配置文件,存放路径为**configs/icnet_pet.yaml**
```yaml ```yaml
# 数据集配置 # 数据集配置
...@@ -93,7 +93,7 @@ SOLVER: ...@@ -93,7 +93,7 @@ SOLVER:
在开始训练和评估之前,我们还需要对配置和数据进行一次校验,确保数据和配置是正确的。使用下述命令启动校验流程 在开始训练和评估之前,我们还需要对配置和数据进行一次校验,确保数据和配置是正确的。使用下述命令启动校验流程
```shell ```shell
python pdseg/check.py --cfg ./configs/test_pet.yaml python pdseg/check.py --cfg ./configs/icnet_pet.yaml
``` ```
...@@ -102,7 +102,7 @@ python pdseg/check.py --cfg ./configs/test_pet.yaml ...@@ -102,7 +102,7 @@ python pdseg/check.py --cfg ./configs/test_pet.yaml
校验通过后,使用下述命令启动训练 校验通过后,使用下述命令启动训练
```shell ```shell
python pdseg/train.py --use_gpu --cfg ./configs/test_pet.yaml python pdseg/train.py --use_gpu --cfg ./configs/icnet_pet.yaml
``` ```
## 六. 进行评估 ## 六. 进行评估
...@@ -110,11 +110,11 @@ python pdseg/train.py --use_gpu --cfg ./configs/test_pet.yaml ...@@ -110,11 +110,11 @@ python pdseg/train.py --use_gpu --cfg ./configs/test_pet.yaml
模型训练完成,使用下述命令启动评估 模型训练完成,使用下述命令启动评估
```shell ```shell
python pdseg/eval.py --use_gpu --cfg ./configs/test_pet.yaml python pdseg/eval.py --use_gpu --cfg ./configs/icnet_pet.yaml
``` ```
## 模型组合 ## 模型组合
|预训练模型名称|BackBone|Norm|数据集|配置| |预训练模型名称|BackBone|Norm|数据集|配置|
|-|-|-|-|-| |-|-|-|-|-|
|icnet_bn_cityscapes|-|bn|Cityscapes|MODEL.MODEL_NAME: icnet <br> MODEL.DEFAULT_NORM_TYPE: bn <br> MULTI_LOSS_WEIGHT: [1.0, 0.4, 0.16]| |icnet_bn_cityscapes|-|bn|Cityscapes|MODEL.MODEL_NAME: icnet <br> MODEL.DEFAULT_NORM_TYPE: bn <br> MODEL.MULTI_LOSS_WEIGHT: [1.0, 0.4, 0.16]|
...@@ -47,7 +47,7 @@ python pretrained_model/download_model.py unet_bn_coco ...@@ -47,7 +47,7 @@ python pretrained_model/download_model.py unet_bn_coco
数据集的配置和数据路径有关,在本教程中,数据存放在`dataset/mini_pet` 数据集的配置和数据路径有关,在本教程中,数据存放在`dataset/mini_pet`
其他配置则根据数据集和机器环境的情况进行调节,最终我们保存一个如下内容的yaml配置文件,存放路径为`configs/test_unet_pet.yaml` 其他配置则根据数据集和机器环境的情况进行调节,最终我们保存一个如下内容的yaml配置文件,存放路径为**configs/unet_pet.yaml**
```yaml ```yaml
# 数据集配置 # 数据集配置
...@@ -90,7 +90,7 @@ SOLVER: ...@@ -90,7 +90,7 @@ SOLVER:
在开始训练和评估之前,我们还需要对配置和数据进行一次校验,确保数据和配置是正确的。使用下述命令启动校验流程 在开始训练和评估之前,我们还需要对配置和数据进行一次校验,确保数据和配置是正确的。使用下述命令启动校验流程
```shell ```shell
python pdseg/check.py --cfg ./configs/test_unet_pet.yaml python pdseg/check.py --cfg ./configs/unet_pet.yaml
``` ```
...@@ -99,7 +99,7 @@ python pdseg/check.py --cfg ./configs/test_unet_pet.yaml ...@@ -99,7 +99,7 @@ python pdseg/check.py --cfg ./configs/test_unet_pet.yaml
校验通过后,使用下述命令启动训练 校验通过后,使用下述命令启动训练
```shell ```shell
python pdseg/train.py --use_gpu --cfg ./configs/test_unet_pet.yaml python pdseg/train.py --use_gpu --cfg ./configs/unet_pet.yaml
``` ```
## 六. 进行评估 ## 六. 进行评估
...@@ -107,11 +107,11 @@ python pdseg/train.py --use_gpu --cfg ./configs/test_unet_pet.yaml ...@@ -107,11 +107,11 @@ python pdseg/train.py --use_gpu --cfg ./configs/test_unet_pet.yaml
模型训练完成,使用下述命令启动评估 模型训练完成,使用下述命令启动评估
```shell ```shell
python pdseg/eval.py --use_gpu --cfg ./configs/test_unet_pet.yaml python pdseg/eval.py --use_gpu --cfg ./configs/unet_pet.yaml
``` ```
## 模型组合 ## 模型组合
|预训练模型名称|BackBone|Norm|数据集|配置| |预训练模型名称|BackBone|Norm|数据集|配置|
|-|-|-|-|-| |-|-|-|-|-|
|unet_bn_coco|-|bn|Cityscapes|MODEL.MODEL_NAME: unet <br> MODEL.DEFAULT_NORM_TYPE: bn| |unet_bn_coco|-|bn|COCO|MODEL.MODEL_NAME: unet <br> MODEL.DEFAULT_NORM_TYPE: bn|
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册