提交 b29aeceb 编写于 作者: C Chaochao Yan 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 481733792
上级 5f4e8936
......@@ -75,7 +75,7 @@ task_factory.register_task_cls(ImageClassificationTask)(
image_classification.ImageClassificationTask)
@exp_factory.register_config_factory('deit_imagenet_pretrain')
@exp_factory.register_config_factory('legacy_deit_imagenet_pretrain')
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
......@@ -156,7 +156,7 @@ def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
return config
@exp_factory.register_config_factory('vit_imagenet_pretrain')
@exp_factory.register_config_factory('legacy_vit_imagenet_pretrain')
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096
......@@ -220,7 +220,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
return config
@exp_factory.register_config_factory('vit_imagenet_finetune')
@exp_factory.register_config_factory('legacy_vit_imagenet_finetune')
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 512
......
......@@ -294,7 +294,7 @@ class VisionTransformer(tf.keras.Model):
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)
@factory.register_backbone_builder('vit')
@factory.register_backbone_builder('legacy_vit')
def build_vit(input_specs,
backbone_config,
norm_activation_config,
......
......@@ -14,13 +14,37 @@
"""Backbones configurations."""
import dataclasses
from typing import Optional, List
# Import libraries
from typing import List, Optional, Tuple
from official.modeling import hyperparams
@dataclasses.dataclass
class Transformer(hyperparams.Config):
"""Transformer config."""
mlp_dim: int = 1
num_heads: int = 1
num_layers: int = 1
attention_dropout_rate: float = 0.0
dropout_rate: float = 0.1
@dataclasses.dataclass
class VisionTransformer(hyperparams.Config):
"""VisionTransformer config."""
model_name: str = 'vit-b16'
# pylint: disable=line-too-long
pooler: str = 'token' # 'token', 'gap' or 'none'. If set to 'token', an extra classification token is added to sequence.
# pylint: enable=line-too-long
representation_size: int = 0
hidden_size: int = 1
patch_size: int = 16
transformer: Transformer = Transformer()
init_stochastic_depth_rate: float = 0.0
original_init: bool = True
pos_embed_shape: Optional[Tuple[int, int]] = None
@dataclasses.dataclass
class ResNet(hyperparams.Config):
"""ResNet config."""
......@@ -120,6 +144,7 @@ class Backbone(hyperparams.OneOfConfig):
spinenet_mobile: mobile spinenet backbone config.
mobilenet: mobilenet backbone config.
mobiledet: mobiledet backbone config.
vit: vision transformer backbone config.
"""
type: Optional[str] = None
resnet: ResNet = ResNet()
......@@ -130,4 +155,4 @@ class Backbone(hyperparams.OneOfConfig):
spinenet_mobile: SpineNetMobile = SpineNetMobile()
mobilenet: MobileNet = MobileNet()
mobiledet: MobileDet = MobileDet()
vit: VisionTransformer = VisionTransformer()
......@@ -402,3 +402,201 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
])
return config
@exp_factory.register_config_factory('deit_imagenet_pretrain')
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
eval_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
label_smoothing = 0.1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(
l2_weight_decay=0.0,
label_smoothing=label_smoothing,
one_hot=False,
soft_labels=True),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
aug_type=common.Augmentation(
type='randaug',
randaug=common.RandAugment(
magnitude=9, exclude_ops=['Cutout'])),
mixup_and_cutmix=common.MixupAndCutmix(
label_smoothing=label_smoothing)),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_pretrain')
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096
eval_batch_size = 4096
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16', representation_size=768))),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.3,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003 * train_batch_size / 4096,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 10000,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_finetune')
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 512
eval_batch_size = 512
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[384, 384, 3],
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(model_name='vit-b16'))),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=20000,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9,
'global_clipnorm': 1.0,
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003,
'decay_steps': 20000,
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
......@@ -29,7 +29,10 @@ class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
('resnet_imagenet',),
('resnet_rs_imagenet',),
('revnet_imagenet',),
('mobilenet_imagenet'),
('mobilenet_imagenet',),
('deit_imagenet_pretrain',),
('vit_imagenet_pretrain',),
('vit_imagenet_finetune',),
)
def test_image_classification_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
......
......@@ -23,3 +23,4 @@ from official.vision.modeling.backbones.resnet_deeplab import DilatedResNet
from official.vision.modeling.backbones.revnet import RevNet
from official.vision.modeling.backbones.spinenet import SpineNet
from official.vision.modeling.backbones.spinenet_mobile import SpineNetMobile
from official.vision.modeling.backbones.vit import VisionTransformer
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VisionTransformer models."""
from typing import Optional, Tuple
from absl import logging
import tensorflow as tf
from official.modeling import activations
from official.vision.modeling.backbones import factory
from official.vision.modeling.backbones.vit_specs import VIT_SPECS
from official.vision.modeling.layers import nn_blocks
from official.vision.modeling.layers import nn_layers
layers = tf.keras.layers
class AddPositionEmbs(tf.keras.layers.Layer):
"""Adds (optionally learned) positional embeddings to the inputs."""
def __init__(self,
posemb_init: Optional[tf.keras.initializers.Initializer] = None,
posemb_origin_shape: Optional[Tuple[int, int]] = None,
posemb_target_shape: Optional[Tuple[int, int]] = None,
**kwargs):
"""Constructs Postional Embedding module.
The logic of this module is: the learnable positional embeddings length will
be determined by the inputs_shape or posemb_origin_shape (if provided)
during the construction. If the posemb_target_shape is provided and is
different from the positional embeddings length, the embeddings will be
interpolated during the forward call.
Args:
posemb_init: The positional embedding initializer.
posemb_origin_shape: The intended positional embedding shape.
posemb_target_shape: The potential target shape positional embedding may
be interpolated to.
**kwargs: other args.
"""
super().__init__(**kwargs)
self.posemb_init = posemb_init
self.posemb_origin_shape = posemb_origin_shape
self.posemb_target_shape = posemb_target_shape
def build(self, inputs_shape):
if self.posemb_origin_shape is not None:
pos_emb_length = self.posemb_origin_shape[0] * self.posemb_origin_shape[1]
else:
pos_emb_length = inputs_shape[1]
pos_emb_shape = (1, pos_emb_length, inputs_shape[2])
self.pos_embedding = self.add_weight(
'pos_embedding', pos_emb_shape, initializer=self.posemb_init)
def _interpolate(self, pos_embedding: tf.Tensor, from_shape: Tuple[int, int],
to_shape: Tuple[int, int]) -> tf.Tensor:
"""Interpolates the positional embeddings."""
logging.info('Interpolating postional embedding from length: %d to %d',
from_shape, to_shape)
grid_emb = tf.reshape(pos_embedding, [1] + list(from_shape) + [-1])
# NOTE: Using BILINEAR interpolation by default.
grid_emb = tf.image.resize(grid_emb, to_shape)
return tf.reshape(grid_emb, [1, to_shape[0] * to_shape[1], -1])
def call(self, inputs, inputs_positions=None):
del inputs_positions
pos_embedding = self.pos_embedding
# inputs.shape is (batch_size, seq_len, emb_dim).
if inputs.shape[1] != pos_embedding.shape[1]:
pos_embedding = self._interpolate(
pos_embedding,
from_shape=self.posemb_origin_shape,
to_shape=self.posemb_target_shape)
pos_embedding = tf.cast(pos_embedding, inputs.dtype)
return inputs + pos_embedding
class TokenLayer(tf.keras.layers.Layer):
"""A simple layer to wrap token parameters."""
def build(self, inputs_shape):
self.cls = self.add_weight(
'cls', (1, 1, inputs_shape[-1]), initializer='zeros')
def call(self, inputs):
cls = tf.cast(self.cls, inputs.dtype)
cls = cls + tf.zeros_like(inputs[:, 0:1]) # A hacky way to tile.
x = tf.concat([cls, inputs], axis=1)
return x
class Encoder(tf.keras.layers.Layer):
"""Transformer Encoder."""
def __init__(self,
num_layers,
mlp_dim,
num_heads,
dropout_rate=0.1,
attention_dropout_rate=0.1,
kernel_regularizer=None,
inputs_positions=None,
init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform',
add_pos_embed=True,
pos_embed_origin_shape=None,
pos_embed_target_shape=None,
**kwargs):
super().__init__(**kwargs)
self._num_layers = num_layers
self._mlp_dim = mlp_dim
self._num_heads = num_heads
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._kernel_regularizer = kernel_regularizer
self._inputs_positions = inputs_positions
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer
self._add_pos_embed = add_pos_embed
self._pos_embed_origin_shape = pos_embed_origin_shape
self._pos_embed_target_shape = pos_embed_target_shape
def build(self, input_shape):
if self._add_pos_embed:
self._pos_embed = AddPositionEmbs(
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
posemb_origin_shape=self._pos_embed_origin_shape,
posemb_target_shape=self._pos_embed_target_shape,
name='posembed_input')
self._dropout = layers.Dropout(rate=self._dropout_rate)
self._encoder_layers = []
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
for i in range(self._num_layers):
encoder_layer = nn_blocks.TransformerEncoderBlock(
inner_activation=activations.gelu,
num_attention_heads=self._num_heads,
inner_dim=self._mlp_dim,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
norm_first=True,
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 1, self._num_layers),
norm_epsilon=1e-6)
self._encoder_layers.append(encoder_layer)
self._norm = layers.LayerNormalization(epsilon=1e-6)
super().build(input_shape)
def call(self, inputs, training=None):
x = inputs
if self._add_pos_embed:
x = self._pos_embed(x, inputs_positions=self._inputs_positions)
x = self._dropout(x, training=training)
for encoder_layer in self._encoder_layers:
x = encoder_layer(x, training=training)
x = self._norm(x)
return x
def get_config(self):
config = super().get_config()
updates = {
'num_layers': self._num_layers,
'mlp_dim': self._mlp_dim,
'num_heads': self._num_heads,
'dropout_rate': self._dropout_rate,
'attention_dropout_rate': self._attention_dropout_rate,
'kernel_regularizer': self._kernel_regularizer,
'inputs_positions': self._inputs_positions,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'kernel_initializer': self._kernel_initializer,
'add_pos_embed': self._add_pos_embed,
'pos_embed_origin_shape': self._pos_embed_origin_shape,
'pos_embed_target_shape': self._pos_embed_target_shape,
}
config.update(updates)
return config
class VisionTransformer(tf.keras.Model):
"""Class to build VisionTransformer family model."""
def __init__(self,
mlp_dim=3072,
num_heads=12,
num_layers=12,
attention_dropout_rate=0.0,
dropout_rate=0.1,
init_stochastic_depth_rate=0.0,
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
patch_size=16,
hidden_size=768,
representation_size=0,
pooler='token',
kernel_regularizer=None,
original_init: bool = True,
pos_embed_shape: Optional[Tuple[int, int]] = None):
"""VisionTransformer initialization function."""
self._mlp_dim = mlp_dim
self._num_heads = num_heads
self._num_layers = num_layers
self._hidden_size = hidden_size
self._patch_size = patch_size
inputs = tf.keras.Input(shape=input_specs.shape[1:])
x = layers.Conv2D(
filters=hidden_size,
kernel_size=patch_size,
strides=patch_size,
padding='valid',
kernel_regularizer=kernel_regularizer,
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
inputs)
if tf.keras.backend.image_data_format() == 'channels_last':
rows_axis, cols_axis = (1, 2)
else:
rows_axis, cols_axis = (2, 3)
# The reshape below assumes the data_format is 'channels_last,' so
# transpose to that. Once the data is flattened by the reshape, the
# data_format is irrelevant, so no need to update
# tf.keras.backend.image_data_format.
x = tf.transpose(x, perm=[0, 2, 3, 1])
pos_embed_target_shape = (x.shape[rows_axis], x.shape[cols_axis])
seq_len = (input_specs.shape[rows_axis] // patch_size) * (
input_specs.shape[cols_axis] // patch_size)
x = tf.reshape(x, [-1, seq_len, hidden_size])
# If we want to add a class token, add it here.
if pooler == 'token':
x = TokenLayer(name='cls')(x)
x = Encoder(
num_layers=num_layers,
mlp_dim=mlp_dim,
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
kernel_regularizer=kernel_regularizer,
kernel_initializer='glorot_uniform' if original_init else dict(
class_name='TruncatedNormal', config=dict(stddev=.02)),
init_stochastic_depth_rate=init_stochastic_depth_rate,
pos_embed_origin_shape=pos_embed_shape,
pos_embed_target_shape=pos_embed_target_shape)(
x)
if pooler == 'token':
x = x[:, 0]
elif pooler == 'gap':
x = tf.reduce_mean(x, axis=1)
elif pooler == 'none':
x = tf.identity(x, name='encoded_tokens')
else:
raise ValueError(f'unrecognized pooler type: {pooler}')
if representation_size:
x = tf.keras.layers.Dense(
representation_size,
kernel_regularizer=kernel_regularizer,
name='pre_logits',
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
x)
x = tf.nn.tanh(x)
else:
x = tf.identity(x, name='pre_logits')
if pooler == 'none':
endpoints = {'encoded_tokens': x}
else:
endpoints = {
'pre_logits':
tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
}
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)
@factory.register_backbone_builder('vit')
def build_vit(input_specs,
backbone_config,
norm_activation_config,
l2_regularizer=None):
"""Build ViT model."""
del norm_activation_config
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'vit', (f'Inconsistent backbone type '
f'{backbone_type}')
backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name])
return VisionTransformer(
mlp_dim=backbone_cfg.transformer.mlp_dim,
num_heads=backbone_cfg.transformer.num_heads,
num_layers=backbone_cfg.transformer.num_layers,
attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
dropout_rate=backbone_cfg.transformer.dropout_rate,
init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
input_specs=input_specs,
patch_size=backbone_cfg.patch_size,
hidden_size=backbone_cfg.hidden_size,
representation_size=backbone_cfg.representation_size,
pooler=backbone_cfg.pooler,
kernel_regularizer=l2_regularizer,
original_init=backbone_cfg.original_init,
pos_embed_shape=backbone_cfg.pos_embed_shape)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VisionTransformer backbone specs."""
import immutabledict
VIT_SPECS = immutabledict.immutabledict({
'vit-ti16':
dict(
hidden_size=192,
patch_size=16,
transformer=dict(mlp_dim=768, num_heads=3, num_layers=12),
),
'vit-s16':
dict(
hidden_size=384,
patch_size=16,
transformer=dict(mlp_dim=1536, num_heads=6, num_layers=12),
),
'vit-b16':
dict(
hidden_size=768,
patch_size=16,
transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
),
'vit-b32':
dict(
hidden_size=768,
patch_size=32,
transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
),
'vit-l16':
dict(
hidden_size=1024,
patch_size=16,
transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
),
'vit-l32':
dict(
hidden_size=1024,
patch_size=32,
transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
),
'vit-h14':
dict(
hidden_size=1280,
patch_size=14,
transformer=dict(mlp_dim=5120, num_heads=16, num_layers=32),
),
'vit-g14':
dict(
hidden_size=1664,
patch_size=14,
transformer=dict(mlp_dim=8192, num_heads=16, num_layers=48),
),
})
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for VIT."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.modeling.backbones import vit
class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(224, 85798656),
(256, 85844736),
)
def test_network_creation(self, input_size, params_count):
"""Test creation of VisionTransformer family models."""
tf.keras.backend.set_image_data_format('channels_last')
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(input_specs=input_specs)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
self.assertEqual(network.count_params(), params_count)
def test_network_none_pooler(self):
tf.keras.backend.set_image_data_format('channels_last')
input_size = 256
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
pooler='none',
representation_size=128,
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)['encoded_tokens']
self.assertEqual(output.shape, [1, 256, 128])
def test_posembedding_interpolation(self):
tf.keras.backend.set_image_data_format('channels_last')
input_size = 256
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
pooler='gap',
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)['pre_logits']
self.assertEqual(output.shape, [1, 1, 1, 768])
if __name__ == '__main__':
tf.test.main()
......@@ -27,20 +27,49 @@ from official.vision.modeling import classification_model
class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(192 * 4, 3, 12, 192, 5524416),
(384 * 4, 6, 12, 384, 21665664),
)
def test_vision_transformer_creation(self, mlp_dim, num_heads, num_layers,
hidden_size, num_params):
"""Test for creation of a Vision Transformer classifier."""
inputs = np.random.rand(2, 224, 224, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.VisionTransformer(
mlp_dim=mlp_dim,
num_heads=num_heads,
num_layers=num_layers,
hidden_size=hidden_size,
input_specs=tf.keras.layers.InputSpec(shape=[None, 224, 224, 3]),
)
self.assertEqual(backbone.count_params(), num_params)
num_classes = 1000
model = classification_model.ClassificationModel(
backbone=backbone,
num_classes=num_classes,
dropout_rate=0.2,
)
logits = model(inputs)
self.assertAllEqual([2, num_classes], logits.numpy().shape)
@parameterized.parameters(
(128, 50, 'relu'),
(128, 50, 'relu'),
(128, 50, 'swish'),
)
def test_resnet_network_creation(
self, input_size, resnet_model_id, activation):
def test_resnet_network_creation(self, input_size, resnet_model_id,
activation):
"""Test for creation of a ResNet-50 classifier."""
inputs = np.random.rand(2, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.ResNet(
model_id=resnet_model_id, activation=activation)
backbone = backbones.ResNet(model_id=resnet_model_id, activation=activation)
self.assertEqual(backbone.count_params(), 23561152)
num_classes = 1000
......
......@@ -21,6 +21,7 @@ from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp import modeling as nlp_modeling
from official.vision.modeling.layers import nn_layers
......@@ -538,8 +539,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
se_inner_activation: A `str` name of squeeze-excitation inner activation.
se_gating_activation: A `str` name of squeeze-excitation gating
activation.
se_round_down_protect: A `bool` of whether round down more than 10%
will be allowed in SE layer.
se_round_down_protect: A `bool` of whether round down more than 10% will
be allowed in SE layer.
expand_se_in_filters: A `bool` of whether or not to expand in_filter in
squeeze and excitation layer.
depthwise_activation: A `str` name of the activation function for
......@@ -547,9 +548,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
dilation_rate: An `int` that specifies the dilation rate to use for.
divisible_by: An `int` that ensures all inner dimensions are divisible by
this number.
dilated convolution: An `int` to specify the same value for all spatial
dimensions.
this number. dilated convolution: An `int` to specify the same value for
all spatial dimensions.
regularize_depthwise: A `bool` of whether or not apply regularization on
depthwise.
use_depthwise: A `bool` of whether to uses fused convolutions instead of
......@@ -1048,7 +1048,7 @@ class ReversibleLayer(tf.keras.layers.Layer):
(bottleneck) residual functions. Where the input to the reversible layer
is x, the input gets partitioned in the channel dimension and the
forward pass follows (eq8): x = [x1; x2], z1 = x1 + f(x2), y2 = x2 +
g(z1), y1 = stop_gradient(z1).
g(z1), y1 = stop_gradient(z1).
g: A `tf.keras.layers.Layer` instance of `g` inner block referred to in
paper. Detailed explanation same as above as `f` arg.
manual_grads: A `bool` [Testing Only] of whether to manually take
......@@ -1204,7 +1204,8 @@ class ReversibleLayer(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='Vision')
class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
"""Creates an depthwise separable convolution block with batch normalization."""
"""Creates a depthwise separable convolution block with batch normalization.
"""
def __init__(
self,
......@@ -1354,10 +1355,10 @@ class TuckerConvBlock(tf.keras.layers.Layer):
Args:
in_filters: An `int` number of filters of the input tensor.
out_filters: An `int` number of filters of the output tensor.
input_compression_ratio: An `float` of compression ratio for
input filters.
output_compression_ratio: An `float` of compression ratio for
output filters.
input_compression_ratio: An `float` of compression ratio for input
filters.
output_compression_ratio: An `float` of compression ratio for output
filters.
strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input.
kernel_size: An `int` kernel_size of the depthwise conv layer.
......@@ -1510,11 +1511,114 @@ class TuckerConvBlock(tf.keras.layers.Layer):
x = self._conv2(x)
x = self._norm2(x)
if (self._use_residual and
self._in_filters == self._out_filters and
if (self._use_residual and self._in_filters == self._out_filters and
self._strides == 1):
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
x = self._add([x, shortcut])
return x
class TransformerEncoderBlock(nlp_modeling.layers.TransformerEncoderBlock):
"""TransformerEncoderBlock layer with stochastic depth."""
def __init__(self,
*args,
stochastic_depth_drop_rate=0.0,
return_attention=False,
**kwargs):
"""Initializes TransformerEncoderBlock."""
super().__init__(*args, **kwargs)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._return_attention = return_attention
def build(self, input_shape):
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x)
super().build(input_shape)
def get_config(self):
config = {'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
"""Transformer self-attention encoder block call."""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError('Unexpected inputs to %s with length at %d' %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output, attention_scores = self._attention_layer(
query=target_tensor,
value=key_value,
attention_mask=attention_mask,
return_attention_scores=True)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + self._stochastic_depth(
attention_output, training=training)
else:
attention_output = self._attention_layer_norm(
target_tensor +
self._stochastic_depth(attention_output, training=training))
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
if self._return_attention:
return source_attention_output + self._stochastic_depth(
layer_output, training=training), attention_scores
else:
return source_attention_output + self._stochastic_depth(
layer_output, training=training)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
if self._return_attention:
return self._output_layer_norm(layer_output + self._stochastic_depth(
attention_output, training=training)), attention_scores
else:
return self._output_layer_norm(
layer_output +
self._stochastic_depth(attention_output, training=training))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册