Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
c9a7e0b2
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c9a7e0b2
编写于
1月 12, 2022
作者:
A
A. Unique TensorFlower
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add builder that applies bounding box-specific ops for RandAugment
PiperOrigin-RevId: 421439862
上级
49a5706c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
58 addition
and
4 deletion
+58
-4
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+0
-1
official/vision/beta/dataloaders/retinanet_input.py
official/vision/beta/dataloaders/retinanet_input.py
+10
-3
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+31
-0
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+17
-0
未找到文件。
official/vision/beta/configs/retinanet.py
浏览文件 @
c9a7e0b2
...
...
@@ -58,7 +58,6 @@ class Parser(hyperparams.Config):
skip_crowd_during_training
:
bool
=
True
max_num_instances
:
int
=
100
# Can choose AutoAugment and RandAugment.
# TODO(b/205346436) Support RandAugment.
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Keep for backward compatibility. Not used.
...
...
official/vision/beta/dataloaders/retinanet_input.py
浏览文件 @
c9a7e0b2
...
...
@@ -75,7 +75,7 @@ class Parser(parser.Parser):
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
The latter is not supported, and will raise ValueError.
RandAugment.
aug_rand_hflip: `bool`, if True, augment training with random horizontal
flip.
aug_scale_min: `float`, the minimum scale applied to `output_size` for
...
...
@@ -122,8 +122,16 @@ class Parser(parser.Parser):
augmentation_name
=
aug_type
.
autoaug
.
augmentation_name
,
cutout_const
=
aug_type
.
autoaug
.
cutout_const
,
translate_const
=
aug_type
.
autoaug
.
translate_const
)
elif
aug_type
.
type
==
'randaug'
:
logging
.
info
(
'Using RandAugment.'
)
self
.
_augmenter
=
augment
.
RandAugment
.
build_for_detection
(
num_layers
=
aug_type
.
randaug
.
num_layers
,
magnitude
=
aug_type
.
randaug
.
magnitude
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
,
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
,
exclude_ops
=
aug_type
.
randaug
.
exclude_ops
)
else
:
# TODO(b/205346436) Support RandAugment.
raise
ValueError
(
f
'Augmentation policy
{
aug_type
.
type
}
not supported.'
)
# Deprecated. Data Augmentation with AutoAugment.
...
...
@@ -162,7 +170,6 @@ class Parser(parser.Parser):
# Apply autoaug or randaug.
if
self
.
_augmenter
is
not
None
:
image
,
boxes
=
self
.
_augmenter
.
distort_with_boxes
(
image
,
boxes
)
image_shape
=
tf
.
shape
(
input
=
image
)[
0
:
2
]
# Normalizes image with mean and std pixel values.
...
...
official/vision/beta/ops/augment.py
浏览文件 @
c9a7e0b2
...
...
@@ -1950,6 +1950,37 @@ class RandAugment(ImageAugment):
op
for
op
in
self
.
available_ops
if
op
not
in
exclude_ops
]
@
classmethod
def
build_for_detection
(
cls
,
num_layers
:
int
=
2
,
magnitude
:
float
=
10.
,
cutout_const
:
float
=
40.
,
translate_const
:
float
=
100.
,
magnitude_std
:
float
=
0.0
,
prob_to_apply
:
Optional
[
float
]
=
None
,
exclude_ops
:
Optional
[
List
[
str
]]
=
None
):
"""Builds a RandAugment that modifies bboxes for geometric transforms."""
augmenter
=
cls
(
num_layers
=
num_layers
,
magnitude
=
magnitude
,
cutout_const
=
cutout_const
,
translate_const
=
translate_const
,
magnitude_std
=
magnitude_std
,
prob_to_apply
=
prob_to_apply
,
exclude_ops
=
exclude_ops
)
box_aware_ops_by_base_name
=
{
'Rotate'
:
'Rotate_BBox'
,
'ShearX'
:
'ShearX_BBox'
,
'ShearY'
:
'ShearY_BBox'
,
'TranslateX'
:
'TranslateX_BBox'
,
'TranslateY'
:
'TranslateY_BBox'
,
}
augmenter
.
available_ops
=
[
box_aware_ops_by_base_name
.
get
(
op_name
)
or
op_name
for
op_name
in
augmenter
.
available_ops
]
return
augmenter
def
_distort_common
(
self
,
image
:
tf
.
Tensor
,
...
...
official/vision/beta/ops/augment_test.py
浏览文件 @
c9a7e0b2
...
...
@@ -140,6 +140,23 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
def
test_randaug_build_for_detection
(
self
):
"""Smoke test to be sure there are no syntax errors built for detection."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
bboxes
=
tf
.
ones
((
2
,
4
),
dtype
=
tf
.
float32
)
augmenter
=
augment
.
RandAugment
.
build_for_detection
()
self
.
assertCountEqual
(
augmenter
.
available_ops
,
[
'AutoContrast'
,
'Equalize'
,
'Invert'
,
'Posterize'
,
'Solarize'
,
'Color'
,
'Contrast'
,
'Brightness'
,
'Sharpness'
,
'Cutout'
,
'SolarizeAdd'
,
'Rotate_BBox'
,
'ShearX_BBox'
,
'ShearY_BBox'
,
'TranslateX_BBox'
,
'TranslateY_BBox'
])
aug_image
,
aug_bboxes
=
augmenter
.
distort_with_boxes
(
image
,
bboxes
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertEqual
((
2
,
4
),
aug_bboxes
.
shape
)
def
test_all_policy_ops
(
self
):
"""Smoke test to be sure all augmentation functions can execute."""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录