Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
1c79ece9
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
1c79ece9
编写于
9月 16, 2021
作者:
A
A. Unique TensorFlower
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10227 from sigeisler:master
PiperOrigin-RevId: 397161611
上级
bea8998b
01b21983
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
854 addition
and
40 deletion
+854
-40
official/vision/beta/configs/common.py
official/vision/beta/configs/common.py
+26
-1
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+5
-0
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+32
-1
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+1
-0
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+331
-13
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+77
-0
official/vision/beta/ops/preprocess_ops.py
official/vision/beta/ops/preprocess_ops.py
+103
-1
official/vision/beta/ops/preprocess_ops_test.py
official/vision/beta/ops/preprocess_ops_test.py
+13
-0
official/vision/beta/projects/vit/README.md
official/vision/beta/projects/vit/README.md
+7
-5
official/vision/beta/projects/vit/configs/backbones.py
official/vision/beta/projects/vit/configs/backbones.py
+2
-0
official/vision/beta/projects/vit/configs/image_classification.py
.../vision/beta/projects/vit/configs/image_classification.py
+86
-1
official/vision/beta/projects/vit/modeling/nn_blocks.py
official/vision/beta/projects/vit/modeling/nn_blocks.py
+107
-0
official/vision/beta/projects/vit/modeling/vit.py
official/vision/beta/projects/vit/modeling/vit.py
+34
-11
official/vision/beta/projects/yolo/configs/darknet_classification.py
...sion/beta/projects/yolo/configs/darknet_classification.py
+1
-1
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+29
-6
未找到文件。
official/vision/beta/configs/common.py
浏览文件 @
1c79ece9
...
...
@@ -16,7 +16,7 @@
"""Common configurations."""
import
dataclasses
from
typing
import
Optional
from
typing
import
List
,
Optional
# Import libraries
...
...
@@ -60,7 +60,9 @@ class RandAugment(hyperparams.Config):
magnitude
:
float
=
10
cutout_const
:
float
=
40
translate_const
:
float
=
10
magnitude_std
:
float
=
0.0
prob_to_apply
:
Optional
[
float
]
=
None
exclude_ops
:
List
[
str
]
=
dataclasses
.
field
(
default_factory
=
list
)
@
dataclasses
.
dataclass
...
...
@@ -71,6 +73,29 @@ class AutoAugment(hyperparams.Config):
translate_const
:
float
=
250
@
dataclasses
.
dataclass
class
RandomErasing
(
hyperparams
.
Config
):
"""Configuration for RandomErasing."""
probability
:
float
=
0.25
min_area
:
float
=
0.02
max_area
:
float
=
1
/
3
min_aspect
:
float
=
0.3
max_aspect
=
None
min_count
=
1
max_count
=
1
trials
=
10
@
dataclasses
.
dataclass
class
MixupAndCutmix
(
hyperparams
.
Config
):
"""Configuration for MixupAndCutmix."""
mixup_alpha
:
float
=
.
8
cutmix_alpha
:
float
=
1.
prob
:
float
=
1.0
switch_prob
:
float
=
0.5
label_smoothing
:
float
=
0.1
@
dataclasses
.
dataclass
class
Augmentation
(
hyperparams
.
OneOfConfig
):
"""Configuration for input data augmentation.
...
...
official/vision/beta/configs/image_classification.py
浏览文件 @
1c79ece9
...
...
@@ -39,10 +39,13 @@ class DataConfig(cfg.DataConfig):
aug_rand_hflip
:
bool
=
True
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
color_jitter
:
float
=
0.
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
file_type
:
str
=
'tfrecord'
image_field_key
:
str
=
'image/encoded'
label_field_key
:
str
=
'image/class/label'
decode_jpeg_only
:
bool
=
True
mixup_and_cutmix
:
Optional
[
common
.
MixupAndCutmix
]
=
None
decoder
:
Optional
[
common
.
DataDecoder
]
=
common
.
DataDecoder
()
# Keep for backward compatibility.
...
...
@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn
=
False
)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm
:
bool
=
False
kernel_initializer
:
str
=
'random_uniform'
@
dataclasses
.
dataclass
...
...
@@ -69,6 +73,7 @@ class Losses(hyperparams.Config):
one_hot
:
bool
=
True
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
soft_labels
:
bool
=
False
@
dataclasses
.
dataclass
...
...
official/vision/beta/dataloaders/classification_input.py
浏览文件 @
1c79ece9
...
...
@@ -69,6 +69,8 @@ class Parser(parser.Parser):
decode_jpeg_only
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
color_jitter
:
float
=
0.
,
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
,
is_multilabel
:
bool
=
False
,
dtype
:
str
=
'float32'
):
"""Initializes parameters for parsing annotations in the dataset.
...
...
@@ -85,6 +87,11 @@ class Parser(parser.Parser):
horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
color_jitter: Magnitude of color jitter. If > 0, the value is used to
generate random scale factor for brightness, contrast and saturation.
See `preprocess_ops.color_jitter` for more details.
random_erasing: if not None, augment input image by random erasing. See
`augment.RandomErasing` for more details.
is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
...
...
@@ -113,13 +120,27 @@ class Parser(parser.Parser):
magnitude
=
aug_type
.
randaug
.
magnitude
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
,
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
)
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
,
exclude_ops
=
aug_type
.
randaug
.
exclude_ops
)
else
:
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
aug_type
.
type
))
else
:
self
.
_augmenter
=
None
self
.
_label_field_key
=
label_field_key
self
.
_color_jitter
=
color_jitter
if
random_erasing
:
self
.
_random_erasing
=
augment
.
RandomErasing
(
probability
=
random_erasing
.
probability
,
min_area
=
random_erasing
.
min_area
,
max_area
=
random_erasing
.
max_area
,
min_aspect
=
random_erasing
.
min_aspect
,
max_aspect
=
random_erasing
.
max_aspect
,
min_count
=
random_erasing
.
min_count
,
max_count
=
random_erasing
.
max_count
,
trials
=
random_erasing
.
trials
)
else
:
self
.
_random_erasing
=
None
self
.
_is_multilabel
=
is_multilabel
self
.
_decode_jpeg_only
=
decode_jpeg_only
...
...
@@ -173,6 +194,12 @@ class Parser(parser.Parser):
if
self
.
_aug_rand_hflip
:
image
=
tf
.
image
.
random_flip_left_right
(
image
)
# Color jitter.
if
self
.
_color_jitter
>
0
:
image
=
preprocess_ops
.
color_jitter
(
image
,
self
.
_color_jitter
,
self
.
_color_jitter
,
self
.
_color_jitter
)
# Resizes image.
image
=
tf
.
image
.
resize
(
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
...
...
@@ -187,6 +214,10 @@ class Parser(parser.Parser):
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
# Random erasing after the image has been normalized
if
self
.
_random_erasing
is
not
None
:
image
=
self
.
_random_erasing
.
distort
(
image
)
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
...
...
official/vision/beta/modeling/factory.py
浏览文件 @
1c79ece9
...
...
@@ -56,6 +56,7 @@ def build_classification_model(
num_classes
=
model_config
.
num_classes
,
input_specs
=
input_specs
,
dropout_rate
=
model_config
.
dropout_rate
,
kernel_initializer
=
model_config
.
kernel_initializer
,
kernel_regularizer
=
l2_regularizer
,
add_head_batch_norm
=
model_config
.
add_head_batch_norm
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
...
...
official/vision/beta/ops/augment.py
浏览文件 @
1c79ece9
...
...
@@ -12,10 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Au
toAugment and RandAugment
policies for enhanced image/video preprocessing.
"""Au
gmentation
policies for enhanced image/video preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
RandomErasing, Mixup and Cutmix are inspired by
https://github.com/rwightman/pytorch-image-models
"""
import
math
from
typing
import
Any
,
List
,
Iterable
,
Optional
,
Text
,
Tuple
...
...
@@ -295,10 +303,26 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
cutout_center_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
lower_pad
=
tf
.
maximum
(
0
,
cutout_center_height
-
pad_size
)
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
cutout_center_height
-
pad_size
)
left_pad
=
tf
.
maximum
(
0
,
cutout_center_width
-
pad_size
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
cutout_center_width
-
pad_size
)
image
=
_fill_rectangle
(
image
,
cutout_center_width
,
cutout_center_height
,
pad_size
,
pad_size
,
replace
)
return
image
def
_fill_rectangle
(
image
,
center_width
,
center_height
,
half_width
,
half_height
,
replace
=
None
):
"""Fill blank area."""
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
lower_pad
=
tf
.
maximum
(
0
,
center_height
-
half_height
)
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
center_height
-
half_height
)
left_pad
=
tf
.
maximum
(
0
,
center_width
-
half_width
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
center_width
-
half_width
)
cutout_shape
=
[
image_height
-
(
lower_pad
+
upper_pad
),
...
...
@@ -311,9 +335,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
constant_values
=
1
)
mask
=
tf
.
expand_dims
(
mask
,
-
1
)
mask
=
tf
.
tile
(
mask
,
[
1
,
1
,
3
])
image
=
tf
.
where
(
tf
.
equal
(
mask
,
0
),
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
,
image
)
if
replace
is
None
:
fill
=
tf
.
random
.
normal
(
tf
.
shape
(
image
),
dtype
=
image
.
dtype
)
elif
isinstance
(
replace
,
tf
.
Tensor
):
fill
=
replace
else
:
fill
=
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
image
=
tf
.
where
(
tf
.
equal
(
mask
,
0
),
fill
,
image
)
return
image
...
...
@@ -803,11 +833,20 @@ def level_to_arg(cutout_const: float, translate_const: float):
return
args
def
_parse_policy_info
(
name
:
Text
,
prob
:
float
,
level
:
float
,
replace_value
:
List
[
int
],
cutout_const
:
float
,
translate_const
:
float
)
->
Tuple
[
Any
,
float
,
Any
]:
def
_parse_policy_info
(
name
:
Text
,
prob
:
float
,
level
:
float
,
replace_value
:
List
[
int
],
cutout_const
:
float
,
translate_const
:
float
,
level_std
:
float
=
0.
)
->
Tuple
[
Any
,
float
,
Any
]:
"""Return the function that corresponds to `name` and update `level` param."""
func
=
NAME_TO_FUNC
[
name
]
if
level_std
>
0
:
level
+=
tf
.
random
.
normal
([],
dtype
=
tf
.
float32
)
level
=
tf
.
clip_by_value
(
level
,
0.
,
_MAX_LEVEL
)
args
=
level_to_arg
(
cutout_const
,
translate_const
)[
name
](
level
)
if
name
in
REPLACE_FUNCS
:
...
...
@@ -1184,7 +1223,9 @@ class RandAugment(ImageAugment):
magnitude
:
float
=
10.
,
cutout_const
:
float
=
40.
,
translate_const
:
float
=
100.
,
prob_to_apply
:
Optional
[
float
]
=
None
):
magnitude_std
:
float
=
0.0
,
prob_to_apply
:
Optional
[
float
]
=
None
,
exclude_ops
:
Optional
[
List
[
str
]]
=
None
):
"""Applies the RandAugment policy to images.
Args:
...
...
@@ -1196,8 +1237,11 @@ class RandAugment(ImageAugment):
[5, 10].
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
magnitude_std: randomness of the severity as proposed by the authors of
the timm library.
prob_to_apply: The probability to apply the selected augmentation at each
layer.
exclude_ops: exclude selected operations.
"""
super
(
RandAugment
,
self
).
__init__
()
...
...
@@ -1212,6 +1256,11 @@ class RandAugment(ImageAugment):
'Color'
,
'Contrast'
,
'Brightness'
,
'Sharpness'
,
'ShearX'
,
'ShearY'
,
'TranslateX'
,
'TranslateY'
,
'Cutout'
,
'SolarizeAdd'
]
self
.
magnitude_std
=
magnitude_std
if
exclude_ops
:
self
.
available_ops
=
[
op
for
op
in
self
.
available_ops
if
op
not
in
exclude_ops
]
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies the RandAugment policy to `image`.
...
...
@@ -1246,7 +1295,8 @@ class RandAugment(ImageAugment):
dtype
=
tf
.
float32
)
func
,
_
,
args
=
_parse_policy_info
(
op_name
,
prob
,
self
.
magnitude
,
replace_value
,
self
.
cutout_const
,
self
.
translate_const
)
self
.
translate_const
,
self
.
magnitude_std
)
branch_fns
.
append
((
i
,
# pylint:disable=g-long-lambda
...
...
@@ -1267,3 +1317,271 @@ class RandAugment(ImageAugment):
image
=
tf
.
cast
(
image
,
dtype
=
input_image_type
)
return
image
class
RandomErasing
(
ImageAugment
):
"""Applies RandomErasing to a single image.
Reference: https://arxiv.org/abs/1708.04896
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def
__init__
(
self
,
probability
:
float
=
0.25
,
min_area
:
float
=
0.02
,
max_area
:
float
=
1
/
3
,
min_aspect
:
float
=
0.3
,
max_aspect
=
None
,
min_count
=
1
,
max_count
=
1
,
trials
=
10
):
"""Applies RandomErasing to a single image.
Args:
probability (float, optional): Probability of augmenting the image.
Defaults to 0.25.
min_area (float, optional): Minimum area of the random erasing rectangle.
Defaults to 0.02.
max_area (float, optional): Maximum area of the random erasing rectangle.
Defaults to 1/3.
min_aspect (float, optional): Minimum aspect rate of the random erasing
rectangle. Defaults to 0.3.
max_aspect ([type], optional): Maximum aspect rate of the random erasing
rectangle. Defaults to None.
min_count (int, optional): Minimum number of erased rectangles. Defaults
to 1.
max_count (int, optional): Maximum number of erased rectangles. Defaults
to 1.
trials (int, optional): Maximum number of trials to randomly sample a
rectangle that fulfills constraint. Defaults to 10.
"""
self
.
_probability
=
probability
self
.
_min_area
=
float
(
min_area
)
self
.
_max_area
=
float
(
max_area
)
self
.
_min_log_aspect
=
math
.
log
(
min_aspect
)
self
.
_max_log_aspect
=
math
.
log
(
max_aspect
or
1
/
min_aspect
)
self
.
_min_count
=
min_count
self
.
_max_count
=
max_count
self
.
_trials
=
trials
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies RandomErasing to single `image`.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
Returns:
tf.Tensor: The augmented version of `image`.
"""
uniform_random
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
)
mirror_cond
=
tf
.
less
(
uniform_random
,
self
.
_probability
)
image
=
tf
.
cond
(
mirror_cond
,
lambda
:
self
.
_erase
(
image
),
lambda
:
image
)
return
image
@
tf
.
function
def
_erase
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Erase an area."""
if
self
.
_min_count
==
self
.
_max_count
:
count
=
self
.
_min_count
else
:
count
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
int
(
self
.
_min_count
),
maxval
=
int
(
self
.
_max_count
-
self
.
_min_count
+
1
),
dtype
=
tf
.
int32
)
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
area
=
tf
.
cast
(
image_width
*
image_height
,
tf
.
float32
)
for
_
in
range
(
count
):
# Work around since break is not supported in tf.function
is_trial_successfull
=
False
for
_
in
range
(
self
.
_trials
):
if
not
is_trial_successfull
:
erase_area
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
area
*
self
.
_min_area
,
maxval
=
area
*
self
.
_max_area
)
aspect_ratio
=
tf
.
math
.
exp
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
self
.
_min_log_aspect
,
maxval
=
self
.
_max_log_aspect
))
half_height
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
*
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
half_width
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
/
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
if
2
*
half_height
<
image_height
and
2
*
half_width
<
image_width
:
center_height
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_height
-
2
*
half_height
),
dtype
=
tf
.
int32
)
center_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_width
-
2
*
half_width
),
dtype
=
tf
.
int32
)
image
=
_fill_rectangle
(
image
,
center_width
,
center_height
,
half_width
,
half_height
,
replace
=
None
)
is_trial_successfull
=
True
return
image
class
MixupAndCutmix
:
"""Applies Mixup and/or Cutmix to a batch of images.
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def
__init__
(
self
,
mixup_alpha
:
float
=
.
8
,
cutmix_alpha
:
float
=
1.
,
prob
:
float
=
1.0
,
switch_prob
:
float
=
0.5
,
label_smoothing
:
float
=
0.1
,
num_classes
:
int
=
1001
):
"""Applies Mixup and/or Cutmix to a batch of images.
Args:
mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Mixup is deactivated.
Defaults to .8.
cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Cutmix is deactivated.
Defaults to 1..
prob (float, optional): Of augmenting the batch. Defaults to 1.0.
switch_prob (float, optional): Probability of applying Cutmix for the
batch. Defaults to 0.5.
label_smoothing (float, optional): Constant for label smoothing. Defaults
to 0.1.
num_classes (int, optional): Number of classes. Defaults to 1001.
"""
self
.
mixup_alpha
=
mixup_alpha
self
.
cutmix_alpha
=
cutmix_alpha
self
.
mix_prob
=
prob
self
.
switch_prob
=
switch_prob
self
.
label_smoothing
=
label_smoothing
self
.
num_classes
=
num_classes
self
.
mode
=
'batch'
self
.
mixup_enabled
=
True
if
self
.
mixup_alpha
and
not
self
.
cutmix_alpha
:
self
.
switch_prob
=
-
1
elif
not
self
.
mixup_alpha
and
self
.
cutmix_alpha
:
self
.
switch_prob
=
1
def
__call__
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
return
self
.
distort
(
images
,
labels
)
def
distort
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Applies Mixup and/or Cutmix to batch of images and transforms labels.
Args:
images (tf.Tensor): Of shape [batch_size,height, width, 3] representing a
batch of image.
labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
each image of the batch.
Returns:
Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
`labels`.
"""
augment_cond
=
tf
.
less
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
),
self
.
mix_prob
)
# pylint: disable=g-long-lambda
augment_a
=
lambda
:
self
.
_update_labels
(
*
tf
.
cond
(
tf
.
less
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
),
self
.
switch_prob
),
lambda
:
self
.
_cutmix
(
images
,
labels
),
lambda
:
self
.
_mixup
(
images
,
labels
)))
augment_b
=
lambda
:
(
images
,
self
.
_smooth_labels
(
labels
))
# pylint: enable=g-long-lambda
return
tf
.
cond
(
augment_cond
,
augment_a
,
augment_b
)
@
staticmethod
def
_sample_from_beta
(
alpha
,
beta
,
shape
):
sample_alpha
=
tf
.
random
.
gamma
(
shape
,
1.
,
beta
=
alpha
)
sample_beta
=
tf
.
random
.
gamma
(
shape
,
1.
,
beta
=
beta
)
return
sample_alpha
/
(
sample_alpha
+
sample_beta
)
def
_cutmix
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
]:
"""Apply cutmix."""
lam
=
MixupAndCutmix
.
_sample_from_beta
(
self
.
cutmix_alpha
,
self
.
cutmix_alpha
,
labels
.
shape
)
ratio
=
tf
.
math
.
sqrt
(
1
-
lam
)
batch_size
=
tf
.
shape
(
images
)[
0
]
image_height
,
image_width
=
tf
.
shape
(
images
)[
1
],
tf
.
shape
(
images
)[
2
]
cut_height
=
tf
.
cast
(
ratio
*
tf
.
cast
(
image_height
,
dtype
=
tf
.
float32
),
dtype
=
tf
.
int32
)
cut_width
=
tf
.
cast
(
ratio
*
tf
.
cast
(
image_height
,
dtype
=
tf
.
float32
),
dtype
=
tf
.
int32
)
random_center_height
=
tf
.
random
.
uniform
(
shape
=
[
batch_size
],
minval
=
0
,
maxval
=
image_height
,
dtype
=
tf
.
int32
)
random_center_width
=
tf
.
random
.
uniform
(
shape
=
[
batch_size
],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
bbox_area
=
cut_height
*
cut_width
lam
=
1.
-
bbox_area
/
(
image_height
*
image_width
)
lam
=
tf
.
cast
(
lam
,
dtype
=
tf
.
float32
)
images
=
tf
.
map_fn
(
lambda
x
:
_fill_rectangle
(
*
x
),
(
images
,
random_center_width
,
random_center_height
,
cut_width
//
2
,
cut_height
//
2
,
tf
.
reverse
(
images
,
[
0
])),
dtype
=
(
tf
.
float32
,
tf
.
int32
,
tf
.
int32
,
tf
.
int32
,
tf
.
int32
,
tf
.
float32
),
fn_output_signature
=
tf
.
TensorSpec
(
images
.
shape
[
1
:],
dtype
=
tf
.
float32
))
return
images
,
labels
,
lam
def
_mixup
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
]:
lam
=
MixupAndCutmix
.
_sample_from_beta
(
self
.
mixup_alpha
,
self
.
mixup_alpha
,
labels
.
shape
)
lam
=
tf
.
reshape
(
lam
,
[
-
1
,
1
,
1
,
1
])
images
=
lam
*
images
+
(
1.
-
lam
)
*
tf
.
reverse
(
images
,
[
0
])
return
images
,
labels
,
tf
.
squeeze
(
lam
)
def
_smooth_labels
(
self
,
labels
:
tf
.
Tensor
)
->
tf
.
Tensor
:
off_value
=
self
.
label_smoothing
/
self
.
num_classes
on_value
=
1.
-
self
.
label_smoothing
+
off_value
smooth_labels
=
tf
.
one_hot
(
labels
,
self
.
num_classes
,
on_value
=
on_value
,
off_value
=
off_value
)
return
smooth_labels
def
_update_labels
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
,
lam
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
labels_1
=
self
.
_smooth_labels
(
labels
)
labels_2
=
tf
.
reverse
(
labels_1
,
[
0
])
lam
=
tf
.
reshape
(
lam
,
[
-
1
,
1
])
labels
=
lam
*
labels_1
+
(
1.
-
lam
)
*
labels_2
return
images
,
labels
official/vision/beta/ops/augment_test.py
浏览文件 @
1c79ece9
...
...
@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
augmenter
.
distort
(
image
)
class
RandomErasingTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_random_erase_replaces_some_pixels
(
self
):
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
float32
)
augmenter
=
augment
.
RandomErasing
(
probability
=
1.
,
max_count
=
10
)
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertNotEqual
(
0
,
tf
.
reduce_max
(
aug_image
))
class
MixupAndCutmixTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_mixup_and_cutmix_smoothes_labels
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
num_classes
=
num_classes
,
label_smoothing
=
label_smoothing
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
def
test_mixup_changes_image
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
1.
,
cutmix_alpha
=
0.
,
num_classes
=
num_classes
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
self
.
assertFalse
(
tf
.
math
.
reduce_all
(
images
==
aug_images
))
def
test_cutmix_changes_image
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
0.
,
cutmix_alpha
=
1.
,
num_classes
=
num_classes
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
self
.
assertFalse
(
tf
.
math
.
reduce_all
(
images
==
aug_images
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/ops/preprocess_ops.py
浏览文件 @
1c79ece9
...
...
@@ -15,12 +15,13 @@
"""Preprocessing ops."""
import
math
from
typing
import
Optional
from
six.moves
import
range
import
tensorflow
as
tf
from
official.vision.beta.ops
import
augment
from
official.vision.beta.ops
import
box_ops
CENTER_CROP_FRACTION
=
0.875
...
...
@@ -557,6 +558,107 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
return
image
,
normalized_boxes
,
masks
def
color_jitter
(
image
:
tf
.
Tensor
,
brightness
:
Optional
[
float
]
=
0.
,
contrast
:
Optional
[
float
]
=
0.
,
saturation
:
Optional
[
float
]
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
brightness (float, optional): Magnitude for brightness jitter. Defaults to
0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
saturation (float, optional): Magnitude for saturation jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
image
=
tf
.
cast
(
image
,
dtype
=
tf
.
uint8
)
image
=
random_brightness
(
image
,
brightness
,
seed
=
seed
)
image
=
random_contrast
(
image
,
contrast
,
seed
=
seed
)
image
=
random_saturation
(
image
,
saturation
,
seed
=
seed
)
return
image
def
random_brightness
(
image
:
tf
.
Tensor
,
brightness
:
float
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters brightness of an image.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
brightness (float, optional): Magnitude for brightness jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert
brightness
>=
0
,
'`brightness` must be positive'
brightness
=
tf
.
random
.
uniform
([],
max
(
0
,
1
-
brightness
),
1
+
brightness
,
seed
=
seed
,
dtype
=
tf
.
float32
)
return
augment
.
brightness
(
image
,
brightness
)
def
random_contrast
(
image
:
tf
.
Tensor
,
contrast
:
float
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert
contrast
>=
0
,
'`contrast` must be positive'
contrast
=
tf
.
random
.
uniform
([],
max
(
0
,
1
-
contrast
),
1
+
contrast
,
seed
=
seed
,
dtype
=
tf
.
float32
)
return
augment
.
contrast
(
image
,
contrast
)
def
random_saturation
(
image
:
tf
.
Tensor
,
saturation
:
float
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
saturation (float, optional): Magnitude for saturation jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert
saturation
>=
0
,
'`saturation` must be positive'
saturation
=
tf
.
random
.
uniform
([],
max
(
0
,
1
-
saturation
),
1
+
saturation
,
seed
=
seed
,
dtype
=
tf
.
float32
)
return
_saturation
(
image
,
saturation
)
def
_saturation
(
image
:
tf
.
Tensor
,
saturation
:
Optional
[
float
]
=
0.
)
->
tf
.
Tensor
:
return
augment
.
blend
(
tf
.
repeat
(
tf
.
image
.
rgb_to_grayscale
(
image
),
3
,
axis
=-
1
),
image
,
saturation
)
def
random_crop_image_with_boxes_and_labels
(
img
,
boxes
,
labels
,
min_scale
,
aspect_ratio_range
,
min_overlap_params
,
max_retry
):
...
...
official/vision/beta/ops/preprocess_ops_test.py
浏览文件 @
1c79ece9
...
...
@@ -197,6 +197,19 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
_
=
preprocess_ops
.
random_crop_image_v2
(
image_bytes
,
tf
.
constant
([
input_height
,
input_width
,
3
],
tf
.
int32
))
@
parameterized
.
parameters
((
400
,
600
,
0
),
(
400
,
600
,
0.4
),
(
600
,
400
,
1.4
))
def
testColorJitter
(
self
,
input_height
,
input_width
,
color_jitter
):
image
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
input_height
,
input_width
,
3
))
jittered_image
=
preprocess_ops
.
color_jitter
(
image
,
color_jitter
,
color_jitter
,
color_jitter
)
assert
jittered_image
.
shape
==
image
.
shape
@
parameterized
.
parameters
((
400
,
600
,
0
),
(
400
,
600
,
0.4
),
(
600
,
400
,
1
))
def
testSaturation
(
self
,
input_height
,
input_width
,
saturation
):
image
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
input_height
,
input_width
,
3
))
jittered_image
=
preprocess_ops
.
_saturation
(
image
,
saturation
)
assert
jittered_image
.
shape
==
image
.
shape
@
parameterized
.
parameters
((
640
,
640
,
20
),
(
1280
,
1280
,
30
))
def
test_random_crop
(
self
,
input_height
,
input_width
,
num_boxes
):
image
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
input_height
,
input_width
,
3
))
...
...
official/vision/beta/projects/vit/README.md
浏览文件 @
1c79ece9
# Vision Transformer (ViT)
# Vision Transformer (ViT)
and Data-Efficient Image Transformer (DEIT)
**DISCLAIMER**
: This implementation is still under development. No support will
be provided during the development phase.
[
![Paper
](
http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv
)
](https://arxiv.org/abs/2010.11929)
-
[
![ViT Paper
](
http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv
)
](https://arxiv.org/abs/2010.11929)
-
[
![DEIT Paper
](
http://img.shields.io/badge/Paper-arXiv.2012.12877-B3181B?logo=arXiv
)
](https://arxiv.org/abs/2012.12877)
This repository is the implementations of Vision Transformer (ViT)
in
TensorFlow 2.
This repository is the implementations of Vision Transformer (ViT)
and
Data-Efficient Image Transformer (DEIT) in
TensorFlow 2.
*
Paper title:
[
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
](
https://arxiv.org/pdf/2010.11929.pdf
)
.
\ No newline at end of file
-
[
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
](
https://arxiv.org/pdf/2010.11929.pdf
)
.
-
[
Training data-efficient image transformers & distillation through attention
](
https://arxiv.org/pdf/2012.12877.pdf
)
.
official/vision/beta/projects/vit/configs/backbones.py
浏览文件 @
1c79ece9
...
...
@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config):
hidden_size
:
int
=
1
patch_size
:
int
=
16
transformer
:
Transformer
=
Transformer
()
init_stochastic_depth_rate
:
float
=
0.0
original_init
:
bool
=
True
@
dataclasses
.
dataclass
...
...
official/vision/beta/projects/vit/configs/image_classification.py
浏览文件 @
1c79ece9
...
...
@@ -44,6 +44,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn
=
False
)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm
:
bool
=
False
kernel_initializer
:
str
=
'random_uniform'
@
dataclasses
.
dataclass
...
...
@@ -51,6 +52,7 @@ class Losses(hyperparams.Config):
one_hot
:
bool
=
True
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
soft_labels
:
bool
=
False
@
dataclasses
.
dataclass
...
...
@@ -79,6 +81,87 @@ task_factory.register_task_cls(ImageClassificationTask)(
image_classification
.
ImageClassificationTask
)
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain'
)
def
image_classification_imagenet_deit_pretrain
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# originally was 1024 but 4096 better for tpu v3-32
eval_batch_size
=
4096
# originally was 1024 but 4096 better for tpu v3-32
num_classes
=
1001
label_smoothing
=
0.1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
num_classes
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
label_smoothing
,
one_hot
=
False
,
soft_labels
=
True
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
aug_type
=
common
.
Augmentation
(
type
=
'randaug'
,
randaug
=
common
.
RandAugment
(
magnitude
=
9
,
exclude_ops
=
[
'Cutout'
])),
mixup_and_cutmix
=
common
.
MixupAndCutmix
(
label_smoothing
=
label_smoothing
)),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'vit_imagenet_pretrain'
)
def
image_classification_imagenet_vit_pretrain
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
...
...
@@ -90,6 +173,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
...
...
@@ -116,12 +200,13 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
'adamw'
:
{
'weight_decay_rate'
:
0.3
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.003
,
'initial_learning_rate'
:
0.003
*
train_batch_size
/
4096
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
...
...
official/vision/beta/projects/vit/modeling/nn_blocks.py
0 → 100644
浏览文件 @
1c79ece9
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
import
tensorflow
as
tf
from
official.nlp
import
keras_nlp
from
official.vision.beta.modeling.layers.nn_layers
import
StochasticDepth
class
TransformerEncoderBlock
(
keras_nlp
.
layers
.
TransformerEncoderBlock
):
"""TransformerEncoderBlock layer with stochastic depth."""
def
__init__
(
self
,
*
args
,
stochastic_depth_drop_rate
=
0.0
,
**
kwargs
):
"""Initializes TransformerEncoderBlock."""
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
def
build
(
self
,
input_shape
):
if
self
.
_stochastic_depth_drop_rate
:
self
.
_stochastic_depth
=
StochasticDepth
(
self
.
_stochastic_depth_drop_rate
)
else
:
self
.
_stochastic_depth
=
lambda
x
,
*
args
,
**
kwargs
:
tf
.
identity
(
x
)
super
().
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"stochastic_depth_drop_rate"
:
self
.
_stochastic_depth_drop_rate
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
training
=
None
):
"""Transformer self-attention encoder block call."""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
self
.
_output_range
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
)
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
))
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
inner_output
=
self
.
_intermediate_dense
(
attention_output
)
inner_output
=
self
.
_intermediate_activation_layer
(
inner_output
)
inner_output
=
self
.
_inner_dropout_layer
(
inner_output
)
layer_output
=
self
.
_output_dense
(
inner_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
if
self
.
_norm_first
:
return
source_attention_output
+
self
.
_stochastic_depth
(
layer_output
,
training
=
training
)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
return
self
.
_output_layer_norm
(
layer_output
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
))
official/vision/beta/projects/vit/modeling/vit.py
浏览文件 @
1c79ece9
...
...
@@ -17,17 +17,24 @@
import
tensorflow
as
tf
from
official.modeling
import
activations
from
official.nlp
import
keras_nlp
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_layers
from
official.vision.beta.projects.vit.modeling
import
nn_blocks
layers
=
tf
.
keras
.
layers
VIT_SPECS
=
{
'vit-t
esting
'
:
'vit-t
i16
'
:
dict
(
hidden_size
=
1
,
hidden_size
=
1
92
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
1
,
num_heads
=
1
,
num_layers
=
1
),
transformer
=
dict
(
mlp_dim
=
768
,
num_heads
=
3
,
num_layers
=
12
),
),
'vit-s16'
:
dict
(
hidden_size
=
384
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
1536
,
num_heads
=
6
,
num_layers
=
12
),
),
'vit-b16'
:
dict
(
...
...
@@ -112,6 +119,8 @@ class Encoder(tf.keras.layers.Layer):
attention_dropout_rate
=
0.1
,
kernel_regularizer
=
None
,
inputs_positions
=
None
,
init_stochastic_depth_rate
=
0.0
,
kernel_initializer
=
'glorot_uniform'
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_num_layers
=
num_layers
...
...
@@ -121,6 +130,8 @@ class Encoder(tf.keras.layers.Layer):
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_inputs_positions
=
inputs_positions
self
.
_init_stochastic_depth_rate
=
init_stochastic_depth_rate
self
.
_kernel_initializer
=
kernel_initializer
def
build
(
self
,
input_shape
):
self
.
_pos_embed
=
AddPositionEmbs
(
...
...
@@ -131,15 +142,18 @@ class Encoder(tf.keras.layers.Layer):
self
.
_encoder_layers
=
[]
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for
_
in
range
(
self
.
_num_layers
):
encoder_layer
=
keras_nlp
.
layer
s
.
TransformerEncoderBlock
(
for
i
in
range
(
self
.
_num_layers
):
encoder_layer
=
nn_block
s
.
TransformerEncoderBlock
(
inner_activation
=
activations
.
gelu
,
num_attention_heads
=
self
.
_num_heads
,
inner_dim
=
self
.
_mlp_dim
,
output_dropout
=
self
.
_dropout_rate
,
attention_dropout
=
self
.
_attention_dropout_rate
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_initializer
=
self
.
_kernel_initializer
,
norm_first
=
True
,
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
+
1
,
self
.
_num_layers
),
norm_epsilon
=
1e-6
)
self
.
_encoder_layers
.
append
(
encoder_layer
)
self
.
_norm
=
layers
.
LayerNormalization
(
epsilon
=
1e-6
)
...
...
@@ -164,12 +178,14 @@ class VisionTransformer(tf.keras.Model):
num_layers
=
12
,
attention_dropout_rate
=
0.0
,
dropout_rate
=
0.1
,
init_stochastic_depth_rate
=
0.0
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
patch_size
=
16
,
hidden_size
=
768
,
representation_size
=
0
,
classifier
=
'token'
,
kernel_regularizer
=
None
):
kernel_regularizer
=
None
,
original_init
=
True
):
"""VisionTransformer initialization function."""
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
...
...
@@ -178,7 +194,8 @@ class VisionTransformer(tf.keras.Model):
kernel_size
=
patch_size
,
strides
=
patch_size
,
padding
=
'valid'
,
kernel_regularizer
=
kernel_regularizer
)(
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
'lecun_normal'
if
original_init
else
'he_uniform'
)(
inputs
)
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
rows_axis
,
cols_axis
=
(
1
,
2
)
...
...
@@ -203,7 +220,10 @@ class VisionTransformer(tf.keras.Model):
num_heads
=
num_heads
,
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
kernel_regularizer
=
kernel_regularizer
)(
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
'glorot_uniform'
if
original_init
else
dict
(
class_name
=
'TruncatedNormal'
,
config
=
dict
(
stddev
=
.
02
)),
init_stochastic_depth_rate
=
init_stochastic_depth_rate
)(
x
)
if
classifier
==
'token'
:
...
...
@@ -215,7 +235,8 @@ class VisionTransformer(tf.keras.Model):
x
=
tf
.
keras
.
layers
.
Dense
(
representation_size
,
kernel_regularizer
=
kernel_regularizer
,
name
=
'pre_logits'
)(
name
=
'pre_logits'
,
kernel_initializer
=
'lecun_normal'
if
original_init
else
'he_uniform'
)(
x
)
x
=
tf
.
nn
.
tanh
(
x
)
else
:
...
...
@@ -247,9 +268,11 @@ def build_vit(input_specs,
num_layers
=
backbone_cfg
.
transformer
.
num_layers
,
attention_dropout_rate
=
backbone_cfg
.
transformer
.
attention_dropout_rate
,
dropout_rate
=
backbone_cfg
.
transformer
.
dropout_rate
,
init_stochastic_depth_rate
=
backbone_cfg
.
init_stochastic_depth_rate
,
input_specs
=
input_specs
,
patch_size
=
backbone_cfg
.
patch_size
,
hidden_size
=
backbone_cfg
.
hidden_size
,
representation_size
=
backbone_cfg
.
representation_size
,
classifier
=
backbone_cfg
.
classifier
,
kernel_regularizer
=
l2_regularizer
)
kernel_regularizer
=
l2_regularizer
,
original_init
=
backbone_cfg
.
original_init
)
official/vision/beta/projects/yolo/configs/darknet_classification.py
浏览文件 @
1c79ece9
...
...
@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig):
@
exp_factory
.
register_config_factory
(
'darknet_classification'
)
def
image
_classification
()
->
cfg
.
ExperimentConfig
:
def
darknet
_classification
()
->
cfg
.
ExperimentConfig
:
"""Image classification general."""
return
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(),
...
...
official/vision/beta/tasks/image_classification.py
浏览文件 @
1c79ece9
...
...
@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.ops
import
augment
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
...
...
@@ -103,14 +104,26 @@ class ImageClassificationTask(base_task.Task):
decode_jpeg_only
=
params
.
decode_jpeg_only
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_type
=
params
.
aug_type
,
color_jitter
=
params
.
color_jitter
,
random_erasing
=
params
.
random_erasing
,
is_multilabel
=
is_multilabel
,
dtype
=
params
.
dtype
)
postprocess_fn
=
None
if
params
.
mixup_and_cutmix
:
postprocess_fn
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
params
.
mixup_and_cutmix
.
mixup_alpha
,
cutmix_alpha
=
params
.
mixup_and_cutmix
.
cutmix_alpha
,
prob
=
params
.
mixup_and_cutmix
.
prob
,
label_smoothing
=
params
.
mixup_and_cutmix
.
label_smoothing
,
num_classes
=
num_classes
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
),
postprocess_fn
=
postprocess_fn
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
...
...
@@ -140,6 +153,9 @@ class ImageClassificationTask(base_task.Task):
model_outputs
,
from_logits
=
True
,
label_smoothing
=
losses_config
.
label_smoothing
)
elif
losses_config
.
soft_labels
:
total_loss
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
model_outputs
)
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
)
...
...
@@ -161,7 +177,8 @@ class ImageClassificationTask(base_task.Task):
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
not
is_multilabel
:
k
=
self
.
task_config
.
evaluation
.
top_k
if
self
.
task_config
.
losses
.
one_hot
:
if
(
self
.
task_config
.
losses
.
one_hot
or
self
.
task_config
.
losses
.
soft_labels
):
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
...
...
@@ -223,7 +240,9 @@ class ImageClassificationTask(base_task.Task):
# Computes per-replica loss.
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
...
...
@@ -266,14 +285,18 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs.
"""
features
,
labels
=
inputs
one_hot
=
self
.
task_config
.
losses
.
one_hot
soft_labels
=
self
.
task_config
.
losses
.
soft_labels
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
self
.
task_config
.
losses
.
one_hot
and
not
is_multilabel
:
if
(
one_hot
or
soft_labels
)
and
not
is_multilabel
:
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录