Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
baecabee
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
baecabee
编写于
3月 09, 2020
作者:
littletomatodonkey
提交者:
GitHub
3月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add AutoAugment (#309)
* add autoaugment, which is validated on faster-rcnn
上级
ad603173
变更
5
展开全部
隐藏空白更改
内联
并排
Showing
5 changed file
with
1934 addition
and
0 deletion
+1934
-0
configs/autoaugment/README.md
configs/autoaugment/README.md
+23
-0
configs/autoaugment/faster_rcnn_r101_vd_fpn_aa_3x.yml
configs/autoaugment/faster_rcnn_r101_vd_fpn_aa_3x.yml
+127
-0
configs/autoaugment/faster_rcnn_r50_vd_fpn_aa_3x.yml
configs/autoaugment/faster_rcnn_r50_vd_fpn_aa_3x.yml
+127
-0
ppdet/data/transform/autoaugment_utils.py
ppdet/data/transform/autoaugment_utils.py
+1588
-0
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+69
-0
未找到文件。
configs/autoaugment/README.md
0 → 100644
浏览文件 @
baecabee
# Learning Data Augmentation Strategies for Object Detection
## Introduction
-
Learning Data Augmentation Strategies for Object Detection:
[
https://arxiv.org/abs/1906.11172
](
https://arxiv.org/abs/1906.11172
)
```
@article{Zoph2019LearningDA,
title={Learning Data Augmentation Strategies for Object Detection},
author={Barret Zoph and Ekin Dogus Cubuk and Golnaz Ghiasi and Tsung-Yi Lin and Jonathon Shlens and Quoc V. Le},
journal={ArXiv},
year={2019},
volume={abs/1906.11172}
}
```
## Model Zoo
| Backbone | Type | AutoAug policy | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download |
| :---------------------- | :-------------:| :-------: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: |
| ResNet50-vd-FPN | Faster | v1 | 2 | 3x | 22.800 | 39.9 | - |
[
model
](
https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_aa_3x.tar
)
|
| ResNet101-vd-FPN | Faster | v1 | 2 | 3x | 17.652 | 42.5 | - |
[
model
](
https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_aa_3x.tar
)
|
configs/autoaugment/faster_rcnn_r101_vd_fpn_aa_3x.yml
0 → 100644
浏览文件 @
baecabee
architecture
:
FasterRCNN
max_iters
:
270000
snapshot_iter
:
30000
use_gpu
:
true
log_smooth_window
:
20
save_dir
:
output
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar
weights
:
output/faster_rcnn_r101_vd_fpn_aa_3x/model_final
metric
:
COCO
num_classes
:
81
FasterRCNN
:
backbone
:
ResNet
fpn
:
FPN
rpn_head
:
FPNRPNHead
roi_extractor
:
FPNRoIAlign
bbox_head
:
BBoxHead
bbox_assigner
:
BBoxAssigner
ResNet
:
depth
:
101
feature_maps
:
[
2
,
3
,
4
,
5
]
freeze_at
:
2
norm_type
:
bn
variant
:
d
FPN
:
max_level
:
6
min_level
:
2
num_chan
:
256
spatial_scale
:
[
0.03125
,
0.0625
,
0.125
,
0.25
]
FPNRPNHead
:
anchor_generator
:
anchor_sizes
:
[
32
,
64
,
128
,
256
,
512
]
aspect_ratios
:
[
0.5
,
1.0
,
2.0
]
stride
:
[
16.0
,
16.0
]
variance
:
[
1.0
,
1.0
,
1.0
,
1.0
]
anchor_start_size
:
32
max_level
:
6
min_level
:
2
num_chan
:
256
rpn_target_assign
:
rpn_batch_size_per_im
:
256
rpn_fg_fraction
:
0.5
rpn_negative_overlap
:
0.3
rpn_positive_overlap
:
0.7
rpn_straddle_thresh
:
0.0
train_proposal
:
min_size
:
0.0
nms_thresh
:
0.7
post_nms_top_n
:
2000
pre_nms_top_n
:
2000
test_proposal
:
min_size
:
0.0
nms_thresh
:
0.7
post_nms_top_n
:
1000
pre_nms_top_n
:
1000
FPNRoIAlign
:
canconical_level
:
4
canonical_size
:
224
max_level
:
5
min_level
:
2
box_resolution
:
7
sampling_ratio
:
2
BBoxAssigner
:
batch_size_per_im
:
512
bbox_reg_weights
:
[
0.1
,
0.1
,
0.2
,
0.2
]
bg_thresh_hi
:
0.5
bg_thresh_lo
:
0.0
fg_fraction
:
0.25
fg_thresh
:
0.5
BBoxHead
:
head
:
TwoFCHead
nms
:
keep_top_k
:
100
nms_threshold
:
0.5
score_threshold
:
0.05
TwoFCHead
:
mlp_dim
:
1024
LearningRate
:
base_lr
:
0.02
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
180000
,
240000
]
-
!LinearWarmup
start_factor
:
0.1
steps
:
1000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0001
type
:
L2
_READER_
:
'
../faster_fpn_reader.yml'
TrainReader
:
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!RandomFlipImage
prob
:
0.5
-
!AutoAugmentImage
autoaug_type
:
v1
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!ResizeImage
target_size
:
800
max_size
:
1333
interp
:
1
use_cv2
:
true
-
!Permute
to_bgr
:
false
channel_first
:
true
batch_size
:
2
use_process
:
true
configs/autoaugment/faster_rcnn_r50_vd_fpn_aa_3x.yml
0 → 100644
浏览文件 @
baecabee
architecture
:
FasterRCNN
max_iters
:
270000
snapshot_iter
:
30000
use_gpu
:
true
log_smooth_window
:
20
save_dir
:
output
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights
:
output/faster_rcnn_r50_vd_fpn_aa_3x/model_final
metric
:
COCO
num_classes
:
81
FasterRCNN
:
backbone
:
ResNet
fpn
:
FPN
rpn_head
:
FPNRPNHead
roi_extractor
:
FPNRoIAlign
bbox_head
:
BBoxHead
bbox_assigner
:
BBoxAssigner
ResNet
:
depth
:
50
feature_maps
:
[
2
,
3
,
4
,
5
]
freeze_at
:
2
norm_type
:
bn
variant
:
d
FPN
:
max_level
:
6
min_level
:
2
num_chan
:
256
spatial_scale
:
[
0.03125
,
0.0625
,
0.125
,
0.25
]
FPNRPNHead
:
anchor_generator
:
anchor_sizes
:
[
32
,
64
,
128
,
256
,
512
]
aspect_ratios
:
[
0.5
,
1.0
,
2.0
]
stride
:
[
16.0
,
16.0
]
variance
:
[
1.0
,
1.0
,
1.0
,
1.0
]
anchor_start_size
:
32
max_level
:
6
min_level
:
2
num_chan
:
256
rpn_target_assign
:
rpn_batch_size_per_im
:
256
rpn_fg_fraction
:
0.5
rpn_negative_overlap
:
0.3
rpn_positive_overlap
:
0.7
rpn_straddle_thresh
:
0.0
train_proposal
:
min_size
:
0.0
nms_thresh
:
0.7
post_nms_top_n
:
2000
pre_nms_top_n
:
2000
test_proposal
:
min_size
:
0.0
nms_thresh
:
0.7
post_nms_top_n
:
1000
pre_nms_top_n
:
1000
FPNRoIAlign
:
canconical_level
:
4
canonical_size
:
224
max_level
:
5
min_level
:
2
box_resolution
:
7
sampling_ratio
:
2
BBoxAssigner
:
batch_size_per_im
:
512
bbox_reg_weights
:
[
0.1
,
0.1
,
0.2
,
0.2
]
bg_thresh_hi
:
0.5
bg_thresh_lo
:
0.0
fg_fraction
:
0.25
fg_thresh
:
0.5
BBoxHead
:
head
:
TwoFCHead
nms
:
keep_top_k
:
100
nms_threshold
:
0.5
score_threshold
:
0.05
TwoFCHead
:
mlp_dim
:
1024
LearningRate
:
base_lr
:
0.02
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
180000
,
240000
]
-
!LinearWarmup
start_factor
:
0.1
steps
:
1000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0001
type
:
L2
_READER_
:
'
../faster_fpn_reader.yml'
TrainReader
:
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!RandomFlipImage
prob
:
0.5
-
!AutoAugmentImage
autoaug_type
:
v1
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!ResizeImage
target_size
:
800
max_size
:
1333
interp
:
1
use_cv2
:
true
-
!Permute
to_bgr
:
false
channel_first
:
true
batch_size
:
2
use_process
:
true
ppdet/data/transform/autoaugment_utils.py
0 → 100644
浏览文件 @
baecabee
此差异已折叠。
点击以展开。
ppdet/data/transform/operators.py
浏览文件 @
baecabee
...
@@ -435,6 +435,75 @@ class RandomFlipImage(BaseOperator):
...
@@ -435,6 +435,75 @@ class RandomFlipImage(BaseOperator):
return
sample
return
sample
@
register_op
class
AutoAugmentImage
(
BaseOperator
):
def
__init__
(
self
,
is_normalized
=
False
,
autoaug_type
=
"v1"
):
"""
Args:
is_normalized (bool): whether the bbox scale to [0,1]
autoaug_type (str): autoaug type, support v0, v1, v2, v3, test
"""
super
(
AutoAugmentImage
,
self
).
__init__
()
self
.
is_normalized
=
is_normalized
self
.
autoaug_type
=
autoaug_type
if
not
isinstance
(
self
.
is_normalized
,
bool
):
raise
TypeError
(
"{}: input type is invalid."
.
format
(
self
))
def
__call__
(
self
,
sample
,
context
=
None
):
"""
Learning Data Augmentation Strategies for Object Detection, see https://arxiv.org/abs/1906.11172
"""
samples
=
sample
batch_input
=
True
if
not
isinstance
(
samples
,
Sequence
):
batch_input
=
False
samples
=
[
samples
]
for
sample
in
samples
:
gt_bbox
=
sample
[
'gt_bbox'
]
im
=
sample
[
'image'
]
if
not
isinstance
(
im
,
np
.
ndarray
):
raise
TypeError
(
"{}: image is not a numpy array."
.
format
(
self
))
if
len
(
im
.
shape
)
!=
3
:
raise
ImageError
(
"{}: image is not 3-dimensional."
.
format
(
self
))
if
len
(
gt_bbox
)
==
0
:
continue
# gt_boxes : [x1, y1, x2, y2]
# norm_gt_boxes: [y1, x1, y2, x2]
height
,
width
,
_
=
im
.
shape
norm_gt_bbox
=
np
.
ones_like
(
gt_bbox
,
dtype
=
np
.
float32
)
if
not
self
.
is_normalized
:
norm_gt_bbox
[:,
0
]
=
gt_bbox
[:,
1
]
/
float
(
height
)
norm_gt_bbox
[:,
1
]
=
gt_bbox
[:,
0
]
/
float
(
width
)
norm_gt_bbox
[:,
2
]
=
gt_bbox
[:,
3
]
/
float
(
height
)
norm_gt_bbox
[:,
3
]
=
gt_bbox
[:,
2
]
/
float
(
width
)
else
:
norm_gt_bbox
[:,
0
]
=
gt_bbox
[:,
1
]
norm_gt_bbox
[:,
1
]
=
gt_bbox
[:,
0
]
norm_gt_bbox
[:,
2
]
=
gt_bbox
[:,
3
]
norm_gt_bbox
[:,
3
]
=
gt_bbox
[:,
2
]
from
.autoaugment_utils
import
distort_image_with_autoaugment
im
,
norm_gt_bbox
=
distort_image_with_autoaugment
(
im
,
norm_gt_bbox
,
self
.
autoaug_type
)
if
not
self
.
is_normalized
:
gt_bbox
[:,
0
]
=
norm_gt_bbox
[:,
1
]
*
float
(
width
)
gt_bbox
[:,
1
]
=
norm_gt_bbox
[:,
0
]
*
float
(
height
)
gt_bbox
[:,
2
]
=
norm_gt_bbox
[:,
3
]
*
float
(
width
)
gt_bbox
[:,
3
]
=
norm_gt_bbox
[:,
2
]
*
float
(
height
)
else
:
gt_bbox
[:,
0
]
=
norm_gt_bbox
[:,
1
]
gt_bbox
[:,
1
]
=
norm_gt_bbox
[:,
0
]
gt_bbox
[:,
2
]
=
norm_gt_bbox
[:,
3
]
gt_bbox
[:,
3
]
=
norm_gt_bbox
[:,
2
]
sample
[
'gt_bbox'
]
=
gt_bbox
sample
[
'image'
]
=
im
sample
=
samples
if
batch_input
else
samples
[
0
]
return
sample
@
register_op
@
register_op
class
NormalizeImage
(
BaseOperator
):
class
NormalizeImage
(
BaseOperator
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录