未验证 提交 c10da939 编写于 作者: Z Zihan Wang 提交者: GitHub

Merge branch 'tensorflow:master' into master

......@@ -25,7 +25,7 @@ class FakeKerasModel(tf.keras.Model):
self.dense = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
def call(self, inputs):
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
return self.dense2(self.dense(inputs))
......
......@@ -919,10 +919,6 @@ class AutoAugment(ImageAugment):
the policy.
"""
# TODO(dankondratyuk): tensorflow_addons defines custom ops, which
# for some reason are not included when building/linking
# This results in the error, "Op type not registered
# 'Addons>ImageProjectiveTransformV2' in binary" when running on borg TPUs
policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
......
......@@ -17,17 +17,171 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Dict, Optional, Text
from typing import Any, Dict, Optional, Text, Union
from absl import logging
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from official.legacy.image_classification import learning_rate
from official.legacy.image_classification.configs import base_configs
from official.modeling import optimization
from official.modeling.optimization import legacy_adamw
# pylint: disable=protected-access
FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64]
class Lookahead(tf.keras.optimizers.legacy.Optimizer):
"""This class allows to extend optimizers with the lookahead mechanism.
The mechanism is proposed by Michael R. Zhang et.al in the paper [Lookahead
Optimizer: k steps forward, 1 step back] (https://arxiv.org/abs/1907.08610v1).
The optimizer iteratively updates two sets of weights: the search directions
for weights are chosen by the inner optimizer, while the "slow weights" are
updated each `k` steps based on the directions of the "fast weights" and the
two sets of weights are synchronized. This method improves the learning
stability and lowers the variance of its inner optimizer.
Example of usage:
```python
opt = tf.keras.optimizers.SGD(learning_rate) opt =
tfa.optimizers.Lookahead(opt)
```
"""
def __init__(
self,
optimizer: tf.keras.optimizers.Optimizer,
sync_period: int = 6,
slow_step_size: FloatTensorLike = 0.5,
name: str = 'Lookahead',
**kwargs,
):
"""Wrap optimizer with the lookahead mechanism.
Args:
optimizer: The original optimizer that will be used to compute and apply
the gradients.
sync_period: An integer. The synchronization period of lookahead. Enable
lookahead mechanism by setting it with a positive value.
slow_step_size: A floating point value. The ratio for updating the slow
weights.
name: Optional name for the operations created when applying gradients.
Defaults to "Lookahead".
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
`decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
gradients by value, `decay` is included for backward compatibility to
allow time inverse decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
"""
super().__init__(name, **kwargs)
if isinstance(optimizer, str):
optimizer = tf.keras.optimizers.get(optimizer)
if not isinstance(
optimizer,
(tf.keras.optimizers.Optimizer, tf.keras.optimizers.legacy.Optimizer),
):
raise TypeError(
'optimizer is not an object of tf.keras.optimizers.Optimizer'
)
self._optimizer = optimizer
self._set_hyper('sync_period', sync_period)
self._set_hyper('slow_step_size', slow_step_size)
self._initialized = False
self._track_trackable(self._optimizer, 'lh_base_optimizer')
def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access
for var in var_list:
self.add_slot(var, 'slow', initializer=var)
def _create_hypers(self):
self._optimizer._create_hypers() # pylint: disable=protected-access
def _prepare(self, var_list):
return self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access
def apply_gradients(
self, grads_and_vars, name=None, skip_gradients_aggregation=None, **kwargs
):
self._optimizer._iterations = self.iterations # pylint: disable=protected-access
return super().apply_gradients(grads_and_vars, name, **kwargs)
def _look_ahead_op(self, var):
var_dtype = var.dtype.base_dtype
slow_var = self.get_slot(var, 'slow')
local_step = tf.cast(self.iterations + 1, tf.dtypes.int64)
sync_period = self._get_hyper('sync_period', tf.dtypes.int64)
slow_step_size = self._get_hyper('slow_step_size', var_dtype)
step_back = slow_var + slow_step_size * (var - slow_var)
sync_cond = tf.equal(
tf.math.floordiv(local_step, sync_period) * sync_period, local_step
)
with tf.control_dependencies([step_back]):
slow_update = slow_var.assign(
tf.where(sync_cond, step_back, slow_var),
use_locking=self._use_locking,
)
var_update = var.assign(
tf.where(sync_cond, step_back, var), use_locking=self._use_locking
)
return tf.group(slow_update, var_update)
@property
def weights(self):
return self._weights + self._optimizer.weights
def _resource_apply_dense(self, grad, var):
train_op = self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access
with tf.control_dependencies([train_op]):
look_ahead_op = self._look_ahead_op(var)
return tf.group(train_op, look_ahead_op)
def _resource_apply_sparse(self, grad, var, indices):
train_op = self._optimizer._resource_apply_sparse( # pylint: disable=protected-access
grad, var, indices
)
with tf.control_dependencies([train_op]):
look_ahead_op = self._look_ahead_op(var)
return tf.group(train_op, look_ahead_op)
def get_config(self):
config = {
'optimizer': tf.keras.optimizers.serialize(self._optimizer),
'sync_period': self._serialize_hyperparameter('sync_period'),
'slow_step_size': self._serialize_hyperparameter('slow_step_size'),
}
base_config = super().get_config()
return {**base_config, **config}
@property
def learning_rate(self):
return self._optimizer._get_hyper('learning_rate')
@learning_rate.setter
def learning_rate(self, value):
self._optimizer._set_hyper('learning_rate', value)
@property
def lr(self):
return self.learning_rate
@lr.setter
def lr(self, lr):
self.learning_rate = lr
@classmethod
def from_config(cls, config, custom_objects=None):
optimizer = tf.keras.optimizers.deserialize(
config.pop('optimizer'), custom_objects=custom_objects
)
return cls(optimizer, **config)
def build_optimizer(
optimizer_name: Text,
......@@ -95,18 +249,19 @@ def build_optimizer(
beta_1 = params.get('beta_1', 0.9)
beta_2 = params.get('beta_2', 0.999)
epsilon = params.get('epsilon', 1e-07)
optimizer = tfa.optimizers.AdamW(
weight_decay=weight_decay,
optimizer = legacy_adamw.AdamWeightDecay(
learning_rate=base_learning_rate,
weight_decay_rate=weight_decay,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon)
epsilon=epsilon,
)
else:
raise ValueError('Unknown optimizer %s' % optimizer_name)
if params.get('lookahead', None):
logging.info('Using lookahead optimizer.')
optimizer = tfa.optimizers.Lookahead(optimizer)
optimizer = Lookahead(optimizer)
# Moving average should be applied last, as it's applied at test time
moving_average_decay = params.get('moving_average_decay', 0.)
......
......@@ -57,7 +57,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
base_learning_rate=params['learning_rate'],
params=params,
model=model)
self.assertTrue(issubclass(type(optimizer), tf.keras.optimizers.Optimizer))
self.assertTrue(
issubclass(type(optimizer), tf.keras.optimizers.legacy.Optimizer)
)
def test_unknown_optimizer(self):
with self.assertRaises(ValueError):
......
......@@ -31,7 +31,7 @@ class MockFooModel(tf.keras.Model):
self.inputs = {"foo": tf.keras.Input(shape=(2,), dtype=tf.float32),
"bar": tf.keras.Input(shape=(2,), dtype=tf.float32)}
def call(self, inputs):
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
self.add_loss(tf.zeros((1,), dtype=tf.float32))
if "foo" in inputs:
input_tensor = inputs["foo"]
......@@ -49,7 +49,7 @@ class MockBarModel(tf.keras.Model):
self._bar_specific_layer = tf.keras.layers.Dense(1)
self.inputs = {"bar": tf.keras.Input(shape=(2,), dtype=tf.float32)}
def call(self, inputs):
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
self.add_loss(tf.zeros((2,), dtype=tf.float32))
return self._bar_specific_layer(self._share_layer(inputs["bar"]))
......
......@@ -14,7 +14,7 @@
"""Exponential moving average optimizer."""
from typing import List, Optional, Text
from typing import List, Optional
import tensorflow as tf
......@@ -79,7 +79,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.legacy.Optimizer):
average_decay: float = 0.99,
start_step: int = 0,
dynamic_decay: bool = True,
name: Text = 'ExponentialMovingAverage',
name: str = 'ExponentialMovingAverage',
**kwargs):
"""Construct a new ExponentialMovingAverage optimizer.
......@@ -107,7 +107,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.legacy.Optimizer):
self._start_step = tf.constant(start_step, tf.float32)
self._dynamic_decay = dynamic_decay
self._optimizer = optimizer
self._track_trackable(self._optimizer, 'base_optimizer')
self._track_trackable(self._optimizer, 'ema_base_optimizer')
self._average_weights = None
self._model_weights = None
......
......@@ -460,10 +460,6 @@ class StepCosineDecayWithOffset(
tf.constant(math.pi) * (global_step) /
(init_total_steps)) + 1.0) / 2.0 + next_init_lr)
learning_rate = cosine_learning_rate
tf.compat.v1.logging.info("DEBUG lr %r next lr %r", learning_rate,
cosine_learning_rate)
tf.compat.v1.logging.info("DEBUG lr %r next lr %r inittotalstep %r",
init_lr, next_init_lr, init_total_steps)
for i in range(1, num_levels):
next_init_lr = lr_levels[i]
......@@ -471,9 +467,6 @@ class StepCosineDecayWithOffset(
next_total_steps = level_total_steps[i]
next_next_init_lr = lr_levels[i + 1] if num_levels > i + 1 else 0.
tf.compat.v1.logging.info(
"DEBUG step %r nilr %r nss %r nts %r nnilr %r", global_step,
next_init_lr, next_start_step, next_total_steps, next_next_init_lr)
next_cosine_learning_rate = ((next_init_lr - next_next_init_lr) *
(tf.cos(
tf.constant(math.pi) *
......@@ -482,8 +475,6 @@ class StepCosineDecayWithOffset(
next_next_init_lr)
learning_rate = tf.where(global_step >= next_start_step,
next_cosine_learning_rate, learning_rate)
tf.compat.v1.logging.info("DEBUG lr %r next lr %r", learning_rate,
next_cosine_learning_rate)
return learning_rate
......
......@@ -10,7 +10,6 @@ scipy>=0.19.1
tensorflow-hub>=0.6.0
tensorflow-model-optimization>=0.4.1
tensorflow-datasets
tfa-nightly
gin-config
tf_slim>=1.1.0
Cython
......
......@@ -458,7 +458,7 @@ class BigBirdAttention(tf.keras.layers.MultiHeadAttention):
to_block_size=self._to_block_size,
rand_attn=rand_attn)
def call(self, query, value, key=None, attention_mask=None, **kwargs):
def call(self, query, value, key=None, attention_mask=None, **kwargs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
......
......@@ -281,5 +281,5 @@ def _pick_fourier_transform(
return functools.partial(
two_dim_matmul,
matrix_dim_one=tf.convert_to_tensor(dft_mat_seq),
matrix_dim_two=tf.convert_to_tensor(dft_mat_hidden))
matrix_dim_one=dft_mat_seq,
matrix_dim_two=dft_mat_hidden)
......@@ -67,7 +67,7 @@ class PerDimScaleAttention(tf.keras.layers.MultiHeadAttention):
attention_scores_dropout, value)
return attention_output, attention_scores
def call(
def call( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
self,
query,
value,
......
......@@ -228,7 +228,7 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
value)
return attention_output
def call(self,
def call(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
query,
value,
content_attention_bias,
......
......@@ -77,7 +77,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
super().__init__(
layer, name=wrapper_name, **kwargs)
def build(self, input_shape):
def build(self, input_shape): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
super().build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
......@@ -195,7 +195,7 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
.format(input=layer))
super().__init__(layer, **kwargs)
def build(self, input_shape):
def build(self, input_shape): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
if not self.layer.built:
self.layer.build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
......
......@@ -27,7 +27,7 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
"""
def setUp(self):
super(TNLayerTest, self).setUp()
super().setUp()
self.labels = np.concatenate((np.ones((50, 1)), np.zeros((50, 1))), axis=0)
def _build_model(self, data, proj_multiple=2):
......@@ -41,21 +41,6 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
return model
@parameterized.parameters((768, 6), (1024, 2))
def test_keras_layer(self, input_dim, proj_multiple):
data = np.random.normal(size=(100, input_dim))
data = data.astype(np.float32)
tf.keras.__internal__.utils.layer_test(
TNExpandCondense,
kwargs={
'proj_multiplier': proj_multiple,
'input_shape': data.shape
},
input_shape=data.shape,
input_data=data,
expected_output_shape=(None, data.shape[-1]),
expected_output_dtype=data.dtype)
@parameterized.parameters((768, 6), (1024, 2))
def test_train(self, input_dim, proj_multiple):
tf.keras.utils.set_random_seed(0)
......
......@@ -226,7 +226,7 @@ class BertPretrainerV2(tf.keras.Model):
inputs.append(masked_lm_positions)
self.inputs = inputs
def call(self, inputs):
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
if isinstance(inputs, list):
logging.warning('List inputs to BertPretrainer are discouraged.')
inputs = dict([
......
......@@ -113,7 +113,7 @@ class ElectraPretrainer(tf.keras.Model):
units=1,
kernel_initializer=tf_utils.clone_initializer(mlm_initializer))
def call(self, inputs):
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
"""ELECTRA forward pass.
Args:
......
......@@ -144,7 +144,7 @@ class Seq2SeqTransformer(tf.keras.Model):
return embedded_inputs, boolean_mask, input_shape, source_dtype
def call(self, inputs):
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
"""Calculate target logits or inferred target sequences.
Args:
......
......@@ -117,7 +117,7 @@ class XLNetPretrainer(tf.keras.Model):
hidden_size=self._hidden_size,
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
def call(self, inputs: Mapping[str, Any]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
input_word_ids = inputs['input_word_ids']
input_type_ids = inputs['input_type_ids']
masked_tokens = inputs['masked_tokens']
......@@ -212,7 +212,7 @@ class XLNetClassifier(tf.keras.Model):
cls_token_idx=cls_token_idx,
name=head_name)
def call(self, inputs: Mapping[str, Any]):
def call(self, inputs: Mapping[str, Any]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
input_ids = inputs['input_word_ids']
segment_ids = inputs['input_type_ids']
input_mask = tf.cast(inputs['input_mask'], tf.float32)
......@@ -305,7 +305,7 @@ class XLNetSpanLabeler(tf.keras.Model):
dropout_rate=self._dropout_rate,
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
def call(self, inputs: Mapping[str, Any]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
input_word_ids = inputs['input_word_ids']
input_type_ids = inputs['input_type_ids']
input_mask = inputs['input_mask']
......
......@@ -452,6 +452,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
def call(self, inputs, output_range: Optional[tf.Tensor] = None):
# inputs are [word_ids, mask, type_ids]
word_embeddings = None
if isinstance(inputs, (list, tuple)):
logging.warning('List inputs to %s are discouraged.', self.__class__)
if len(inputs) == 3:
......@@ -472,6 +473,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
word_ids = inputs.get('input_word_ids')
mask = inputs.get('input_mask')
type_ids = inputs.get('input_type_ids')
word_embeddings = inputs.get('input_word_embeddings', None)
dense_inputs = inputs.get('dense_inputs', None)
dense_mask = inputs.get('dense_mask', None)
......@@ -479,7 +481,8 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
else:
raise ValueError('Unexpected inputs type to %s.' % self.__class__)
word_embeddings = self._embedding_layer(word_ids)
if word_embeddings is None:
word_embeddings = self._embedding_layer(word_ids)
if dense_inputs is not None:
# Concat the dense embeddings at sequence begin so unpool_len can control
......
......@@ -320,6 +320,43 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(outputs[0].shape[-1], hidden_size)
self.assertTrue(hasattr(test_network, "_embedding_projection"))
def test_embeddings_as_inputs(self):
hidden_size = 32
sequence_length = 21
# Create a small BertEncoder for testing.
test_network = funnel_transformer.FunnelTransformerEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
pool_stride=2,
)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
test_network.build(
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids)
)
embeddings = test_network.get_embedding_layer()(word_ids)
# Calls with the embeddings.
dict_outputs = test_network(
dict(
input_word_embeddings=embeddings,
input_mask=mask,
input_type_ids=type_ids,
)
)
all_encoder_outputs = dict_outputs["encoder_outputs"]
pooled = dict_outputs["pooled_output"]
expected_pooled_shape = [None, hidden_size]
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
# The default output dtype is float32.
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
self.assertAllEqual(tf.float32, pooled.dtype)
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
kwargs = dict(
......
......@@ -88,7 +88,7 @@ class BASNetModel(tf.keras.Model):
self.decoder = decoder
self.refinement = refinement
def call(self, inputs, training=None):
def call(self, inputs, training=None): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
features = self.backbone(inputs)
if self.decoder:
......
......@@ -41,7 +41,7 @@ class CenterNetModel(tf.keras.Model):
self._detection_generator = detection_generator
self._head = head
def call(self,
def call(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
inputs: tf.Tensor,
training: bool = None,
**kwargs) -> Mapping[str, tf.Tensor]:
......
......@@ -439,7 +439,7 @@ class HourglassNetwork(tf.keras.Model):
self.intermediate_relu = tf.keras.layers.ReLU()
def call(self, inputs):
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
if self.initial_downsample:
inputs = self.downsample_input(inputs)
......
......@@ -222,7 +222,7 @@ class DETR(tf.keras.Model):
mask, target_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return mask
def call(self, inputs: tf.Tensor, training: bool = None) -> List[Any]:
def call(self, inputs: tf.Tensor, training: bool = None) -> List[Any]: # pytype: disable=signature-mismatch # overriding-parameter-count-checks
batch_size = tf.shape(inputs)[0]
features = self._backbone(inputs)[self._backbone_endpoint_name]
shape = tf.shape(features)
......
......@@ -346,7 +346,7 @@ class GroupConv2DKerasModel(tf.keras.Model):
self.batch_norm_layer(
axis=-1, momentum=bn_momentum, epsilon=bn_epsilon)) # pytype: disable=bad-return-type # typed-keras
def call(self, inputs: Any) -> Any:
def call(self, inputs: Any) -> Any: # pytype: disable=signature-mismatch # overriding-parameter-count-checks
"""Applies 2d group convolution on the inputs."""
input_shape = inputs.get_shape().as_list()
if input_shape[-1] % self._groups != 0:
......
......@@ -208,7 +208,7 @@ class AutosegEdgeTPU(tf.keras.Model):
fullres_output=fullres_output,
num_classes=num_classes)
def call(self, inputs, training):
def call(self, inputs, training): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
# call backbone network.
all_feats = self.backbone(inputs, training=training)
if self.use_original_backbone_features:
......
......@@ -43,7 +43,7 @@ class ViTClassifier(tf.keras.Model):
num_classes,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=2e-5))
def call(self, inputs):
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
encoded = self.encoder({'images': inputs})
return self.linear(encoded[:, 0])
......@@ -64,7 +64,7 @@ class ViTLinearClassifier(tf.keras.Model):
self.batch_norm = self._norm(
axis=-1, epsilon=1e-6, center=False, scale=False, momentum=0.9)
def call(self, inputs, training=False):
def call(self, inputs, training=False): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
encoded = self.encoder({'images': inputs})
features = self.batch_norm(encoded[:, 0], training=training)
return self.linear(features)
......@@ -108,7 +108,7 @@ class VisionTransformer(tf.keras.Model):
return patch_embeds + utils.position_embedding_sine(
tf.ones_like(patch_embeds[..., 0]), 1024, normalize=False)
def call(self, inputs):
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
if isinstance(inputs, dict):
images = inputs.get('images', None)
patch_embeds = inputs.get('embeddings', None)
......
# MaxViT: Multi-Axis Vision Transformer (ECCV 2022)
[![Paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2204.01697)
⚠️ **DISCLAIMER**: This implementation is still under development.
[TOC]
[MaxViT](https://arxiv.org/abs/2204.01697) is a family of hybrid (CNN + ViT)
vision backbone models, that achieves better performances across the board
for both parameter and FLOPs efficiency than both state-of-the-art ConvNets and
Transformers ([Blog](https://ai.googleblog.com/2022/09/a-multi-axis-approach-for-vision.html)).
They can also scale well on large dataset sizes like ImageNet-21K.
Notably, due to the linear-complexity of the grid attention used, MaxViT scales
well on tasks requiring large image sizes, such as object detection and
segmentation.
MaxViT meta-architecture: a homogeneously stacked backbone, wherein each MaxViT
block contains [MBConv](https://arxiv.org/abs/2104.00298), block attention
(window-based local attention), and grid attention (dilated global attention).
<p align="center">
<img src = "./docs/maxvit_arch.png" width="80%">
</p>
Results on ImageNet-1k standard train and test:
<p align="center">
<img src = "./docs/imagenet_results.png" width="80%">
</p>
Results on ImageNet-21k and JFT pre-trained models:
<p align="center">
<img src = "./docs/i21k_jft_results.png" width="80%">
</p>
### Model Performance
Note: Deit ImageNet pretrain experimental settings are different from the
paper. These experiments follows the pre-training hyperparameters in
[paper](https://arxiv.org/abs/2204.01697) and only run pre-training for similar
number of steps. The paper suggested a short fine-tuning with different
hyper-parameters and EMA.
<section class="tabs">
#### Deit ImageNet pretrain {.new-tab}
Model | Eval Size | Top-1 Acc | Acc on Paper | #Param | #FLOPs | Config
------------- | --------- | :---------: | :----------: | :----: | :----: | :----:
MaxViT-Tiny | 224x224 | 83.1 (-0.5) | 83.6 | 31M | 5.6G | [config](configs/experiments/maxvit_tiny_imagenet.yaml)
MaxViT-Small | 224x224 | 84.1 (-0.3) | 84.4 | 69M | 11.7G | [config](configs/experiments/maxvit_small_imagenet.yaml)
MaxViT-Base | 224x224 | 84.2 (-0.7) | 84.9 | 120M | 23.4G | [config](configs/experiments/maxvit_base_imagenet.yaml)
MaxViT-Large | 224x224 | 84.6 (-0.6) | 85.2 | 212M | 43.9G | [config](configs/experiments/maxvit_large_imagenet.yaml)
MaxViT-XLarge | 224x224 | 84.8 | - | 475M | 97.9G | [config](configs/experiments/maxvit_xlarge_imagenet.yaml)
#### Cascade RCNN models {.new-tab}
Model | Image Size | Window Size | Epochs | box AP | box AP on paper | mask AP | Config
------------ | ---------: | :---------: | :----: | :-----------: | :-------------: | :-----: | :----:
MaxViT-Tiny | 640x640 | 20x20 | 200 | 49.97 | - | 42.69 | [config](configs/experiments/coco_maxvitt_i640_crcnn.yaml)
MaxViT-Tiny | 896x896 | 28x28 | 200 | 52.35 (+0.25) | 52.1 | 44.69 | -
MaxViT-Small | 640x640 | 20x20 | 200 | 50.79 | - | 43.36 | -
MaxViT-Small | 896x896 | 28x28 | 200 | 53.54 (+0.44) | 53.1 | 45.79 | [config](configs/experiments/coco_maxvits_i896_crcnn.yaml)
MaxViT-Base | 640x640 | 20x20 | 200 | 51.59 | - | 44.07 | [config](configs/experiments/coco_maxvitb_i640_crcnn.yaml)
MaxViT-Base | 896x896 | 28x28 | 200 | 53.47 (+0.07) | 53.4 | 45.96 | [config](configs/experiments/coco_maxvitb_i896_crcnn.yaml)
</section>
<section class="tabs">
#### JFT-300M supervised pretrain {.new-tab}
Model | Pretrain Size | #Param | #FLOPs | globalPR-AUC
------------- | :------------ | :----: | :----: | :----------:
MaxViT-Base | 224x224 | 120M | 23.4G | 52.75%
MaxViT-Large | 224x224 | 212M | 43.9G | 53.77%
MaxViT-XLarge | 224x224 | 475M | - | 54.71%
#### ImageNet Finetuning {.new-tab}
Model | Image Size | Top-1 Acc | Acc on Paper | #Param | #FLOPs | Config
------------- | :--------- | :-------------: | :----------: | :----: | :----: | :----:
MaxViT-Base | 384x384 | 88.37% (-0.32%) | 88.69% | 120M | 74.2G | [config](configs/experiments/finetune_maxvitb_imagenet_i384.yaml)
MaxViT-Base | 512x512 | 88.63% (-0.19%) | 88.82% | 120M | 138.3G | [config](configs/experiments/finetune_maxvitb_imagenet_i512.yaml)
MaxViT-Large | 384x384 | 88.86% (-0.26%) | 89.12% | 212M | 128.7G | [config](configs/experiments/finetune_maxvitl_imagenet_i384.yaml)
MaxViT-Large | 512x512 | 89.02% (-0.39%) | 89.41% | 212M | 245.2G | [config](configs/experiments/finetune_maxvitl_imagenet_i512.yaml)
MaxViT-XLarge | 384x384 | 89.21% (-0.15%) | 89.36% | 475M | 293.7G | [config](configs/experiments/finetune_maxvitxl_imagenet_i384.yaml)
MaxViT-XLarge | 512x512 | 89.31% (-0.22%) | 89.53% | 475M | 535.2G | [config](configs/experiments/finetune_maxvitxl_imagenet_i512.yaml)
#### Cascade RCNN models {.new-tab}
Model | Image Size | Window Size | Epochs | box AP | box AP on paper | mask AP | Config
------------ | ---------: | :---------: | :----: | :-----------: | :-------------: | :-----: | :----:
MaxViT-Base | 896x896 | 28x28 | 200 | 54.31 (+0.91) | 53.4 | 46.31 | [config](configs/experiments/coco_maxvitb_i896_crcnn.yaml)
MaxViT-Large | 896x896 | 28x28 | 200 | 54.69 | - | 46.59 | [config](configs/experiments/coco_maxvitl_i896_crcnn.yaml)
</section>
### Citation
Should you find this repository useful, please consider citing:
```
@article{tu2022maxvit,
title={MaxViT: Multi-Axis Vision Transformer},
author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao},
journal={ECCV},
year={2022},
}
```
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from official.projects.maxvit.configs import backbones # pylint:disable=unused-import
from official.projects.maxvit.configs import rcnn # pylint:disable=unused-import
from official.projects.maxvit.configs import retinanet # pylint:disable=unused-import
from official.projects.maxvit.configs import semantic_segmentation # pylint:disable=unused-import
from official.projects.maxvit.configs.google import image_classification # pylint:disable=unused-import
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CoAtNet Image classification configuration definition."""
import dataclasses
from typing import Optional, Tuple
import tensorflow as tf
from official.modeling import hyperparams
from official.vision.configs import backbones
@dataclasses.dataclass
class MaxViT(hyperparams.Config):
"""MaxViT config."""
model_name: str = 'maxvit-tiny'
# These configs are specified according to `model_name` in default.
# Set values will override the default configs.
stem_hsize: Optional[Tuple[int, ...]] = None
block_type: Optional[Tuple[str, ...]] = None
num_blocks: Optional[Tuple[int, ...]] = None
hidden_size: Optional[Tuple[int, ...]] = None
# specific to the multi-axis attention in MaxViT
# Note that the window_size and grid_size should be divisible by all the
# feature map sizes along the entire network. Say, if you train on ImageNet
# classification at 224x224, set both to 7 is almost the only choice.
# If you train on COCO object detection at 896x896, set it to 28 is suggested,
# as following Swin Transformer, window size should scales with feature size.
# You may as well set it as 14 or 7.
window_size: int = 7 # window size for conducting block attention module.
grid_size: int = 7 # grid size for conducting sparse global grid attention.
# tfm specific
head_size: int = 32
dropatt: Optional[float] = None
dropout: Optional[float] = None
rel_attn_type: str = '2d_multi_head'
num_heads: Optional[int] = None
# A string of `current_window_size/ckpt_window_size` for finetuning from a
# checkpoint trained with `ckpt_window_size`.
scale_ratio: Optional[str] = None
ln_epsilon: float = 1e-5
ln_dtype: Optional[tf.DType] = None
# conv specific
downsample_loc: str = 'depth_conv'
kernel_size: int = 3
se_ratio: float = 0.25
dropcnn: Optional[float] = None
# Only channels_last is supported for now.
data_format: str = 'channels_last'
norm_type: str = 'sync_batch_norm'
# shared
add_pos_enc: bool = False
pool_type: str = '2d:avg'
pool_stride: int = 2
expansion_rate: int = 4
# Stochastic depth keep probability for the residual connection in. Smaller
# value means stronger regularization. If using anneal, it decays linearly
# from 1.0 to this value with the depth of each layer."
survival_prob: Optional[float] = None # from [0, 1]
survival_prob_anneal: bool = True
kernel_initializer: str = 'glorot_uniform'
bias_initializer: str = 'zeros'
# For cls head, should be same as the last `hidden_size` of backbone.
representation_size: Optional[int] = None
# Only effective when representation_size > 0.
add_gap_layer_norm: bool = True
@dataclasses.dataclass
class Backbone(backbones.Backbone):
"""Configuration for backbones."""
type: Optional[str] = 'maxvit'
maxvit: MaxViT = MaxViT()
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: ['backbone']
losses:
l2_weight_decay: 1.0e-07
model:
anchor:
anchor_size: 3.0
backbone:
maxvit:
model_name: 'maxvit-base'
window_size: 20
grid_size: 20
scale_ratio: '20/7'
survival_prob: 0.7
input_size: [640, 640, 3]
max_level: 7
min_level: 3
rpn_head:
num_convs: 2
train_data:
global_batch_size: 256
validation_data:
global_batch_size: 64
trainer:
optimizer_config:
ema:
average_decay: 0.9998
trainable_weights_only: false
learning_rate:
cosine:
decay_steps: 115750
initial_learning_rate: 0.002
optimizer:
adamw:
weight_decay_rate: 0.0001
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 6000
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: ['backbone']
losses:
l2_weight_decay: 2.0e-07
model:
anchor:
anchor_size: 3.0
backbone:
maxvit:
model_name: 'maxvit-base'
window_size: 28
grid_size: 28
scale_ratio: '28/7'
survival_prob: 0.2
input_size: [896, 896, 3]
max_level: 7
min_level: 3
rpn_head:
num_convs: 2
train_data:
global_batch_size: 256
validation_data:
global_batch_size: 128
trainer:
train_steps: 90000
validation_steps: 39
optimizer_config:
ema:
average_decay: 0.9998
trainable_weights_only: false
learning_rate:
cosine:
decay_steps: 90000
initial_learning_rate: 0.003
optimizer:
adamw:
weight_decay_rate: 0.05
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 6000
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: ['backbone']
losses:
l2_weight_decay: 0.0
model:
anchor:
anchor_size: 3.0
backbone:
maxvit:
model_name: 'maxvit-large'
window_size: 28
grid_size: 28
scale_ratio: '28/7'
survival_prob: 0.2
input_size: [896, 896, 3]
max_level: 7
min_level: 3
rpn_head:
num_convs: 2
train_data:
global_batch_size: 256
validation_data:
global_batch_size: 64
trainer:
train_steps: 90000
validation_steps: 78
optimizer_config:
ema:
average_decay: 0.9998
trainable_weights_only: false
learning_rate:
cosine:
decay_steps: 90000
initial_learning_rate: 0.003
optimizer:
adamw:
weight_decay_rate: 0.05
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 6000
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: ['backbone']
losses:
l2_weight_decay: 1.0e-07
model:
anchor:
anchor_size: 3.0
backbone:
maxvit:
model_name: 'maxvit-small'
window_size: 28
grid_size: 28
scale_ratio: '28/7'
survival_prob: 0.5
input_size: [896, 896, 3]
max_level: 7
min_level: 3
rpn_head:
num_convs: 2
train_data:
global_batch_size: 256
validation_data:
global_batch_size: 128
trainer:
train_steps: 115750
validation_steps: 39
optimizer_config:
ema:
average_decay: 0.9998
trainable_weights_only: false
learning_rate:
cosine:
decay_steps: 90000
initial_learning_rate: 0.003
optimizer:
adamw:
weight_decay_rate: 0.05
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 6000
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: ['backbone']
losses:
l2_weight_decay: 1.0e-07
model:
anchor:
anchor_size: 3.0
backbone:
maxvit:
model_name: 'maxvit-tiny'
window_size: 20
grid_size: 20
scale_ratio: '20/7'
survival_prob: 0.3
input_size: [640, 640, 3]
max_level: 7
min_level: 3
rpn_head:
num_convs: 2
train_data:
global_batch_size: 256
validation_data:
global_batch_size: 64
trainer:
optimizer_config:
ema:
average_decay: 0.9998
trainable_weights_only: false
learning_rate:
cosine:
decay_steps: 115750
initial_learning_rate: 0.002
optimizer:
adamw:
weight_decay_rate: 0.0001
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 6000
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: ['backbone']
losses:
l2_weight_decay: 1.0e-07
model:
anchor:
anchor_size: 3.0
backbone:
maxvit:
model_name: 'maxvit-base'
window_size: 28
grid_size: 28
scale_ratio: '28/7'
survival_prob: 0.7
input_size: [896, 896, 3]
max_level: 7
min_level: 3
rpn_head:
num_convs: 2
train_data:
global_batch_size: 256
validation_data:
global_batch_size: 64
trainer:
optimizer_config:
ema:
average_decay: 0.9998
trainable_weights_only: false
learning_rate:
cosine:
decay_steps: 115750
initial_learning_rate: 0.003
optimizer:
adamw:
weight_decay_rate: 0.0001
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 6000
runtime:
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: 'backbone'
model:
backbone:
maxvit:
model_name: 'maxvit-base'
representation_size: 768
survival_prob: 0.8
window_size: 12
grid_size: 12
scale_ratio: '12/7'
input_size: [384, 384, 3]
train_data:
global_batch_size: 512
dtype: 'bfloat16'
aug_crop: false
mixup_and_cutmix:
cutmix_alpha: 0.1
label_smoothing: 0.1
mixup_alpha: 0.1
prob: 0.1
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
dtype: 'bfloat16'
aug_crop: false
losses:
one_hot: false
soft_labels: true
use_binary_cross_entropy: false
trainer:
train_steps: 100080
steps_per_loop: 2000
summary_interval: 2000
validation_interval: 2000
checkpoint_interval: 2000
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 1.0e-4
gradient_clip_norm: 1.0
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
type: constant
constant:
learning_rate: 5.0e-5
warmup:
type: null
runtime:
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: 'backbone'
model:
backbone:
maxvit:
model_name: 'maxvit-base'
representation_size: 768
survival_prob: 0.8
window_size: 16
grid_size: 16
scale_ratio: '16/7'
input_size: [512, 512, 3]
train_data:
global_batch_size: 512
dtype: 'bfloat16'
aug_crop: false
mixup_and_cutmix:
cutmix_alpha: 0.1
label_smoothing: 0.1
mixup_alpha: 0.1
prob: 0.1
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
dtype: 'bfloat16'
aug_crop: false
losses:
one_hot: false
soft_labels: true
use_binary_cross_entropy: false
trainer:
train_steps: 100080
steps_per_loop: 2000
summary_interval: 2000
validation_interval: 2000
checkpoint_interval: 2000
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 1.0e-4
gradient_clip_norm: 1.0
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
type: constant
constant:
learning_rate: 5.0e-5
warmup:
type: null
runtime:
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: 'backbone'
model:
backbone:
maxvit:
model_name: 'maxvit-large'
representation_size: 1024
survival_prob: 0.7
window_size: 12
grid_size: 12
scale_ratio: '12/7'
input_size: [384, 384, 3]
train_data:
global_batch_size: 512
dtype: 'bfloat16'
aug_crop: false
mixup_and_cutmix:
cutmix_alpha: 0.1
label_smoothing: 0.1
mixup_alpha: 0.1
prob: 0.1
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
dtype: 'bfloat16'
aug_crop: false
losses:
one_hot: false
soft_labels: true
use_binary_cross_entropy: false
trainer:
train_steps: 100080
steps_per_loop: 2000
summary_interval: 2000
validation_interval: 2000
checkpoint_interval: 2000
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 1.0e-4
gradient_clip_norm: 1.0
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
type: constant
constant:
learning_rate: 5.0e-5
warmup:
type: null
runtime:
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: 'backbone'
model:
backbone:
maxvit:
model_name: 'maxvit-large'
representation_size: 1024
survival_prob: 0.8
window_size: 16
grid_size: 16
scale_ratio: '16/7'
input_size: [512, 512, 3]
train_data:
global_batch_size: 512
dtype: 'bfloat16'
aug_crop: false
mixup_and_cutmix:
cutmix_alpha: 0.1
label_smoothing: 0.1
mixup_alpha: 0.1
prob: 0.1
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
dtype: 'bfloat16'
aug_crop: false
losses:
one_hot: false
soft_labels: true
use_binary_cross_entropy: false
trainer:
train_steps: 100080
steps_per_loop: 2000
summary_interval: 2000
validation_interval: 2000
checkpoint_interval: 2000
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 1.0e-4
gradient_clip_norm: 1.0
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
type: constant
constant:
learning_rate: 5.0e-5
warmup:
type: null
runtime:
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: 'backbone'
model:
backbone:
maxvit:
model_name: 'maxvit-xlarge'
representation_size: 1536
survival_prob: 0.8
window_size: 12
grid_size: 12
scale_ratio: '12/7'
input_size: [384, 384, 3]
train_data:
global_batch_size: 512
dtype: 'bfloat16'
aug_crop: false
mixup_and_cutmix:
cutmix_alpha: 0.1
label_smoothing: 0.1
mixup_alpha: 0.1
prob: 0.1
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
dtype: 'bfloat16'
aug_crop: false
losses:
one_hot: false
soft_labels: true
use_binary_cross_entropy: false
trainer:
train_steps: 100080
steps_per_loop: 2000
summary_interval: 2000
validation_interval: 2000
checkpoint_interval: 2000
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 1.0e-4
gradient_clip_norm: 1.0
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
type: constant
constant:
learning_rate: 5.0e-5
warmup:
type: null
runtime:
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: 'backbone'
model:
backbone:
maxvit:
model_name: 'maxvit-xlarge'
representation_size: 1536
survival_prob: 0.8
window_size: 16
grid_size: 16
scale_ratio: '16/7'
input_size: [512, 512, 3]
train_data:
global_batch_size: 512
dtype: 'bfloat16'
aug_crop: false
mixup_and_cutmix:
cutmix_alpha: 0.1
label_smoothing: 0.1
mixup_alpha: 0.1
prob: 0.1
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
dtype: 'bfloat16'
aug_crop: false
losses:
one_hot: false
soft_labels: true
use_binary_cross_entropy: false
trainer:
train_steps: 100080
steps_per_loop: 2000
summary_interval: 2000
validation_interval: 2000
checkpoint_interval: 2000
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 1.0e-4
gradient_clip_norm: 1.0
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
type: constant
constant:
learning_rate: 5.0e-5
warmup:
type: null
task:
init_checkpoint: ''
model:
backbone:
maxvit:
model_name: 'maxvit-base'
representation_size: 768
input_size: [224, 224, 3]
trainer:
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 0.05
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
cosine:
initial_learning_rate: 0.003
alpha: 0.01
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 10000
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float16'
task:
model:
backbone:
maxvit:
model_name: 'maxvit-base'
representation_size: 768
norm_type: 'batch_norm'
input_size: [224, 224, 3]
norm_activation:
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: false
train_data:
global_batch_size: 192
shuffle_buffer_size: 1024
dtype: 'float16'
validation_data:
global_batch_size: 256
shuffle_buffer_size: 1024
dtype: 'float16'
trainer:
train_steps: 1500000
steps_per_loop: 10000
summary_interval: 10000
validation_interval: 10000
validation_steps: 195
optimizer_config:
ema: null
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 0.05
learning_rate:
cosine:
initial_learning_rate: 0.0001
alpha: 0.01
decay_steps: 1500000
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 10000
task:
init_checkpoint: ''
losses:
l2_weight_decay: 1.0e-07
model:
backbone:
maxvit:
model_name: 'maxvit-large'
representation_size: 1024
input_size: [224, 224, 3]
trainer:
max_to_keep: 5
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 0.05
gradient_clip_norm: 0.0
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
cosine:
initial_learning_rate: 0.001
alpha: 0.00
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 10000
task:
init_checkpoint: ''
losses:
l2_weight_decay: 1.0e-07
model:
backbone:
maxvit:
model_name: 'maxvit-small'
representation_size: 768
input_size: [224, 224, 3]
trainer:
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 0.05
gradient_clip_norm: 0.0
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
cosine:
initial_learning_rate: 0.002
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 10000
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float16'
task:
model:
backbone:
maxvit:
model_name: 'maxvit-small'
representation_size: 768
norm_type: 'batch_norm'
input_size: [224, 224, 3]
norm_activation:
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: false
train_data:
global_batch_size: 256
shuffle_buffer_size: 1024
dtype: 'float16'
validation_data:
global_batch_size: 256
shuffle_buffer_size: 1024
dtype: 'float16'
trainer:
train_steps: 1500000
steps_per_loop: 8000
summary_interval: 8000
validation_interval: 8000
validation_steps: 195
optimizer_config:
ema: null
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 0.05
learning_rate:
cosine:
initial_learning_rate: 0.0001
alpha: 0.01
decay_steps: 1500000
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 8000
task:
init_checkpoint: ''
losses:
l2_weight_decay: 1.0e-07
model:
backbone:
maxvit:
model_name: 'maxvit-tiny'
representation_size: 512
add_gap_layer_norm: true
kernel_initializer: 'glorot_uniform'
kernel_initializer: 'glorot_uniform'
input_size: [224, 224, 3]
trainer:
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 0.05
gradient_clip_norm: 0.0
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
cosine:
initial_learning_rate: 0.002
alpha: 0.0
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 10000
task:
init_checkpoint: ''
losses:
l2_weight_decay: 1.0e-07
model:
backbone:
maxvit:
model_name: 'maxvit-xlarge'
representation_size: 1536
input_size: [224, 224, 3]
trainer:
max_to_keep: 5
optimizer_config:
optimizer:
type: 'adamw'
adamw:
weight_decay_rate: 0.05
gradient_clip_norm: 0.0
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
cosine:
initial_learning_rate: 0.001
alpha: 0.01
warmup:
linear:
warmup_learning_rate: 0.0
warmup_steps: 10000
# RetinaNet with MaxViT backbone COCO detection.
# Required flags:
# --experiment_type=retinanet_maxvit_coco
#
# Expected AP on DF TPU 8x8: 50.38%.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: ['backbone']
model:
anchor:
anchor_size: 3
aspect_ratios: [0.5, 1.0, 2.0]
num_scales: 3
backbone:
type: 'maxvit'
maxvit:
model_name: 'maxvit-base'
window_size: 40
grid_size: 40
scale_ratio: '40/7'
survival_prob: 0.3
input_size: [1280, 1280, 3]
# RetinaNet with MaxViT backbone COCO detection.
# Required flags:
# --experiment_type=retinanet_maxvit_coco
#
# Expected AP on DF TPU 4x4: 46.63%.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: ['backbone']
model:
backbone:
type: 'maxvit'
maxvit:
model_name: 'maxvit-base'
window_size: 20
grid_size: 20
scale_ratio: '20/7'
survival_prob: 0.3
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: ['backbone']
model:
num_classes: 91
input_size: [640, 640, 3]
backbone:
type: 'maxvit'
maxvit:
model_name: 'maxvit-small'
window_size: 20
grid_size: 20
scale_ratio: '20/7'
survival_prob: 0.7
decoder:
fpn:
fusion_type: 'concat'
type: 'fpn'
head:
level: 3
losses:
l2_weight_decay: 0
top_k_percent_pixels: 1.0
train_data:
output_size: [640, 640]
global_batch_size: 32
dtype: 'bfloat16'
aug_rand_hflip: true
aug_scale_max: 1.5
aug_scale_min: 0.5
validation_data:
output_size: [640, 640]
global_batch_size: 32
dtype: 'bfloat16'
groundtruth_padded_size: [640, 640]
trainer:
optimizer_config:
learning_rate:
type: cosine
cosine:
decay_steps: 64000
initial_learning_rate: 0.000001
optimizer:
adamw:
beta_1: 0.9
beta_2: 0.999
weight_decay_rate: 0.0001
type: adamw
warmup:
linear:
name: linear
warmup_learning_rate: 0
warmup_steps: 4000
type: linear
ema:
average_decay: 0.9998
trainable_weights_only: false
best_checkpoint_eval_metric: 'mean_iou'
best_checkpoint_export_subdir: 'best_ckpt'
best_checkpoint_metric_comp: 'higher'
steps_per_loop: 200
summary_interval: 200
train_steps: 64000
checkpoint_interval: 200
validation_interval: 200
validation_steps: 39
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: 'Please provide'
init_checkpoint_modules: ['backbone']
model:
num_classes: 21
input_size: [512, 512, 3]
backbone:
type: 'maxvit'
maxvit:
model_name: 'maxvit-small'
window_size: 16
grid_size: 16
scale_ratio: '16/7'
survival_prob: 0.7
decoder:
fpn:
fusion_type: 'sum'
type: 'fpn'
head:
level: 3
losses:
l2_weight_decay: 0
top_k_percent_pixels: 1.0
train_data:
output_size: [512, 512]
global_batch_size: 32
dtype: 'bfloat16'
aug_rand_hflip: true
aug_scale_max: 2.0
aug_scale_min: 0.5
validation_data:
output_size: [512, 512]
global_batch_size: 32
dtype: 'bfloat16'
groundtruth_padded_size: [512, 512]
trainer:
optimizer_config:
learning_rate:
cosine:
initial_learning_rate: 0.00001
alpha: 0.01
optimizer:
adamw:
beta_1: 0.9
beta_2: 0.999
weight_decay_rate: 0.0001
type: adamw
warmup:
linear:
name: linear
warmup_learning_rate: 0
warmup_steps: 500
type: linear
ema:
average_decay: 0.9998
trainable_weights_only: false
best_checkpoint_eval_metric: 'mean_iou'
best_checkpoint_export_subdir: 'best_ckpt'
best_checkpoint_metric_comp: 'higher'
steps_per_loop: 330
summary_interval: 330
train_steps: 20000
validation_interval: 330
checkpoint_interval: 330
validation_steps: 45
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MaxViT Image classification configuration definition."""
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling.optimization.configs import optimization_config
from official.projects.maxvit.configs import backbones
from official.vision.configs import image_classification as img_cls_cfg
@exp_factory.register_config_factory('maxvit_imagenet')
def maxvit_imagenet() -> cfg.ExperimentConfig:
"""Returns MaxViT-Tiny on imagenet-1k.
Expected to be trained on DF 4x4 or bigger. Can eval on DF 4x2.
Returns:
The full experiment config.
"""
# Reuse ViT deit pretraining config.
exp = img_cls_cfg.image_classification_imagenet_deit_pretrain()
exp.task.model = img_cls_cfg.ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='glorot_uniform',
backbone=backbones.Backbone(
type='maxvit',
maxvit=backbones.MaxViT(
model_name='maxvit-tiny', representation_size=768
),
),
norm_activation=img_cls_cfg.common.NormActivation(activation='relu'),
)
exp.task.train_data.aug_type.randaug.num_layers = 2
exp.task.train_data.aug_type.randaug.magnitude = 15
exp.runtime.mixed_precision_dtype = 'bfloat16'
exp.trainer.optimizer_config.optimizer.adamw.gradient_clip_norm = 0.0
exp.trainer.optimizer_config.warmup.linear.warmup_steps = 10000
exp.trainer.optimizer_config.ema = optimization_config.opt_cfg.EMAConfig(
average_decay=0.9999,
trainable_weights_only=False,
)
return exp
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.maxvit.configs import image_classification # pylint:disable=unused-import
from official.vision.configs import image_classification as img_cls_config
class MaxViTImageClassificationConfigTest(tf.test.TestCase):
def test_maxvit_build_model(self):
config = exp_factory.get_exp_config('maxvit_imagenet')
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(
config.task, img_cls_config.ImageClassificationTask
)
self.assertIsInstance(
config.task.model, img_cls_config.ImageClassificationModel
)
self.assertIsInstance(
config.task.train_data, img_cls_config.DataConfig
)
config.validate()
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask R-CNN configuration definition."""
import os
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling.optimization.configs import optimization_config
from official.projects.maxvit.configs import backbones
from official.vision.configs import common
from official.vision.configs import decoders
from official.vision.configs import maskrcnn
Parser = maskrcnn.Parser
Anchor = maskrcnn.Anchor
Losses = maskrcnn.Losses
ROISampler = maskrcnn.ROISampler
DetectionHead = maskrcnn.DetectionHead
DataConfig = maskrcnn.DataConfig
MaskRCNN = maskrcnn.MaskRCNN
MaskRCNNTask = maskrcnn.MaskRCNNTask
COCO_INPUT_PATH_BASE = (
'/readahead/200M/placer/prod/home/tensorflow-performance-data/datasets/coco'
)
@exp_factory.register_config_factory('rcnn_maxvit_coco')
def rcnn_maxvit_coco() -> cfg.ExperimentConfig:
"""COCO object detection with MaxViT and Cascade R-CNN."""
steps_per_epoch = 1848 # based on 463 steps @ bs=256
train_batch_size = 256
coco_val_samples = 5000
eval_batch_size = 64
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=MaskRCNNTask(
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=MaskRCNN(
anchor=Anchor(num_scales=3, anchor_size=3.0),
backbone=backbones.Backbone(
type='maxvit',
maxvit=backbones.MaxViT(model_name='maxvit-base')
),
decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()),
num_classes=91,
input_size=[640, 640, 3],
include_mask=True,
roi_sampler=ROISampler(
cascade_iou_thresholds=[0.7], foreground_iou_threshold=0.6),
detection_head=DetectionHead(
cascade_class_ensemble=True, class_agnostic_bbox_pred=True),
norm_activation=common.NormActivation(
use_sync_bn=True,
activation='relu',
norm_epsilon=0.001,
norm_momentum=0.99),
min_level=3,
max_level=7,
),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
parser=Parser(
aug_rand_hflip=True, aug_scale_min=0.1, aug_scale_max=2.5)),
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=True)),
trainer=cfg.TrainerConfig(
train_steps=90000,
validation_steps=coco_val_samples // eval_batch_size,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
checkpoint_interval=steps_per_epoch * 4,
optimizer_config=optimization_config.OptimizationConfig({
'ema': {
'average_decay': 0.9998,
'trainable_weights_only': False,
},
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.0001,
'beta_1': 0.9,
'beta_2': 0.999,
'include_in_weight_decay': r'.*(kernel|weight):0$',
},
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'decay_steps': 90000,
'initial_learning_rate': 0.0001,
'alpha': 0.03,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 6000,
'warmup_learning_rate': 0.,
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.maxvit.configs import rcnn as exp_cfg
class MaskRCNNConfigTest(tf.test.TestCase):
def test_maskrcnn_configs(self):
config = exp_factory.get_exp_config('rcnn_maxvit_coco')
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.MaskRCNNTask)
self.assertIsInstance(config.task.model, exp_cfg.MaskRCNN)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.validate()
config.task.train_data.is_training = None
with self.assertRaisesRegex(KeyError, 'Found inconsistency between key'):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RetinaNet configuration definition."""
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.maxvit.configs import backbones
from official.vision.configs import retinanet
@exp_factory.register_config_factory('retinanet_maxvit_coco')
def retinanet_maxvit_coco() -> cfg.ExperimentConfig:
"""COCO object detection with RetinaNet using MaxViT backbone."""
config = retinanet.retinanet_resnetfpn_coco()
config.task.model.backbone = backbones.Backbone(
type='maxvit', maxvit=backbones.MaxViT(
model_name='maxvit-base',
window_size=20,
grid_size=20,
scale_ratio='20/7',
survival_prob=0.7,
)
)
config.task.validation_data.global_batch_size = 32
config.trainer.validation_steps = 156
config.trainer.validation_interval = 1560
return config
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for retinanet."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.maxvit.configs import retinanet
from official.vision.configs import retinanet as exp_cfg
class RetinaNetConfigTest(tf.test.TestCase, parameterized.TestCase):
def test_retinanet_configs(self):
config = exp_factory.get_exp_config('retinanet_maxvit_coco')
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.RetinaNetTask)
self.assertIsInstance(config.task.model, exp_cfg.RetinaNet)
self.assertIsInstance(
config.task.model.backbone.maxvit, retinanet.backbones.MaxViT
)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.validate()
config.task.train_data.is_training = None
with self.assertRaisesRegex(KeyError, 'Found inconsistency between key'):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Semantic segmentation configuration definition."""
import os
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.projects.maxvit.configs import backbones
from official.vision.configs import common
from official.vision.configs import decoders
from official.vision.configs import semantic_segmentation
DataConfig = semantic_segmentation.DataConfig
Losses = semantic_segmentation.Losses
Evaluation = semantic_segmentation.Evaluation
SegmentationHead = semantic_segmentation.SegmentationHead
SemanticSegmentationModel = semantic_segmentation.SemanticSegmentationModel
SemanticSegmentationTask = semantic_segmentation.SemanticSegmentationTask
# PASCAL VOC 2012 Dataset
PASCAL_TRAIN_EXAMPLES = 10582
PASCAL_VAL_EXAMPLES = 1449
PASCAL_INPUT_PATH_BASE = 'gs://**/pascal_voc_seg'
@exp_factory.register_config_factory('maxvit_seg_pascal')
def maxvit_seg_pascal() -> cfg.ExperimentConfig:
"""Image segmentation on Pascal VOC with MaxViT."""
train_batch_size = 32
eval_batch_size = 32
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
num_classes=21,
input_size=[512, 512, 3],
min_level=3,
max_level=7,
backbone=backbones.Backbone(
type='maxvit',
maxvit=backbones.MaxViT(
model_name='maxvit-tiny',
window_size=16,
grid_size=16,
scale_ratio='16/7',
),
),
decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()),
head=SegmentationHead(level=3, num_convs=3),
norm_activation=common.NormActivation(
use_sync_bn=True,
activation='relu',
norm_epsilon=0.001,
norm_momentum=0.99,
),
),
losses=Losses(l2_weight_decay=1e-5, top_k_percent_pixels=1.0),
train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
output_size=[512, 512],
is_training=True,
global_batch_size=train_batch_size,
aug_rand_hflip=True,
aug_scale_min=0.2,
aug_scale_max=1.5,
),
validation_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'),
output_size=[512, 512],
is_training=True,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=True,
groundtruth_padded_size=[512, 512],
drop_remainder=True,
),
),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=20000,
validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'ema': {
'average_decay': 0.9998,
'trainable_weights_only': False,
},
'optimizer': {
'type': 'adamw',
'adamw': {
'beta_1': 0.9,
'beta_2': 0.999,
'weight_decay_rate': 0.0001,
'include_in_weight_decay': r'.*(kernel|weight):0$',
},
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005,
'decay_steps': 20000,
'alpha': 0.03,
},
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 500,
'warmup_learning_rate': 0,
},
},
}),
),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
],
)
return config
# COCO segmentation.
COCO_TRAIN_EXAMPLES = 25600
COCO_VAL_EXAMPLES = 5000
COCO_INPUT_PATH_BASE = 'mscoco'
@exp_factory.register_config_factory('maxvit_seg_coco')
def maxvit_seg_coco() -> cfg.ExperimentConfig:
"""Image segmentation on COCO with MaxViT."""
train_batch_size = 32
eval_batch_size = 32
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
num_classes=91,
input_size=[640, 640, 3],
backbone=backbones.Backbone(
type='maxvit',
maxvit=backbones.MaxViT(
model_name='maxvit-tiny',
window_size=20,
grid_size=20,
scale_ratio='20/7',
),
),
decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()),
head=SegmentationHead(level=3, num_convs=3),
norm_activation=common.NormActivation(
use_sync_bn=True,
activation='relu',
norm_epsilon=0.001,
norm_momentum=0.99,
),
),
losses=Losses(l2_weight_decay=1e-5, top_k_percent_pixels=1.0),
train_data=DataConfig(
input_path=os.path.join(
COCO_INPUT_PATH_BASE,
'mscoco_alltasks_trainvalminusminival2014*',
),
output_size=[640, 640],
is_training=True,
global_batch_size=train_batch_size,
aug_rand_hflip=True,
aug_scale_min=0.2,
aug_scale_max=2.0,
),
validation_data=DataConfig(
input_path=os.path.join(
COCO_INPUT_PATH_BASE, 'mscoco_alltasks_minival2014*'
),
output_size=[640, 640],
is_training=True,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=True,
groundtruth_padded_size=[640, 640],
drop_remainder=True,
),
),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=64000,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'ema': {
'average_decay': 0.9998,
'trainable_weights_only': False,
},
'optimizer': {
'type': 'adamw',
'adamw': {
'beta_1': 0.9,
'beta_2': 0.999,
'weight_decay_rate': 0.00001,
'include_in_weight_decay': r'.*(kernel|weight):0$',
},
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.00005,
'decay_steps': 64000,
'alpha': 0.03,
},
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 1600,
'warmup_learning_rate': 0,
},
},
}),
),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
],
)
return config
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
# pylint: disable=unused-import
from official import vision
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.maxvit.configs import semantic_segmentation as exp_cfg
class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('maxvit_seg_pascal',),
('maxvit_seg_coco',))
def test_semantic_segmentation_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.SemanticSegmentationTask)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.validate()
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common operations."""
import functools
import math
from typing import Optional
from absl import logging
import numpy as np
import tensorflow as tf
def activation_fn(features: tf.Tensor, act_fn: str):
"""Customized non-linear activation type."""
if act_fn in ('silu', 'swish'):
return tf.nn.swish(features)
elif act_fn == 'silu_native':
return features * tf.sigmoid(features)
elif act_fn == 'hswish':
return features * tf.nn.relu6(features + 3) / 6
elif act_fn == 'relu':
return tf.nn.relu(features)
elif act_fn == 'relu6':
return tf.nn.relu6(features)
elif act_fn == 'elu':
return tf.nn.elu(features)
elif act_fn == 'leaky_relu':
return tf.nn.leaky_relu(features)
elif act_fn == 'selu':
return tf.nn.selu(features)
elif act_fn == 'mish':
return features * tf.math.tanh(tf.math.softplus(features))
elif act_fn == 'gelu':
return (
0.5
* features
* (
1
+ tf.tanh(
np.sqrt(2 / np.pi) * (features + 0.044715 * tf.pow(features, 3))
)
)
)
else:
raise ValueError('Unsupported act_fn {}'.format(act_fn))
def get_act_fn(act_fn):
if act_fn is None:
act_fn = 'gelu'
if isinstance(act_fn, str):
return functools.partial(activation_fn, act_fn=act_fn)
elif callable(act_fn):
return act_fn
else:
raise ValueError('Unsupported act_fn %s.' % act_fn)
def pooling_2d(inputs, pool_type, stride, **kwargs):
"""Perform 2D pooling."""
if stride > 1:
if pool_type == 'max':
pool_op = tf.keras.layers.MaxPool2D
elif pool_type == 'avg':
pool_op = tf.keras.layers.AveragePooling2D
else:
raise ValueError('Unsurpported pool_type %s' % pool_type)
output = pool_op(
pool_size=(stride, stride), strides=(stride, stride), **kwargs
)(inputs)
else:
output = inputs
return output
def drop_connect(inputs, training, survival_prob):
"""Drop the entire conv with given survival probability."""
# "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
if not training:
return inputs
# Compute tensor.
batch_size = tf.shape(inputs)[0]
random_tensor = survival_prob
random_tensor += tf.random.uniform([batch_size], dtype=inputs.dtype)
for _ in range(inputs.shape.rank - 1):
random_tensor = tf.expand_dims(random_tensor, axis=-1)
binary_tensor = tf.floor(random_tensor)
# Unlike conventional way that multiply survival_prob at test time, here we
# divide survival_prob at training time, such that no addition compute is
# needed at test time.
output = inputs / survival_prob * binary_tensor
return output
def residual_add(residual, shortcut, survival_prob, training):
"""Combine residual and shortcut."""
if survival_prob is not None and 0 < survival_prob < 1:
residual = drop_connect(residual, training, survival_prob)
return shortcut + residual
def maybe_reshape_to_2d(x, height=None):
"""Reshape tensor to 2d if not already 2d."""
if x.shape.rank == 3:
_, length, num_channel = x.shape.as_list()
if height is None:
height = int(np.sqrt(length))
else:
assert length % height == 0
width = length // height
logging.debug(
'Reshape %s -> %s', [length, num_channel], [height, width, num_channel]
)
return tf.reshape(x, [-1, height, width, num_channel])
elif x.shape.rank == 4:
return x
else:
raise ValueError('Unsupport shape {}'.format(x.shape))
def maybe_reshape_to_1d(x):
"""Reshape tensor to 1d if not already 1d."""
if x.shape.rank == 4:
_, h, w, num_channel = x.shape.as_list()
logging.debug('Reshape %s -> %s', [h, w, num_channel], [h * w, num_channel])
return tf.reshape(x, [-1, h * w, num_channel])
elif x.shape.rank == 3:
return x
else:
raise ValueError('Unsupport shape {}'.format(x.shape))
def generate_lookup_tensor(
length: int,
max_relative_position: Optional[int] = None,
clamp_out_of_range: bool = False,
dtype: tf.DType = tf.float32) -> tf.Tensor:
"""Generate a one_hot lookup tensor to reindex embeddings along one dimension.
Args:
length: the length to reindex to.
max_relative_position: the maximum relative position to consider.
Relative position embeddings for distances above this threshold
are zeroed out.
clamp_out_of_range: bool. Whether to clamp out of range locations to the
maximum relative distance. If False, the out of range locations will be
filled with all-zero vectors.
dtype: dtype for the returned lookup tensor.
Returns:
ret: [length, length, vocab_size] lookup tensor that satisfies
ret[n,m,v] = 1{m - n + max_relative_position = v}.
"""
if max_relative_position is None:
max_relative_position = length - 1
vocab_size = 2 * max_relative_position + 1
ret = np.zeros((length, length, vocab_size))
for i in range(length):
for x in range(length):
v = x - i + max_relative_position
if abs(x - i) > max_relative_position:
if clamp_out_of_range:
v = np.clip(v, 0, vocab_size - 1)
else:
continue
ret[i, x, v] = 1
return tf.constant(ret, dtype)
def reindex_2d_einsum_lookup(
relative_position_tensor: tf.Tensor,
height: int,
width: int,
max_relative_height: Optional[int] = None,
max_relative_width: Optional[int] = None,
h_axis=None) -> tf.Tensor:
"""Reindex 2d relative position bias with 2 independent einsum lookups.
Args:
relative_position_tensor: tensor of shape
[..., vocab_height, vocab_width, ...].
height: height to reindex to.
width: width to reindex to.
max_relative_height: maximum relative height.
Position embeddings corresponding to vertical distances larger
than max_relative_height are zeroed out. None to disable.
max_relative_width: maximum relative width.
Position embeddings corresponding to horizontal distances larger
than max_relative_width are zeroed out. None to disable.
h_axis: Axis corresponding to vocab_height. Default to 0 if None.
Returns:
reindexed_bias: a Tensor of shape
[..., height * width, height * width, ...]
"""
height_lookup = generate_lookup_tensor(
height, max_relative_position=max_relative_height,
dtype=relative_position_tensor.dtype)
width_lookup = generate_lookup_tensor(
width, max_relative_position=max_relative_width,
dtype=relative_position_tensor.dtype)
if h_axis is None:
h_axis = 0
non_spatial_rank = relative_position_tensor.shape.rank - 2
non_spatial_expr = ''.join(chr(ord('n') + i) for i in range(non_spatial_rank))
prefix = non_spatial_expr[:h_axis]
suffix = non_spatial_expr[h_axis:]
reindexed_tensor = tf.einsum(
'{0}hw{1},ixh->{0}ixw{1}'.format(prefix, suffix),
relative_position_tensor, height_lookup, name='height_lookup')
reindexed_tensor = tf.einsum(
'{0}ixw{1},jyw->{0}ijxy{1}'.format(prefix, suffix),
reindexed_tensor, width_lookup, name='width_lookup')
ret_shape = relative_position_tensor.shape.as_list()
ret_shape[h_axis] = height * width
ret_shape[h_axis + 1] = height * width
reindexed_tensor = tf.reshape(reindexed_tensor, ret_shape)
return reindexed_tensor
def float32_softmax(x: tf.Tensor, *args, **kwargs) -> tf.Tensor:
y = tf.cast(tf.nn.softmax(tf.cast(x, tf.float32), *args, **kwargs), x.dtype)
return y
def get_shape_from_length(length: int, height: int = 1, width: int = 1):
"""Gets input 2D shape from 1D sequence length."""
input_height = int(math.sqrt(length * height // width))
input_width = input_height * width // height
if input_height * input_width != length:
raise ValueError(
f'Invalid sequence length: {length} or shape: ({height, width}).'
)
return (input_height, input_width)
此差异已折叠。
此差异已折叠。
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for MaxViT."""
import collections
from typing import Optional, Sequence
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.projects.maxvit.configs import backbones
from official.projects.maxvit.modeling import maxvit
from official.vision.configs import common
class MaxViTBlockTest(tf.test.TestCase):
"""Test the layers of MaxViT."""
def testMaxViTBlockCreation(self) -> None:
"""Ensures that layers can be constructed and forward-props can run."""
inputs_shape = [2, 64, 64, 3]
inp = tf.random.uniform(
shape=inputs_shape, minval=-1.0, maxval=1.0, dtype=tf.float32
)
model = maxvit.MaxViTBlock(
hidden_size=8, head_size=4, window_size=4, grid_size=4
)
out = model(inp, training=False)
self.assertAllEqual([2, 64, 64, 8], out.get_shape().as_list())
self.assertDTypeEqual(tf.reduce_mean(out).numpy(), np.float32)
class MaxViTTest(tf.test.TestCase, parameterized.TestCase):
"""Test the layers of MaxViT."""
@parameterized.named_parameters(
collections.OrderedDict(
testcase_name='MaxViTTest',
input_shape=[2, 64, 64, 3],
input_dtype=tf.float32,
training=False,
stem_hsize=[12, 12],
num_blocks=[2, 2, 2, 2],
window_size=2,
grid_size=2,
block_type=['maxvit', 'maxvit', 'maxvit'],
hidden_size=[16, 32, 64],
expected_shape=[2, 4, 4, 64],
name='maxvit_test',
),
collections.OrderedDict(
testcase_name='MaxViTTiny',
input_shape=[2, 64, 64, 3],
input_dtype=tf.float32,
training=False,
block_type=['maxvit', 'maxvit', 'maxvit', 'maxvit'],
stem_hsize=[64, 64],
num_blocks=[2, 3, 5, 2],
window_size=2,
grid_size=2,
hidden_size=[96, 192, 384, 768],
expected_shape=[2, 2, 2, 768],
name='maxvit_tiny',
),
collections.OrderedDict(
testcase_name='MaxViTTinyWithPrelogits',
input_shape=[2, 64, 64, 3],
input_dtype=tf.float32,
training=False,
representation_size=16,
add_gap_layer_norm=True,
block_type=['maxvit', 'maxvit', 'maxvit', 'maxvit'],
stem_hsize=[64, 64],
num_blocks=[2, 3, 5, 2],
window_size=2,
grid_size=2,
hidden_size=[96, 192, 384, 768],
expected_shape=[2, 2, 2, 768],
name='maxvit_tiny',
),
)
def testForward(
self,
input_shape: Sequence[int],
input_dtype: Optional[tf.DType] = tf.float32,
**kwargs
) -> None:
"""Ensures that layers can be constructed and forward-props can run."""
inp = tf.random.uniform(
input_shape,
minval=-1.0,
maxval=1.0,
dtype=input_dtype,
)
model = maxvit.MaxViT(**kwargs)
out = model(inp, training=kwargs.get('training', None))
add_gap_layer_norm = kwargs.get('add_gap_layer_norm', False)
if add_gap_layer_norm:
self.assertAllEqual([input_shape[0], kwargs['representation_size']],
out['pre_logits'].get_shape().as_list())
# Remove `pre_logits` if exists.
out.pop('pre_logits', None)
out = out[max(out.keys())]
self.assertAllEqual(kwargs['expected_shape'], out.get_shape().as_list())
self.assertDTypeEqual(tf.reduce_mean(out).numpy(), np.float32)
def testBuildMaxViTWithConfig(self):
backbone_config = backbones.Backbone(
type='maxvit',
maxvit=backbones.MaxViT(
stem_hsize=[32, 32],
num_blocks=[2, 3, 5, 2],
window_size=2,
grid_size=2,
hidden_size=[32, 32, 32, 32],
),
)
backbone = maxvit.build_maxvit(
input_specs=tf.keras.layers.InputSpec(shape=[None] + [64, 64, 3]),
backbone_config=backbone_config,
norm_activation_config=common.NormActivation(),
)
self.assertSetEqual(
set(['2', '3', '4', '5']), set(backbone.output_specs.keys())
)
if __name__ == '__main__':
tf.test.main()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver, including ViT configs.."""
from absl import app
from official.common import flags as tfm_flags
from official.projects.maxvit import configs # pylint: disable=unused-import
from official.projects.maxvit.modeling import maxvit # pylint: disable=unused-import
from official.vision import train
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(train.main)
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from absl import flags
from absl.testing import flagsaver
import gin
import tensorflow as tf
from official.projects.maxvit import train as train_lib
from official.vision.dataloaders import tfexample_utils
FLAGS = flags.FLAGS
class TrainTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
self._test_tfrecord_file = os.path.join(
self.get_temp_dir(), 'test.tfrecord'
)
num_samples = 3
example = tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=224, image_width=224
)
)
examples = [example] * num_samples
tfexample_utils.dump_to_tfrecord(
record_file=self._test_tfrecord_file, tf_examples=examples
)
def test_run(self):
saved_flag_values = flagsaver.save_flag_values()
train_lib.tfm_flags.define_flags()
FLAGS.mode = 'train'
FLAGS.model_dir = self._model_dir
FLAGS.experiment = 'maxvit_imagenet'
params_override = json.dumps({
'runtime': {
'mixed_precision_dtype': 'float32',
},
'trainer': {
'train_steps': 1,
'validation_steps': 1,
'optimizer_config': {
'ema': None,
},
},
'task': {
'init_checkpoint': '',
'model': {
'backbone': {
'maxvit': {
'model_name': 'maxvit-tiny-for-test',
'representation_size': 64,
'add_gap_layer_norm': True,
}
},
'input_size': [224, 224, 3],
'num_classes': 3,
},
'train_data': {
'global_batch_size': 2,
'input_path': self._test_tfrecord_file,
},
'validation_data': {
'global_batch_size': 2,
'input_path': self._test_tfrecord_file,
},
},
})
FLAGS.params_override = params_override
train_lib.train.main('unused_args')
FLAGS.mode = 'eval'
with gin.unlock_config():
train_lib.train.main('unused_args')
flagsaver.restore_flag_values(saved_flag_values)
if __name__ == '__main__':
tf.test.main()
......@@ -66,7 +66,7 @@ class MosaicSegmentationModel(tf.keras.Model):
self.head = head
self.mask_scoring_head = mask_scoring_head
def call(self,
def call(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
inputs: tf.Tensor,
training: bool = None) -> Dict[str, tf.Tensor]:
backbone_features = self.backbone(inputs)
......
......@@ -110,6 +110,11 @@ flags.DEFINE_bool(
flags.DEFINE_string(
'checkpoint_path', '',
'Checkpoint path to load. Leave blank for default initialization.')
flags.DEFINE_bool(
'assert_checkpoint_objects_matched',
True,
'Whether to check the checkpoint objects exactly match those of the model.',
)
FLAGS = flags.FLAGS
......@@ -120,21 +125,25 @@ def export_saved_model(
export_path: str = '/tmp/movinet/',
causal: bool = False,
bundle_input_init_states_fn: bool = True,
checkpoint_path: Optional[str] = None) -> None:
checkpoint_path: Optional[str] = None,
assert_checkpoint_objects_matched: bool = True,
) -> None:
"""Exports a MoViNet model to a saved model.
Args:
model: the tf.keras.Model to export.
input_shape: The 5D spatiotemporal input shape of size
[batch_size, num_frames, image_height, image_width, num_channels].
Set the field or a shape position in the field to None for dynamic input.
input_shape: The 5D spatiotemporal input shape of size [batch_size,
num_frames, image_height, image_width, num_channels]. Set the field or a
shape position in the field to None for dynamic input.
export_path: Export path to save the saved_model file.
causal: Run the model in causal mode.
bundle_input_init_states_fn: Add init_states as a function signature to the
saved model. This is not necessary if the input shape is static (e.g.,
for TF Lite).
saved model. This is not necessary if the input shape is static (e.g., for
TF Lite).
checkpoint_path: Checkpoint path to load. Leave blank to keep the model's
initialization.
assert_checkpoint_objects_matched: Whether to check the checkpoint objects
exactly match those of the model.
"""
# Use dimensions of 1 except the channels to export faster,
......@@ -149,7 +158,8 @@ def export_saved_model(
if checkpoint_path:
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()
if assert_checkpoint_objects_matched:
status.assert_existing_objects_matched()
if causal:
# Call the model once to get the output states. Call again with `states`
......@@ -205,23 +215,23 @@ def build_and_export_saved_model(
num_classes: int = 600,
input_shape: Optional[Tuple[int, int, int, int, int]] = None,
bundle_input_init_states_fn: bool = True,
checkpoint_path: Optional[str] = None) -> None:
checkpoint_path: Optional[str] = None,
assert_checkpoint_objects_matched: bool = True,
) -> None:
"""Builds and exports a MoViNet model to a saved model.
Args:
export_path: Export path to save the saved_model file.
model_id: MoViNet model name.
causal: Run the model in causal mode.
conv_type: 3d, 2plus1d, or 3d_2plus1d. 3d configures the network
to use the default 3D convolution. 2plus1d uses (2+1)D convolution
with Conv2D operations and 2D reshaping (e.g., a 5x3x3 kernel becomes
3x3 followed by 5x1 conv). 3d_2plus1d uses (2+1)D convolution with
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3
followed by 5x1x1 conv).
se_type:
3d, 2d, or 2plus3d. 3d uses the default 3D spatiotemporal global average
pooling for squeeze excitation. 2d uses 2D spatial global average pooling
on each frame. 2plus3d concatenates both 3D and 2D global average
conv_type: 3d, 2plus1d, or 3d_2plus1d. 3d configures the network to use the
default 3D convolution. 2plus1d uses (2+1)D convolution with Conv2D
operations and 2D reshaping (e.g., a 5x3x3 kernel becomes 3x3 followed by
5x1 conv). 3d_2plus1d uses (2+1)D convolution with Conv3D and no 2D
reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed by 5x1x1 conv).
se_type: 3d, 2d, or 2plus3d. 3d uses the default 3D spatiotemporal global
average pooling for squeeze excitation. 2d uses 2D spatial global average
pooling on each frame. 2plus3d concatenates both 3D and 2D global average
pooling.
activation: The main activation to use across layers.
classifier_activation: The classifier activation to use.
......@@ -230,14 +240,16 @@ def build_and_export_saved_model(
use_positional_encoding: Whether to use positional encoding (only applied
when causal=True).
num_classes: The number of classes for prediction.
input_shape: The 5D spatiotemporal input shape of size
[batch_size, num_frames, image_height, image_width, num_channels].
Set the field or a shape position in the field to None for dynamic input.
input_shape: The 5D spatiotemporal input shape of size [batch_size,
num_frames, image_height, image_width, num_channels]. Set the field or a
shape position in the field to None for dynamic input.
bundle_input_init_states_fn: Add init_states as a function signature to the
saved model. This is not necessary if the input shape is static (e.g.,
for TF Lite).
saved model. This is not necessary if the input shape is static (e.g., for
TF Lite).
checkpoint_path: Checkpoint path to load. Leave blank for default
initialization.
assert_checkpoint_objects_matched: Whether to check the checkpoint objects
exactly match those of the model.
"""
input_specs = tf.keras.layers.InputSpec(shape=input_shape)
......@@ -272,7 +284,9 @@ def build_and_export_saved_model(
export_path=export_path,
causal=causal,
bundle_input_init_states_fn=bundle_input_init_states_fn,
checkpoint_path=checkpoint_path)
checkpoint_path=checkpoint_path,
assert_checkpoint_objects_matched=assert_checkpoint_objects_matched,
)
def main(_) -> None:
......@@ -291,7 +305,9 @@ def main(_) -> None:
num_classes=FLAGS.num_classes,
input_shape=input_shape,
bundle_input_init_states_fn=FLAGS.bundle_input_init_states_fn,
checkpoint_path=FLAGS.checkpoint_path)
checkpoint_path=FLAGS.checkpoint_path,
assert_checkpoint_objects_matched=FLAGS.assert_checkpoint_objects_matched,
)
print(' ----- Done. Saved Model is saved at {}'.format(FLAGS.export_path))
......
......@@ -60,7 +60,7 @@ class PanopticDeeplabModel(tf.keras.Model):
self.instance_head = instance_head
self.post_processor = post_processor
def call(
def call( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
self, inputs: tf.Tensor,
image_info: tf.Tensor,
training: bool = None):
......
# Perceiver IO: A General Architecture for Structured Inputs & Outputs
TF2 implementation of [Perceiver](https://arxiv.org/abs/2107.14795).
## Default setup command:
Scripts to pretrain, finetune, train from scratch can be found under
perceiver/experiments.
## BERT Wiki Books Pretrain
Configurations can be seen on Table 8 and Table 9 of the
[paper](https://arxiv.org/abs/2107.14795). Our model configuration can be
deduced in the configs and experiment folder, where we follow the
configuration in the paper except for the tokenization and data.
Model | Tokenizer | Pretrain Data | Batch Size | Steps | Val MLM Accuracy
----- | --------: | ------------: | ---------: | ----: | ---------------:
Perceiver IO Base (paper) | SentencePiece | T5 + Wiki | 512 | 500 k | N/A
Perceiver IO Base (ours) | WordPiece | Wiki + Books | 512 | 500 k | 68.69 %
## GLUE Finetune
Our perceiver model is fine-tuned on GLUE upon the pre-trained model shown
above. These are all single-task fine-tuning only.
These are run with configurations shown on Table 10 in the [paper](https://arxiv.org/abs/2107.14795).
Model | Tokenizer | Pretrain Data | CoLA | MNLI-m/mm | MRPC | QNLI | QQP | RTE | SST-2 | STS-B | Average
----- | --------: | ------------: | ---: | --------: | ----:| ----:| --: | --: | ----: | ----: | -----:
Perceiver IO Base (paper) | SentencePiece | T5 + Wiki | 47.11 % | 84.53/85.03 % | 87.25 % | 92.12 % | 90.22 % | 65.23 % | 94.38 % | 88.18 % | 81.16 %
Perceiver IO Base (ours) | WordPiece | Wiki + Books | 63.23 % | 84.29/84.52 % | 87.74 % | 91.43 % | 91.22 % | 70.76 % | 94.15 % | 89.85 % | 84.09 %
Note: The average is computed by first averaging the results of MNLI-matched and
MNLI-mismatched, which is then counted as a single task in the overall average.
`Average = (63.23 + (84.29 + 84.52) / 2 + 87.74 + 91.43 + 91.22 + 70.76 + 94.15 + 89.85) / 8`
## Discrepancy with the paper:
* ~+2.93 average GLUE accuracy compared to paper results.
## Citing TensorFlow Model Garden
If you find this codebase helpful in your research, please cite this repository.
```
@misc{tensorflowmodelgarden2022,
author = {Hongkun Yu and Chen Chen and Xianzhi Du and Yeqing Li and
Abdullah Rashwan and Le Hou and Pengchong Jin and Fan Yang and
Frederick Liu and Jaeyoun Kim and Jing Li},
title = {{TensorFlow Model Garden}},
howpublished = {\url{https://github.com/tensorflow/models}},
year = {2020}
}
```
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册