Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
b02375b2
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b02375b2
编写于
4月 09, 2021
作者:
W
wangxinxin08
提交者:
GitHub
4月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add docs for custom dataset, reader and model technical (#2561)
上级
2fb39e1d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
1179 addition
and
2 deletion
+1179
-2
dygraph/README.md
dygraph/README.md
+3
-0
dygraph/docs/advanced_tutorials/MODEL_TECHNICAL.md
dygraph/docs/advanced_tutorials/MODEL_TECHNICAL.md
+407
-0
dygraph/docs/advanced_tutorials/READER.md
dygraph/docs/advanced_tutorials/READER.md
+328
-0
dygraph/docs/images/model_figure.png
dygraph/docs/images/model_figure.png
+0
-0
dygraph/docs/images/reader_figure.png
dygraph/docs/images/reader_figure.png
+0
-0
dygraph/docs/tutorials/PrepareDataSet.md
dygraph/docs/tutorials/PrepareDataSet.md
+2
-2
dygraph/tools/x2coco.py
dygraph/tools/x2coco.py
+439
-0
未找到文件。
dygraph/README.md
浏览文件 @
b02375b2
...
@@ -142,6 +142,9 @@ PaddleDetection模块化地实现了多种主流目标检测算法,提供了
...
@@ -142,6 +142,9 @@ PaddleDetection模块化地实现了多种主流目标检测算法,提供了
-
[
Python端推理部署
](
deploy/python
)
-
[
Python端推理部署
](
deploy/python
)
-
[
C++端推理部署
](
deploy/cpp
)
-
[
C++端推理部署
](
deploy/cpp
)
-
[
服务端部署
](
deploy/serving
)
-
[
服务端部署
](
deploy/serving
)
-
[进阶开发]
-
[
数据处理模块
](
docs/advanced_tutorials/READER.md
)
-
[
新增检测模型
](
docs/advanced_tutorials/MODEL_TECHNICAL.md
)
## 模型库
## 模型库
...
...
dygraph/docs/advanced_tutorials/MODEL_TECHNICAL.md
0 → 100644
浏览文件 @
b02375b2
# 新增模型算法
为了让用户更好的使用PaddleDetection,本文档中,我们将介绍PaddleDetection的主要模型技术细节及应用
## 目录
-
[
1.简介
](
#1.简介
)
-
[
2.新增模型
](
#2.新增模型
)
-
[
2.1新增网络结构
](
#2.1新增网络结构
)
-
[
2.1.1新增Backbone
](
#2.1.1新增Backbone
)
-
[
2.1.2新增Neck
](
#2.1.2新增Neck
)
-
[
2.1.3新增Head
](
#2.1.3新增Head
)
-
[
2.1.4新增Loss
](
#2.1.4新增Loss
)
-
[
2.1.5新增后处理模块
](
#2.1.5新增后处理模块
)
-
[
2.1.6新增Architecture
](
#2.1.6新增Architecture
)
-
[
2.2新增配置文件
](
#2.2新增配置文件
)
-
[
2.2.1网络结构配置文件
](
#2.2.1网络结构配置文件
)
-
[
2.2.2优化器配置文件
](
#2.2.2优化器配置文件
)
-
[
2.2.3Reader配置文件
](
#2.2.3Reader配置文件
)
### 1.简介
PaddleDetecion中的每一种模型对应一个文件夹,以yolov3为例,yolov3系列的模型对应于
`configs/yolov3`
文件夹,其中yolov3_darknet的总配置文件
`configs/yolov3/yolov3_darknet53_270e_coco.yml`
的内容如下:
```
_BASE_: [
'../datasets/coco_detection.yml', # 数据集配置文件,所有模型共用
'../runtime.yml', # 运行时相关配置
'_base_/optimizer_270e.yml', # 优化器相关配置
'_base_/yolov3_darknet53.yml', # yolov3网络结构配置文件
'_base_/yolov3_reader.yml', # yolov3 Reader模块配置
]
# 定义在此处的相关配置可以覆盖上述文件中的同名配置
snapshot_epoch: 5
weights: output/yolov3_darknet53_270e_coco/model_final
```
可以看到,配置文件中的模块进行了清晰的划分,除了公共的数据集配置以及运行时配置,其他配置被划分为优化器,网络结构以及Reader模块。PaddleDetection中支持丰富的优化器,学习率调整策略,预处理算子等,因此大多数情况下不需要编写优化器以及Reader相关的代码,而只需要在配置文件中配置即可。因此,新增一个模型的主要在于搭建网络结构。
PaddleDetection网络结构的代码在
`ppdet/modeling/`
中,所有网络结构以组件的形式进行定义与组合,网络结构的主要构成如下所示:
```
ppdet/modeling/
├── architectures
│ ├── faster_rcnn.py # Faster Rcnn模型
│ ├── ssd.py # SSD模型
│ ├── yolo.py # YOLOv3模型
│ │ ...
├── heads # 检测头模块
│ ├── xxx_head.py # 定义各类检测头
│ ├── roi_extractor.py #检测感兴趣区域提取
├── backbones # 基干网络模块
│ ├── resnet.py # ResNet网络
│ ├── mobilenet.py # MobileNet网络
│ │ ...
├── losses # 损失函数模块
│ ├── xxx_loss.py # 定义注册各类loss函数
├── necks # 特征融合模块
│ ├── xxx_fpn.py # 定义各种FPN模块
├── proposal_generator # anchor & proposal生成与匹配模块
│ ├── anchor_generator.py # anchor生成模块
│ ├── proposal_generator.py # proposal生成模块
│ ├── target.py # anchor & proposal的匹配函数
│ ├── target_layer.py # anchor & proposal的匹配模块
├── tests # 单元测试模块
│ ├── test_xxx.py # 对网络中的算子以及模块结构进行单元测试
├── ops.py # 封装各类PaddlePaddle物体检测相关公共检测组件/算子
├── layers.py # 封装及注册各类PaddlePaddle物体检测相关公共检测组件/算子
├── bbox_utils.py # 封装检测框相关的函数
├── post_process.py # 封装及注册后处理相关模块
├── shape_spec.py # 定义模块输出shape的类
```
![](
../images/model_figure.png
)
### 2.新增模型
接下来,以单阶段检测器YOLOv3为例,对建立模型过程进行详细描述,按照此思路您可以快速搭建新的模型。
#### 2.1新增网络结构
##### 2.1.1新增Backbone
PaddleDetection中现有所有Backbone网络代码都放置在
`ppdet/modeling/backbones`
目录下,所以我们在其中新建
`darknet.py`
如下:
```
python
import
paddle.nn
as
nn
from
ppdet.core.workspace
import
register
,
serializable
@
register
@
serializable
class
DarkNet
(
nn
.
Layer
):
__shared__
=
[
'norm_type'
]
def
__init__
(
self
,
depth
=
53
,
return_idx
=
[
2
,
3
,
4
],
norm_type
=
'bn'
,
norm_decay
=
0.
):
super
(
DarkNet
,
self
).
__init__
()
# 省略内容
def
forward
(
self
,
inputs
):
# 省略处理逻辑
pass
@
property
def
out_shape
(
self
):
# 省略内容
pass
```
然后在
`backbones/__init__.py`
中加入引用:
```
python
from
.
import
darknet
from
.darknet
import
*
```
**几点说明:**
-
为了在yaml配置文件中灵活配置网络,所有Backbone需要利用
`ppdet.core.workspace`
里的
`register`
进行注册,形式请参考如上示例。此外,可以使用
`serializable`
以使backbone支持序列化;
-
所有的Backbone需继承
`paddle.nn.Layer`
类,并实现forward函数。此外,还需实现out_shape属性定义输出的feature map的channel信息,具体可参见源码;
-
`__shared__`
为了实现一些参数的配置全局共享,这些参数可以被backbone, neck,head,loss等所有注册模块共享。
##### 2.1.2新增Neck
特征融合模块放置在
`ppdet/modeling/necks`
目录下,我们在其中新建
`yolo_fpn.py`
如下:
```
python
import
paddle.nn
as
nn
from
ppdet.core.workspace
import
register
,
serializable
@
register
@
serializable
class
YOLOv3FPN
(
nn
.
Layer
):
__shared__
=
[
'norm_type'
]
def
__init__
(
self
,
in_channels
=
[
256
,
512
,
1024
],
norm_type
=
'bn'
):
super
(
YOLOv3FPN
,
self
).
__init__
()
# 省略内容
def
forward
(
self
,
blocks
):
# 省略内容
pass
@
classmethod
def
from_config
(
cls
,
cfg
,
input_shape
):
# 省略内容
pass
@
property
def
out_shape
(
self
):
# 省略内容
pass
```
然后在
`necks/__init__.py`
中加入引用:
```
python
from
.
import
yolo_fpn
from
.yolo_fpn
import
*
```
**几点说明:**
-
neck模块需要使用
`register`
进行注册,可以使用
`serializable`
进行序列化;
-
neck模块需要继承
`paddle.nn.Layer`
类,并实现forward函数。除此之外,还需要实现
`out_shape`
属性,用于定义输出的feature map的channel信息,还需要实现类函数
`from_config`
用于在配置文件中推理出输入channel,并用于
`YOLOv3FPN`
的初始化;
-
neck模块可以使用
`__shared__`
实现一些参数的配置全局共享。
##### 2.1.3新增Head
Head模块全部存放在
`ppdet/modeling/heads`
目录下,我们在其中新建
`yolo_head.py`
如下
```
python
import
paddle.nn
as
nn
from
ppdet.core.workspace
import
register
@
register
class
YOLOv3Head
(
nn
.
Layer
):
__shared__
=
[
'num_classes'
]
__inject__
=
[
'loss'
]
def
__init__
(
self
,
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]],
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]],
num_classes
=
80
,
loss
=
'YOLOv3Loss'
,
iou_aware
=
False
,
iou_aware_factor
=
0.4
):
super
(
YOLOv3Head
,
self
).
__init__
()
# 省略内容
def
forward
(
self
,
feats
,
targets
=
None
):
# 省略内容
pass
```
然后在
`heads/__init__.py`
中加入引用:
```
python
from
.
import
yolo_head
from
.yolo_head
import
*
```
**几点说明:**
-
Head模块需要使用
`register`
进行注册;
-
Head模块需要继承
`paddle.nn.Layer`
类,并实现forward函数。
-
`__inject__`
表示引入全局字典中已经封装好的模块。如loss等。
##### 2.1.4新增Loss
Loss模块全部存放在
`ppdet/modeling/losses`
目录下,我们在其中新建
`yolo_loss.py`
下
```
python
import
paddle.nn
as
nn
from
ppdet.core.workspace
import
register
@
register
class
YOLOv3Loss
(
nn
.
Layer
):
__inject__
=
[
'iou_loss'
,
'iou_aware_loss'
]
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
num_classes
=
80
,
ignore_thresh
=
0.7
,
label_smooth
=
False
,
downsample
=
[
32
,
16
,
8
],
scale_x_y
=
1.
,
iou_loss
=
None
,
iou_aware_loss
=
None
):
super
(
YOLOv3Loss
,
self
).
__init__
()
# 省略内容
def
forward
(
self
,
inputs
,
targets
,
anchors
):
# 省略内容
pass
```
然后在
`losses/__init__.py`
中加入引用:
```
python
from
.
import
yolo_loss
from
.yolo_loss
import
*
```
**几点说明:**
-
loss模块需要使用
`register`
进行注册;
-
loss模块需要继承
`paddle.nn.Layer`
类,并实现forward函数。
-
可以使用
`__inject__`
表示引入全局字典中已经封装好的模块,使用
`__shared__`
可以实现一些参数的配置全局共享。
##### 2.1.5新增后处理模块
后处理模块定义在
`ppdet/modeling/post_process.py`
中,其中定义了
`BBoxPostProcess`
类来进行后处理操作,如下所示:
```
python
from
ppdet.core.workspace
import
register
@
register
class
BBoxPostProcess
(
object
):
__shared__
=
[
'num_classes'
]
__inject__
=
[
'decode'
,
'nms'
]
def
__init__
(
self
,
num_classes
=
80
,
decode
=
None
,
nms
=
None
):
# 省略内容
pass
def
__call__
(
self
,
head_out
,
rois
,
im_shape
,
scale_factor
):
# 省略内容
pass
```
**几点说明:**
-
后处理模块需要使用
`register`
进行注册
-
`__inject__`
注入了全局字典中封装好的模块,如decode和nms等。decode和nms定义在
`ppdet/modeling/layers.py`
中。
##### 2.1.6新增Architecture
所有architecture网络代码都放置在
`ppdet/modeling/architectures`
目录下,
`meta_arch.py`
中定义了
`BaseArch`
类,代码如下:
```
python
import
paddle.nn
as
nn
from
ppdet.core.workspace
import
register
@
register
class
BaseArch
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
BaseArch
,
self
).
__init__
()
def
forward
(
self
,
inputs
):
self
.
inputs
=
inputs
self
.
model_arch
()
if
self
.
training
:
out
=
self
.
get_loss
()
else
:
out
=
self
.
get_pred
()
return
out
def
model_arch
(
self
,
):
pass
def
get_loss
(
self
,
):
raise
NotImplementedError
(
"Should implement get_loss method!"
)
def
get_pred
(
self
,
):
raise
NotImplementedError
(
"Should implement get_pred method!"
)
```
所有的architecture需要继承
`BaseArch`
类,如
`yolo.py`
中的
`YOLOv3`
定义如下:
```
python
@
register
class
YOLOv3
(
BaseArch
):
__category__
=
'architecture'
__inject__
=
[
'post_process'
]
def
__init__
(
self
,
backbone
=
'DarkNet'
,
neck
=
'YOLOv3FPN'
,
yolo_head
=
'YOLOv3Head'
,
post_process
=
'BBoxPostProcess'
):
super
(
YOLOv3
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
neck
=
neck
self
.
yolo_head
=
yolo_head
self
.
post_process
=
post_process
@
classmethod
def
from_config
(
cls
,
cfg
,
*
args
,
**
kwargs
):
# 省略内容
pass
def
get_loss
(
self
):
# 省略内容
pass
def
get_pred
(
self
):
# 省略内容
pass
```
**几点说明:**
-
所有的architecture需要使用
`register`
进行注册
-
在组建一个完整的网络时必须要设定
`__category__ = 'architecture'`
来表示一个完整的物体检测模型;
-
backbone, neck, yolo_head以及post_process等检测组件传入到architecture中组成最终的网络。像这样将检测模块化,提升了检测模型的复用性,可以通过组合不同的检测组件得到多个模型。
-
from_config类函数实现了模块间组合时channel的自动配置。
#### 2.2新增配置文件
##### 2.2.1网络结构配置文件
上面详细地介绍了如何新增一个architecture,接下来演示如何配置一个模型,yolov3关于网络结构的配置在
`configs/yolov3/_base_/`
文件夹中定义,如
`yolov3_darknet53.yml`
定义了yolov3_darknet的网络结构,其定义如下:
```
architecture: YOLOv3
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/DarkNet53_pretrained.pdparams
norm_type: sync_bn
YOLOv3:
backbone: DarkNet
neck: YOLOv3FPN
yolo_head: YOLOv3Head
post_process: BBoxPostProcess
DarkNet:
depth: 53
return_idx: [2, 3, 4]
# use default config
# YOLOv3FPN:
YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.7
downsample: [32, 16, 8]
label_smooth: false
BBoxPostProcess:
decode:
name: YOLOBox
conf_thresh: 0.005
downsample_ratio: 32
clip_bbox: true
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.01
nms_threshold: 0.45
nms_top_k: 1000
```
可以看到在配置文件中,首先需要指定网络的architecture,pretrain_weights指定训练模型的url或者路径,norm_type等可以作为全局参数共享。模型的定义自上而下依次在文件中定义,与上节中的模型组件一一对应。对于一些模型组件,如果采用默认
的参数,可以不用配置,如上文中的
`yolo_fpn`
。通过改变相关配置,我们可以轻易地组合出另一个模型,比如
`configs/yolov3/_base_/yolov3_mobilenet_v1.yml`
将backbone从Darknet切换成MobileNet。
##### 2.2.2优化器配置文件
优化器配置文件定义模型使用的优化器以及学习率的调度策略,目前PaddleDetection中已经集成了多种多样的优化器和学习率策略,具体可参见代码
`ppdet/optimizer.py`
。比如,yolov3的优化器配置文件定义在
`configs/yolov3/_base_/optimizer_270e.yml`
,其定义如下:
```
epoch: 270
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
# epoch数目
- 216
- 243
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
```
**几点说明:**
-
可以通过OptimizerBuilder.optimizer指定优化器的类型及参数,目前支持的优化可以参考
[
PaddlePaddle官方文档
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/Overview_cn.html
)
-
可以设置LearningRate.schedulers设置不同学习率调整策略的组合,PaddlePaddle目前支持多种学习率调整策略,具体也可参考
[
PaddlePaddle官方文档
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/Overview_cn.html
)
。需要注意的是,你需要对于PaddlePaddle中的学习率调整策略进行简单的封装,具体可参考源码
`ppdet/optimizer.py`
。
##### 2.2.3Reader配置文件
关于Reader的配置可以参考
[
Reader配置文档
](
./READER.md#5.配置及运行
)
。
> 看过此文档,您应该对PaddleDetection中模型搭建与配置有了一定经验,结合源码会理解的更加透彻。关于模型技术,如您有其他问题或建议,请给我们提issue,我们非常欢迎您的反馈。
dygraph/docs/advanced_tutorials/READER.md
0 → 100644
浏览文件 @
b02375b2
# 数据处理模块
## 目录
-
[
1.简介
](
#1.简介
)
-
[
2.数据集
](
#2.数据集
)
-
[
2.1COCO数据集
](
#2.1COCO数据集
)
-
[
2.2Pascal VOC数据集
](
#2.2Pascal-VOC数据集
)
-
[
2.3自定义数据集
](
#2.3自定义数据集
)
-
[
3.数据预处理
](
#3.数据预处理
)
-
[
3.1数据增强算子
](
#3.1数据增强算子
)
-
[
3.2自定义数据增强算子
](
#3.2自定义数据增强算子
)
-
[
4.Raeder
](
#4.Reader
)
-
[
5.配置及运行
](
#5.配置及运行
)
-
[
5.1配置
](
#5.1配置
)
-
[
5.2运行
](
#5.2运行
)
### 1.简介
PaddleDetection的数据处理模块的所有代码逻辑在
`ppdet/data/`
中,数据处理模块用于加载数据并将其转换成适用于物体检测模型的训练、评估、推理所需要的格式。
数据处理模块的主要构成如下架构所示:
```
bash
ppdet/data/
├── reader.py
# 基于Dataloader封装的Reader模块
├──
source
# 数据源管理模块
│ ├── dataset.py
# 定义数据源基类,各类数据集继承于此
│ ├── coco.py
# COCO数据集解析与格式化数据
│ ├── voc.py
# Pascal VOC数据集解析与格式化数据
│ ├── widerface.py
# WIDER-FACE数据集解析与格式化数据
│ ├── category.py
# 相关数据集的类别信息
├── transform
# 数据预处理模块
│ ├── batch_operators.py
# 定义各类基于批量数据的预处理算子
│ ├── op_helper.py
# 预处理算子的辅助函数
│ ├── operators.py
# 定义各类基于单张图片的预处理算子
│ ├── gridmask_utils.py
# GridMask数据增强函数
│ ├── autoaugment_utils.py
# AutoAugment辅助函数
├── shm_utils.py
# 用于使用共享内存的辅助函数
```
### 2.数据集
数据集定义在
`
source
`
目录下,其中
`
dataset.py
`
中定义了数据集的基类
`
DetDataSet
`
, 所有的数据集均继承于基类,
`
DetDataset
`
基类里定义了如下等方法:
| 方法 | 输入 | 输出 | 备注 |
| :------------------------: | :----: | :------------: | :--------------: |
|
\_\_
len
\_\_
| 无 | int, 数据集中样本的数量 | 过滤掉了无标注的样本 |
|
\_\_
getitem
\_\_
| int, 样本的索引idx | dict, 索引idx对应的样本roidb | 得到transform之后的样本roidb |
| check_or_download_dataset | 无 | 无 | 检查数据集是否存在,如果不存在则下载,目前支持COCO, VOC,widerface等数据集 |
| set_kwargs | 可选参数,以键值对的形式给出 | 无 | 目前用于支持接收mixup, cutmix等参数的设置 |
| set_transform | 一系列的transform函数 | 无 | 设置数据集的transform函数 |
| set_epoch | int, 当前的epoch | 无 | 用于dataset与训练过程的交互 |
| parse_dataset | 无 | 无 | 用于从数据中读取所有的样本 |
| get_anno | 无 | 无 | 用于获取标注文件的路径 |
当一个数据集类继承自
`
DetDataSet
`
,那么它只需要实现parse_dataset函数即可。parse_dataset根据数据集设置的数据集根路径dataset_dir,图片文件夹image_dir, 标注文件路径anno_path取出所有的样本,并将其保存在一个列表roidbs中,每一个列表中的元素为一个样本xxx_rec
(
比如coco_rec或者voc_rec
)
,用dict表示,dict中包含样本的image, gt_bbox, gt_class等字段。COCO和Pascal-VOC数据集中的xxx_rec的数据结构定义如下:
```
python
xxx_rec
=
{
'im_file'
: im_fname,
# 一张图像的完整路径
'im_id'
: np.array
([
img_id]
)
,
# 一张图像的ID序号
'h'
: im_h,
# 图像高度
'w'
: im_w,
# 图像宽度
'is_crowd'
: is_crowd,
# 是否是群落对象, 默认为0 (VOC中无此字段)
'gt_class'
: gt_class,
# 标注框标签名称的ID序号
'gt_bbox'
: gt_bbox,
# 标注框坐标(xmin, ymin, xmax, ymax)
'gt_poly'
: gt_poly,
# 分割掩码,此字段只在coco_rec中出现,默认为None
'difficult'
: difficult
# 是否是困难样本,此字段只在voc_rec中出现,默认为0
}
```
xxx_rec中的内容也可以通过
`
DetDataSet
`
的data_fields参数来控制,即可以过滤掉一些不需要的字段,但大多数情况下不需要修改,按照
`
configs/dataset
`
中的默认配置即可。
此外,在parse_dataset函数中,保存了类别名到id的映射的一个字典
`
cname2cid
`
。在coco数据集中,会利用[COCO API]
(
https://github.com/cocodataset/cocoapi
)
从标注文件中加载数据集的类别名,并设置此字典。在voc数据集中,如果设置
`
use_default_label
=
False
`
,将从
`
label_list.txt
`
中读取类别列表,反之将使用voc默认的类别列表。
#### 2.1COCO数据集
COCO数据集目前分为COCO2014和COCO2017,主要由json文件和image文件组成,其组织结构如下所示:
```
dataset/coco/
├── annotations
│ ├── instances_train2014.json
│ ├── instances_train2017.json
│ ├── instances_val2014.json
│ ├── instances_val2017.json
│ │ ...
├── train2017
│ ├── 000000000009.jpg
│ ├── 000000580008.jpg
│ │ ...
├── val2017
│ ├── 000000000139.jpg
│ ├── 000000000285.jpg
│ │ ...
```
在
`
source
/coco.py
`
中定义并注册了
`
COCODataSet
`
数据集类,其继承自
`
DetDataSet
`
,并实现了parse_dataset方法,调用[COCO API]
(
https://github.com/cocodataset/cocoapi
)
加载并解析COCO格式数据源
`
roidbs
`
和
`
cname2cid
`
,具体可参见
`
source
/coco.py
`
源码。将其他数据集转换成COCO格式可以参考[用户数据转成COCO数据]
(
../tutorials/PrepareDataSet.md#用户数据转成COCO数据
)
#### 2.2Pascal VOC数据集
该数据集目前分为VOC2007和VOC2012,主要由xml文件和image文件组成,其组织结构如下所示:
```
dataset/voc/
├── trainval.txt
├── test.txt
├── label_list.txt (optional)
├── VOCdevkit/VOC2007
│ ├── Annotations
│ ├── 001789.xml
│ │ ...
│ ├── JPEGImages
│ ├── 001789.jpg
│ │ ...
│ ├── ImageSets
│ | ...
├── VOCdevkit/VOC2012
│ ├── Annotations
│ ├── 2011_003876.xml
│ │ ...
│ ├── JPEGImages
│ ├── 2011_003876.jpg
│ │ ...
│ ├── ImageSets
│ │ ...
```
在`source/voc.py`中定义并注册了`VOCDataSet`数据集,它继承自`DetDataSet`基类,并重写了`parse_dataset`方法,解析VOC数据集中xml格式标注文件,更新`roidbs`和`cname2cid`。将其他数据集转换成VOC格式可以参考[用户数据转成VOC数据](../tutorials/PrepareDataSet.md#用户数据转成VOC数据)
#### 2.3自定义数据集
如果COCODataSet和VOCDataSet不能满足你的需求,可以通过自定义数据集的方式来加载你的数据集。只需要以下两步即可实现自定义数据集
1. 新建`source/xxx.py`,定义类`XXXDataSet`继承自`DetDataSet`基类,完成注册与序列化,并重写`parse_dataset`方法对`roidbs`与`cname2cid`更新:
```
python
from ppdet.core.workspace import register, serializable
#注册并序列化
@register
@serializable
class XXXDataSet(DetDataSet):
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
...
):
self.roidbs = None
self.cname2cid = None
...
def parse_dataset(self):
...
省略具体解析数据逻辑
...
self.roidbs, self.cname2cid = records, cname2cid
```
2. 在`source/__init__.py`中添加引用:
```
python
from . import xxx
from .xxx import
*
```
完成以上两步就将新的数据源`XXXDataSet`添加好了,你可以参考[配置及运行](#配置及运行)实现自定义数据集的使用。
### 3.数据预处理
#### 3.1数据增强算子
PaddleDetection中支持了种类丰富的数据增强算子,有单图像数据增强算子与批数据增强算子两种方式,您可选取合适的算子组合使用。单图像数据增强算子定义在`transform/operators.py`中,已支持的单图像数据增强算子详见下表:
| 名称 | 作用 |
| :---------------------: | :--------------: |
| Decode | 从图像文件或内存buffer中加载图像,格式为RGB格式 |
| Permute | 假如输入是HWC顺序变成CHW |
| RandomErasingImage | 对图像进行随机擦除 |
| NormalizeImage | 对图像像素值进行归一化,如果设置is_scale=True,则先将像素值除以255.0, 再进行归一化。 |
| GridMask | GridMask数据增广 |
| RandomDistort | 随机扰动图片亮度、对比度、饱和度和色相 |
| AutoAugment | AutoAugment数据增广,包含一系列数据增强方法 |
| RandomFlip | 随机水平翻转图像 |
| Resize | 对于图像进行resize,并对标注进行相应的变换 |
| MultiscaleTestResize | 将图像重新缩放为多尺度list的每个尺寸 |
| RandomResize | 对于图像进行随机Resize,可以Resize到不同的尺寸以及使用不同的插值策略 |
| RandomExpand | 将原始图片放入用像素均值填充的扩张图中,对此图进行裁剪、缩放和翻转 |
| CropWithSampling | 根据缩放比例、长宽比例生成若干候选框,再依据这些候选框和标注框的面积交并比(IoU)挑选出符合要求的裁剪结果 |
| CropImageWithDataAchorSampling | 基于CropImage,在人脸检测中,随机将图片尺度变换到一定范围的尺度,大大增强人脸的尺度变化 |
| RandomCrop | 原理同CropImage,以随机比例与IoU阈值进行处理 |
| RandomScaledCrop | 根据长边对图像进行随机裁剪,并对标注做相应的变换 |
| Cutmix | Cutmix数据增强,对两张图片做拼接 |
| Mixup | Mixup数据增强,按比例叠加两张图像 |
| NormalizeBox | 对bounding box进行归一化 |
| PadBox | 如果bounding box的数量少于num_max_boxes,则将零填充到bbox |
| BboxXYXY2XYWH | 将bounding box从(xmin,ymin,xmax,ymin)形式转换为(xmin,ymin,width,height)格式 |
| Pad | 将图片Pad某一个数的整数倍或者指定的size,并支持指定Pad的方式 |
| Poly2Mask | Poly2Mask数据增强 |
批数据增强算子定义在`transform/batch_operators.py`中, 目前支持的算子列表如下:
| 名称 | 作用 |
| :---------------------: | :--------------: |
| PadBatch | 随机对每个batch的数据图片进行Pad操作,使得batch中的图片具有相同的shape |
| BatchRandomResize | 对一个batch的图片进行resize,使得batch中的图片随机缩放到相同的尺寸 |
| Gt2YoloTarget | 通过gt数据生成YOLO系列模型的目标 |
| Gt2FCOSTarget | 通过gt数据生成FCOS模型的目标 |
| Gt2TTFTarget | 通过gt数据生成TTFNet模型的目标 |
| Gt2Solov2Target | 通过gt数据生成SOLOv2模型的目标 |
**几点说明:**
- 数据增强算子的输入为sample或者samples,每一个sample对应上文所说的`DetDataSet`输出的roidbs中的一个样本,如coco_rec或者voc_rec
- 单图像数据增强算子(Mixup, Cutmix等除外)也可用于批数据处理中。但是,单图像处理算子和批图像处理算子仍有一些差异,以RandomResize和BatchRandomResize为例,RandomResize会将一个Batch中的每张图片进行随机缩放,但是每一张图像Resize之后的形状不尽相同,BatchRandomResize则会将一个Batch中的所有图片随机缩放到相同的形状。
- 除BatchRandomResize外,定义在`transform/batch_operators.py`的批数据增强算子接收的输入图像均为CHW形式,所以使用这些批数据增强算子前请先使用Permute进行处理。如果用到Gt2xxxTarget算子,需要将其放置在靠后的位置。NormalizeBox算子建议放置在Gt2xxxTarget之前。将这些限制条件总结下来,推荐的预处理算子的顺序为
```
-
XXX: {}
-
...
-
BatchRandomResize: {...} # 如果不需要,可以移除,如果需要,放置在Permute之前
-
Permute: {} # 必须项
-
NormalizeBox: {} # 如果需要,建议放在Gt2XXXTarget之前
-
PadBatch: {...} # 如果不需要可移除,如果需要,建议放置在Permute之后
-
Gt2XXXTarget: {...} # 建议与PadBatch放置在最后的位置
```
#### 3.2自定义数据增强算子
如果需要自定义数据增强算子,那么您需要了解下数据增强算子的相关逻辑。数据增强算子基类为定义在`transform/operators.py`中的`BaseOperator`类,单图像数据增强算子与批数据增强算子均继承自这个基类。完整定义参考源码,以下代码显示了`BaseOperator`类的关键函数: apply和__call__方法
```
python
class BaseOperator(object):
...
def apply(self, sample, context=None):
return sample
def __call__(self, sample, context=None):
if isinstance(sample, Sequence):
for i in range(len(sample)):
sample[i] = self.apply(sample[i], context)
else:
sample = self.apply(sample, context)
return sample
```
__call__方法为`BaseOperator`的调用入口,接收一个sample(单图像)或者多个sample(多图像)作为输入,并调用apply函数对一个或者多个sample进行处理。大多数情况下,你只需要继承`BaseOperator`重写apply方法或者重写__call__方法即可,如下所示,定义了一个XXXOp继承自BaseOperator,并注册:
```
python
@register_op
class XXXOp(BaseOperator):
def __init__(self,...):
super(XXXImage, self).__init__()
...
# 大多数情况下只需要重写apply方法
def apply(self, sample, context=None):
...
省略对输入的sample具体操作
...
return sample
# 如果有需要,可以重写__call__方法,如Mixup, Gt2XXXTarget等
# def __call__(self, sample, context=None):
# ...
# 省略对输入的sample具体操作
# ...
# return sample
```
大多数情况下,只需要重写apply方法即可,如`transform/operators.py`中除Mixup和Cutmix外的预处理算子。对于批处理的情况一般需要重写__call__方法,如`transform/batch_operators.py`的预处理算子。
### 4.Reader
Reader相关的类定义在`reader.py`, 其中定义了`BaseDataLoader`类。`BaseDataLoader`在`paddle.io.DataLoader`的基础上封装了一层,其具备`paddle.io.DataLoader`的所有功能,并能够实现不同模型对于`DetDataset`的不同需求,如可以通过对Reader进行设置,以控制`DetDataset`支持Mixup, Cutmix等操作。除此之外,数据预处理算子通过`Compose`类和`BatchCompose`类组合起来分别传入`DetDataset`和`paddle.io.DataLoader`中。
所有的Reader类都继承自`BaseDataLoader`类,具体可参见源码。
### 5.配置及运行
#### 5.1配置
与数据预处理相关的模块的配置文件包含所有模型公用的Datas set的配置文件以及不同模型专用的Reader的配置文件。关于Dataset的配置文件存在于`configs/datasets`文件夹。比如COCO数据集的配置文件如下:
```
metric: COCO # 目前支持COCO, VOC, OID, WiderFace等评估标准
num_classes: 80 # num_classes数据集的类别数,不包含背景类
TrainDataset:
!COCODataSet
image_dir: train2017 # 训练集的图片所在文件相对于dataset_dir的路径
anno_path: annotations/instances_train2017.json # 训练集的标注文件相对于dataset_dir的路径
dataset_dir: dataset/coco #数据集所在路径,相对于PaddleDetection路径
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] # 控制dataset输出的sample所包含的字段
EvalDataset:
!COCODataSet
image_dir: val2017 # 验证集的图片所在文件夹相对于dataset_dir的路径
anno_path: annotations/instances_val2017.json # 验证集的标注文件相对于dataset_dir的路径
dataset_dir: dataset/coco # 数据集所在路径,相对于PaddleDetection路径
TestDataset:
!ImageFolder
anno_path: dataset/coco/annotations/instances_val2017.json # 验证集的标注文件所在路径,相对于PaddleDetection的路径
```
在PaddleDetection的yml配置文件中,使用`!`直接序列化模块实例(可以是函数,实例等),上述的配置文件均使用Dataset进行了序列化。
不同模型专用的Reader定义在每一个模型的文件夹下,如yolov3的Reader配置文件定义在`configs/yolov3/_base_/yolov3_reader.yml`。一个Reader的示例配置如下:
```
worker_num: 2
TrainReader:
sample_transforms:
-
Decode: {}
...
batch_transforms:
...
batch_size: 8
shuffle: true
drop_last: true
use_shared_memory: true
EvalReader:
sample_transforms:
-
Decode: {}
...
batch_size: 1
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 608, 608]
sample_transforms:
-
Decode: {}
...
batch_size: 1
```
你可以在Reader中定义不同的预处理算子,每张卡的batch_size以及DataLoader的worker_num等。
#### 5.2运行
在PaddleDetection的训练、评估和测试运行程序中,都通过创建Reader迭代器。Reader在`ppdet/engine/trainer.py`中创建。下面的代码展示了如何创建训练时的Reader
```
python
from ppdet.core.workspace import create
# build data loader
self.dataset = cfg['TrainDataset']
self.loader = create('TrainReader')(selfdataset, cfg.worker_num)
```
相应的预测以及评估时的Reader与之类似,具体可参考`ppdet/engine/trainer.py`源码。
> 关于数据处理模块,如您有其他问题或建议,请给我们提issue,我们非常欢迎您的反馈。
dygraph/docs/images/model_figure.png
0 → 100644
浏览文件 @
b02375b2
147.8 KB
dygraph/docs/images/reader_figure.png
0 → 100644
浏览文件 @
b02375b2
195.0 KB
dygraph/docs/tutorials/PrepareDataSet.md
浏览文件 @
b02375b2
...
@@ -289,7 +289,7 @@ classname2
...
@@ -289,7 +289,7 @@ classname2
...
...
```
```
##### 用户数据转成COCO
##### 用户数据转成COCO
数据
在
`./tools/`
中提供了
`x2coco.py`
用于将VOC数据集、labelme标注的数据集或cityscape数据集转换为COCO数据,例如:
在
`./tools/`
中提供了
`x2coco.py`
用于将VOC数据集、labelme标注的数据集或cityscape数据集转换为COCO数据,例如:
(1)labelme数据转换为COCO数据:
(1)labelme数据转换为COCO数据:
...
@@ -328,7 +328,7 @@ dataset/xxx/
...
@@ -328,7 +328,7 @@ dataset/xxx/
```
```
##### 用户数据自定义reader
##### 用户数据自定义reader
如果数据集有新的数据需要添加进PaddleDetection中,您可参考数据处理文档中的
[
添加新数据源
](
../advanced_tutorials/READER.md#
添加新数据源
)
文档部分,开发相应代码完成新的数据源支持,同时数据处理具体代码解析等可阅读
[
数据处理文档
](
../advanced_tutorials/READER.md
)
如果数据集有新的数据需要添加进PaddleDetection中,您可参考数据处理文档中的
[
添加新数据源
](
../advanced_tutorials/READER.md#
2.3自定义数据集
)
文档部分,开发相应代码完成新的数据源支持,同时数据处理具体代码解析等可阅读
[
数据处理文档
](
../advanced_tutorials/READER.md
)
#### 用户数据数据转换示例
#### 用户数据数据转换示例
...
...
dygraph/
ppdet/data/
tools/x2coco.py
→
dygraph/tools/x2coco.py
浏览文件 @
b02375b2
...
@@ -21,6 +21,9 @@ import os
...
@@ -21,6 +21,9 @@ import os
import
os.path
as
osp
import
os.path
as
osp
import
sys
import
sys
import
shutil
import
shutil
import
xml.etree.ElementTree
as
ET
from
tqdm
import
tqdm
import
re
import
numpy
as
np
import
numpy
as
np
import
PIL.ImageDraw
import
PIL.ImageDraw
...
@@ -42,12 +45,6 @@ class MyEncoder(json.JSONEncoder):
...
@@ -42,12 +45,6 @@ class MyEncoder(json.JSONEncoder):
return
super
(
MyEncoder
,
self
).
default
(
obj
)
return
super
(
MyEncoder
,
self
).
default
(
obj
)
def
getbbox
(
self
,
points
):
polygons
=
points
mask
=
self
.
polygons_to_mask
([
self
.
height
,
self
.
width
],
polygons
)
return
self
.
mask2box
(
mask
)
def
images_labelme
(
data
,
num
):
def
images_labelme
(
data
,
num
):
image
=
{}
image
=
{}
image
[
'height'
]
=
data
[
'imageHeight'
]
image
[
'height'
]
=
data
[
'imageHeight'
]
...
@@ -154,17 +151,19 @@ def deal_json(ds_type, img_path, json_path):
...
@@ -154,17 +151,19 @@ def deal_json(ds_type, img_path, json_path):
categories_list
.
append
(
categories
(
label
,
labels_list
))
categories_list
.
append
(
categories
(
label
,
labels_list
))
labels_list
.
append
(
label
)
labels_list
.
append
(
label
)
label_to_num
[
label
]
=
len
(
labels_list
)
label_to_num
[
label
]
=
len
(
labels_list
)
points
=
shapes
[
'points'
]
p_type
=
shapes
[
'shape_type'
]
p_type
=
shapes
[
'shape_type'
]
if
p_type
==
'polygon'
:
if
p_type
==
'polygon'
:
points
=
shapes
[
'points'
]
annotations_list
.
append
(
annotations_list
.
append
(
annotations_polygon
(
data
[
'imageHeight'
],
data
[
annotations_polygon
(
data
[
'imageHeight'
],
data
[
'imageWidth'
],
points
,
label
,
image_num
,
'imageWidth'
],
points
,
label
,
image_num
,
object_num
,
label_to_num
))
object_num
,
label_to_num
))
if
p_type
==
'rectangle'
:
if
p_type
==
'rectangle'
:
points
.
append
([
points
[
0
][
0
],
points
[
1
][
1
]])
(
x1
,
y1
),
(
x2
,
y2
)
=
shapes
[
'points'
]
points
.
append
([
points
[
1
][
0
],
points
[
0
][
1
]])
x1
,
x2
=
sorted
([
x1
,
x2
])
y1
,
y2
=
sorted
([
y1
,
y2
])
points
=
[[
x1
,
y1
],
[
x2
,
y2
],
[
x1
,
y2
],
[
x2
,
y1
]]
annotations_list
.
append
(
annotations_list
.
append
(
annotations_rectangle
(
points
,
label
,
image_num
,
annotations_rectangle
(
points
,
label
,
image_num
,
object_num
,
label_to_num
))
object_num
,
label_to_num
))
...
@@ -187,6 +186,99 @@ def deal_json(ds_type, img_path, json_path):
...
@@ -187,6 +186,99 @@ def deal_json(ds_type, img_path, json_path):
return
data_coco
return
data_coco
def
voc_get_label_anno
(
ann_dir_path
,
ann_ids_path
,
labels_path
):
with
open
(
labels_path
,
'r'
)
as
f
:
labels_str
=
f
.
read
().
split
()
labels_ids
=
list
(
range
(
1
,
len
(
labels_str
)
+
1
))
with
open
(
ann_ids_path
,
'r'
)
as
f
:
ann_ids
=
f
.
read
().
split
()
ann_paths
=
[]
for
aid
in
ann_ids
:
if
aid
.
endswith
(
'xml'
):
ann_path
=
os
.
path
.
join
(
ann_dir_path
,
aid
)
else
:
ann_path
=
os
.
path
.
join
(
ann_dir_path
,
aid
+
'.xml'
)
ann_paths
.
append
(
ann_path
)
return
dict
(
zip
(
labels_str
,
labels_ids
)),
ann_paths
def
voc_get_image_info
(
annotation_root
,
im_id
):
filename
=
annotation_root
.
findtext
(
'filename'
)
assert
filename
is
not
None
img_name
=
os
.
path
.
basename
(
filename
)
size
=
annotation_root
.
find
(
'size'
)
width
=
float
(
size
.
findtext
(
'width'
))
height
=
float
(
size
.
findtext
(
'height'
))
image_info
=
{
'file_name'
:
filename
,
'height'
:
height
,
'width'
:
width
,
'id'
:
im_id
}
return
image_info
def
voc_get_coco_annotation
(
obj
,
label2id
):
label
=
obj
.
findtext
(
'name'
)
assert
label
in
label2id
,
"label is not in label2id."
category_id
=
label2id
[
label
]
bndbox
=
obj
.
find
(
'bndbox'
)
xmin
=
float
(
bndbox
.
findtext
(
'xmin'
))
-
1
ymin
=
float
(
bndbox
.
findtext
(
'ymin'
))
-
1
xmax
=
float
(
bndbox
.
findtext
(
'xmax'
))
ymax
=
float
(
bndbox
.
findtext
(
'ymax'
))
assert
xmax
>
xmin
and
ymax
>
ymin
,
"Box size error."
o_width
=
xmax
-
xmin
o_height
=
ymax
-
ymin
anno
=
{
'area'
:
o_width
*
o_height
,
'iscrowd'
:
0
,
'bbox'
:
[
xmin
,
ymin
,
o_width
,
o_height
],
'category_id'
:
category_id
,
'ignore'
:
0
,
}
return
anno
def
voc_xmls_to_cocojson
(
annotation_paths
,
label2id
,
output_dir
,
output_file
):
output_json_dict
=
{
"images"
:
[],
"type"
:
"instances"
,
"annotations"
:
[],
"categories"
:
[]
}
bnd_id
=
1
# bounding box start id
im_id
=
0
print
(
'Start converting !'
)
for
a_path
in
tqdm
(
annotation_paths
):
# Read annotation xml
ann_tree
=
ET
.
parse
(
a_path
)
ann_root
=
ann_tree
.
getroot
()
img_info
=
voc_get_image_info
(
ann_root
,
im_id
)
im_id
+=
1
img_id
=
img_info
[
'id'
]
output_json_dict
[
'images'
].
append
(
img_info
)
for
obj
in
ann_root
.
findall
(
'object'
):
ann
=
voc_get_coco_annotation
(
obj
=
obj
,
label2id
=
label2id
)
ann
.
update
({
'image_id'
:
img_id
,
'id'
:
bnd_id
})
output_json_dict
[
'annotations'
].
append
(
ann
)
bnd_id
=
bnd_id
+
1
for
label
,
label_id
in
label2id
.
items
():
category_info
=
{
'supercategory'
:
'none'
,
'id'
:
label_id
,
'name'
:
label
}
output_json_dict
[
'categories'
].
append
(
category_info
)
output_file
=
os
.
path
.
join
(
output_dir
,
output_file
)
with
open
(
output_file
,
'w'
)
as
f
:
output_json
=
json
.
dumps
(
output_json_dict
)
f
.
write
(
output_json
)
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
...
@@ -194,7 +286,7 @@ def main():
...
@@ -194,7 +286,7 @@ def main():
parser
.
add_argument
(
'--json_input_dir'
,
help
=
'input annotated directory'
)
parser
.
add_argument
(
'--json_input_dir'
,
help
=
'input annotated directory'
)
parser
.
add_argument
(
'--image_input_dir'
,
help
=
'image directory'
)
parser
.
add_argument
(
'--image_input_dir'
,
help
=
'image directory'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--output_dir'
,
help
=
'output dataset directory'
,
default
=
'.
./../..
/'
)
'--output_dir'
,
help
=
'output dataset directory'
,
default
=
'./'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--train_proportion'
,
'--train_proportion'
,
help
=
'the proportion of train dataset'
,
help
=
'the proportion of train dataset'
,
...
@@ -210,96 +302,137 @@ def main():
...
@@ -210,96 +302,137 @@ def main():
help
=
'the proportion of test dataset'
,
help
=
'the proportion of test dataset'
,
type
=
float
,
type
=
float
,
default
=
0.0
)
default
=
0.0
)
parser
.
add_argument
(
'--voc_anno_dir'
,
help
=
'In Voc format dataset, path to annotation files directory.'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--voc_anno_list'
,
help
=
'In Voc format dataset, path to annotation files ids list.'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--voc_label_list'
,
help
=
'In Voc format dataset, path to label list. The content of each line is a category.'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--voc_out_name'
,
type
=
str
,
default
=
'voc.json'
,
help
=
'In Voc format dataset, path to output json file'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
try
:
try
:
assert
args
.
dataset_type
in
[
'labelme'
,
'cityscape'
]
assert
args
.
dataset_type
in
[
'voc'
,
'labelme'
,
'cityscape'
]
except
AssertionError
as
e
:
print
(
'Now only support the cityscape dataset and labelme dataset!!'
)
os
.
_exit
(
0
)
try
:
assert
os
.
path
.
exists
(
args
.
json_input_dir
)
except
AssertionError
as
e
:
print
(
'The json folder does not exist!'
)
os
.
_exit
(
0
)
try
:
assert
os
.
path
.
exists
(
args
.
image_input_dir
)
except
AssertionError
as
e
:
print
(
'The image folder does not exist!'
)
os
.
_exit
(
0
)
try
:
assert
abs
(
args
.
train_proportion
+
args
.
val_proportion
\
+
args
.
test_proportion
-
1.0
)
<
1e-5
except
AssertionError
as
e
:
except
AssertionError
as
e
:
print
(
print
(
'The sum of pqoportion of training, validation and test datase must be 1!'
'Now only support the voc, cityscape dataset and labelme dataset!!'
)
)
os
.
_exit
(
0
)
os
.
_exit
(
0
)
# Allocate the dataset.
if
args
.
dataset_type
==
'voc'
:
total_num
=
len
(
glob
.
glob
(
osp
.
join
(
args
.
json_input_dir
,
'*.json'
)))
assert
args
.
voc_anno_dir
and
args
.
voc_anno_list
and
args
.
voc_label_list
if
args
.
train_proportion
!=
0
:
label2id
,
ann_paths
=
voc_get_label_anno
(
train_num
=
int
(
total_num
*
args
.
train_proportion
)
args
.
voc_anno_dir
,
args
.
voc_anno_list
,
args
.
voc_label_list
)
os
.
makedirs
(
args
.
output_dir
+
'/train'
)
voc_xmls_to_cocojson
(
annotation_paths
=
ann_paths
,
label2id
=
label2id
,
output_dir
=
args
.
output_dir
,
output_file
=
args
.
voc_out_name
)
else
:
else
:
train_num
=
0
try
:
if
args
.
val_proportion
==
0.0
:
assert
os
.
path
.
exists
(
args
.
json_input_dir
)
val_num
=
0
except
AssertionError
as
e
:
test_num
=
total_num
-
train_num
print
(
'The json folder does not exist!'
)
if
args
.
test_proportion
!=
0.0
:
os
.
_exit
(
0
)
os
.
makedirs
(
args
.
output_dir
+
'/test'
)
try
:
else
:
assert
os
.
path
.
exists
(
args
.
image_input_dir
)
val_num
=
int
(
total_num
*
args
.
val_proportion
)
except
AssertionError
as
e
:
test_num
=
total_num
-
train_num
-
val_num
print
(
'The image folder does not exist!'
)
os
.
makedirs
(
args
.
output_dir
+
'/val'
)
os
.
_exit
(
0
)
if
args
.
test_proportion
!=
0.0
:
try
:
os
.
makedirs
(
args
.
output_dir
+
'/test'
)
assert
abs
(
args
.
train_proportion
+
args
.
val_proportion
\
count
=
1
+
args
.
test_proportion
-
1.0
)
<
1e-5
for
img_name
in
os
.
listdir
(
args
.
image_input_dir
):
except
AssertionError
as
e
:
if
count
<=
train_num
:
print
(
if
osp
.
exists
(
args
.
output_dir
+
'/train/'
):
'The sum of pqoportion of training, validation and test datase must be 1!'
shutil
.
copyfile
(
)
osp
.
join
(
args
.
image_input_dir
,
img_name
),
os
.
_exit
(
0
)
osp
.
join
(
args
.
output_dir
+
'/train/'
,
img_name
))
# Allocate the dataset.
total_num
=
len
(
glob
.
glob
(
osp
.
join
(
args
.
json_input_dir
,
'*.json'
)))
if
args
.
train_proportion
!=
0
:
train_num
=
int
(
total_num
*
args
.
train_proportion
)
os
.
makedirs
(
args
.
output_dir
+
'/train'
)
else
:
else
:
if
count
<=
train_num
+
val_num
:
train_num
=
0
if
osp
.
exists
(
args
.
output_dir
+
'/val/'
):
if
args
.
val_proportion
==
0.0
:
val_num
=
0
test_num
=
total_num
-
train_num
if
args
.
test_proportion
!=
0.0
:
os
.
makedirs
(
args
.
output_dir
+
'/test'
)
else
:
val_num
=
int
(
total_num
*
args
.
val_proportion
)
test_num
=
total_num
-
train_num
-
val_num
os
.
makedirs
(
args
.
output_dir
+
'/val'
)
if
args
.
test_proportion
!=
0.0
:
os
.
makedirs
(
args
.
output_dir
+
'/test'
)
count
=
1
for
img_name
in
os
.
listdir
(
args
.
image_input_dir
):
if
count
<=
train_num
:
if
osp
.
exists
(
args
.
output_dir
+
'/train/'
):
shutil
.
copyfile
(
shutil
.
copyfile
(
osp
.
join
(
args
.
image_input_dir
,
img_name
),
osp
.
join
(
args
.
image_input_dir
,
img_name
),
osp
.
join
(
args
.
output_dir
+
'/
val
/'
,
img_name
))
osp
.
join
(
args
.
output_dir
+
'/
train
/'
,
img_name
))
else
:
else
:
if
osp
.
exists
(
args
.
output_dir
+
'/test/'
):
if
count
<=
train_num
+
val_num
:
shutil
.
copyfile
(
if
osp
.
exists
(
args
.
output_dir
+
'/val/'
):
osp
.
join
(
args
.
image_input_dir
,
img_name
),
shutil
.
copyfile
(
osp
.
join
(
args
.
output_dir
+
'/test/'
,
img_name
))
osp
.
join
(
args
.
image_input_dir
,
img_name
),
count
=
count
+
1
osp
.
join
(
args
.
output_dir
+
'/val/'
,
img_name
))
else
:
# Deal with the json files.
if
osp
.
exists
(
args
.
output_dir
+
'/test/'
):
if
not
os
.
path
.
exists
(
args
.
output_dir
+
'/annotations'
):
shutil
.
copyfile
(
os
.
makedirs
(
args
.
output_dir
+
'/annotations'
)
osp
.
join
(
args
.
image_input_dir
,
img_name
),
if
args
.
train_proportion
!=
0
:
osp
.
join
(
args
.
output_dir
+
'/test/'
,
img_name
))
train_data_coco
=
deal_json
(
count
=
count
+
1
args
.
dataset_type
,
args
.
output_dir
+
'/train'
,
args
.
json_input_dir
)
train_json_path
=
osp
.
join
(
args
.
output_dir
+
'/annotations'
,
# Deal with the json files.
'instance_train.json'
)
if
not
os
.
path
.
exists
(
args
.
output_dir
+
'/annotations'
):
json
.
dump
(
os
.
makedirs
(
args
.
output_dir
+
'/annotations'
)
train_data_coco
,
if
args
.
train_proportion
!=
0
:
open
(
train_json_path
,
'w'
),
train_data_coco
=
deal_json
(
args
.
dataset_type
,
indent
=
4
,
args
.
output_dir
+
'/train'
,
cls
=
MyEncoder
)
args
.
json_input_dir
)
if
args
.
val_proportion
!=
0
:
train_json_path
=
osp
.
join
(
args
.
output_dir
+
'/annotations'
,
val_data_coco
=
deal_json
(
args
.
dataset_type
,
args
.
output_dir
+
'/val'
,
'instance_train.json'
)
args
.
json_input_dir
)
json
.
dump
(
val_json_path
=
osp
.
join
(
args
.
output_dir
+
'/annotations'
,
train_data_coco
,
'instance_val.json'
)
open
(
train_json_path
,
'w'
),
json
.
dump
(
indent
=
4
,
val_data_coco
,
open
(
val_json_path
,
'w'
),
indent
=
4
,
cls
=
MyEncoder
)
cls
=
MyEncoder
)
if
args
.
test_proportion
!=
0
:
if
args
.
val_proportion
!=
0
:
test_data_coco
=
deal_json
(
args
.
dataset_type
,
args
.
output_dir
+
'/test'
,
val_data_coco
=
deal_json
(
args
.
dataset_type
,
args
.
json_input_dir
)
args
.
output_dir
+
'/val'
,
test_json_path
=
osp
.
join
(
args
.
output_dir
+
'/annotations'
,
args
.
json_input_dir
)
'instance_test.json'
)
val_json_path
=
osp
.
join
(
args
.
output_dir
+
'/annotations'
,
json
.
dump
(
'instance_val.json'
)
test_data_coco
,
open
(
test_json_path
,
'w'
),
indent
=
4
,
cls
=
MyEncoder
)
json
.
dump
(
val_data_coco
,
open
(
val_json_path
,
'w'
),
indent
=
4
,
cls
=
MyEncoder
)
if
args
.
test_proportion
!=
0
:
test_data_coco
=
deal_json
(
args
.
dataset_type
,
args
.
output_dir
+
'/test'
,
args
.
json_input_dir
)
test_json_path
=
osp
.
join
(
args
.
output_dir
+
'/annotations'
,
'instance_test.json'
)
json
.
dump
(
test_data_coco
,
open
(
test_json_path
,
'w'
),
indent
=
4
,
cls
=
MyEncoder
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录