提交 29fa8c42 编写于 作者: F Fan Yang 提交者: A. Unique TensorFlower

Internal change.

PiperOrigin-RevId: 421362994
上级 bfc36ef8
# Image Classification
**Warning:** the features in the `image_classification/` folder have been fully
intergrated into vision/beta. Please use the [new code base](../beta/README.md).
This folder contains TF 2.0 model examples for image classification:
* [MNIST](#mnist)
* [Classifier Trainer](#classifier-trainer), a framework that uses the Keras
compile/fit methods for image classification models, including:
* ResNet
* EfficientNet[^1]
[^1]: Currently a work in progress. We cannot match "AutoAugment (AA)" in [the original version](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet).
For more information about other types of models, please refer to this
[README file](../../README.md).
## Before you begin
Please make sure that you have the latest version of TensorFlow
installed and
[add the models folder to your Python path](/official/#running-the-models).
### ImageNet preparation
#### Using TFDS
`classifier_trainer.py` supports ImageNet with
[TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/overview).
Please see the following [example snippet](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/scripts/download_and_prepare.py)
for more information on how to use TFDS to download and prepare datasets, and
specifically the [TFDS ImageNet readme](https://github.com/tensorflow/datasets/blob/master/docs/catalog/imagenet2012.md)
for manual download instructions.
#### Legacy TFRecords
Download the ImageNet dataset and convert it to TFRecord format.
The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)
and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy)
provide a few options.
Note that the legacy ResNet runners, e.g. [resnet/resnet_ctl_imagenet_main.py](resnet/resnet_ctl_imagenet_main.py)
require TFRecords whereas `classifier_trainer.py` can use both by setting the
builder to 'records' or 'tfds' in the configurations.
### Running on Cloud TPUs
Note: These models will **not** work with TPUs on Colab.
You can train image classification models on Cloud TPUs using
[tf.distribute.TPUStrategy](https://www.tensorflow.org/api_docs/python/tf.distribute.TPUStrategy?version=nightly).
If you are not familiar with Cloud TPUs, it is strongly recommended that you go
through the
[quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
create a TPU and GCE VM.
### Running on multiple GPU hosts
You can also train these models on multiple hosts, each with GPUs, using
[tf.distribute.experimental.MultiWorkerMirroredStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy).
The easiest way to run multi-host benchmarks is to set the
[`TF_CONFIG`](https://www.tensorflow.org/guide/distributed_training#TF_CONFIG)
appropriately at each host. e.g., to run using `MultiWorkerMirroredStrategy` on
2 hosts, the `cluster` in `TF_CONFIG` should have 2 `host:port` entries, and
host `i` should have the `task` in `TF_CONFIG` set to `{"type": "worker",
"index": i}`. `MultiWorkerMirroredStrategy` will automatically use all the
available GPUs at each host.
## MNIST
To download the data and run the MNIST sample model locally for the first time,
run one of the following command:
```bash
python3 mnist_main.py \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--train_epochs=10 \
--distribution_strategy=one_device \
--num_gpus=$NUM_GPUS \
--download
```
To train the model on a Cloud TPU, run the following command:
```bash
python3 mnist_main.py \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--train_epochs=10 \
--distribution_strategy=tpu \
--download
```
Note: the `--download` flag is only required the first time you run the model.
## Classifier Trainer
The classifier trainer is a unified framework for running image classification
models using Keras's compile/fit methods. Experiments should be provided in the
form of YAML files, some examples are included within the configs/examples
folder. Please see [configs/examples](./configs/examples) for more example
configurations.
The provided configuration files use a per replica batch size and is scaled
by the number of devices. For instance, if `batch size` = 64, then for 1 GPU
the global batch size would be 64 * 1 = 64. For 8 GPUs, the global batch size
would be 64 * 8 = 512. Similarly, for a v3-8 TPU, the global batch size would
be 64 * 8 = 512, and for a v3-32, the global batch size is 64 * 32 = 2048.
### ResNet50
#### On GPU:
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
--model_type=resnet \
--dataset=imagenet \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=configs/examples/resnet/imagenet/gpu.yaml \
--params_override='runtime.num_gpus=$NUM_GPUS'
```
To train on multiple hosts, each with GPUs attached using
[MultiWorkerMirroredStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)
please update `runtime` section in gpu.yaml
(or override using `--params_override`) with:
```YAML
# gpu.yaml
runtime:
distribution_strategy: 'multi_worker_mirrored'
worker_hosts: '$HOST1:port,$HOST2:port'
num_gpus: $NUM_GPUS
task_index: 0
```
By having `task_index: 0` on the first host and `task_index: 1` on the second
and so on. `$HOST1` and `$HOST2` are the IP addresses of the hosts, and `port`
can be chosen any free port on the hosts. Only the first host will write
TensorBoard Summaries and save checkpoints.
#### On TPU:
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
--model_type=resnet \
--dataset=imagenet \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=configs/examples/resnet/imagenet/tpu.yaml
```
### EfficientNet
**Note: EfficientNet development is a work in progress.**
#### On GPU:
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
--model_type=efficientnet \
--dataset=imagenet \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml \
--params_override='runtime.num_gpus=$NUM_GPUS'
```
#### On TPU:
```bash
python3 classifier_trainer.py \
--mode=train_and_eval \
--model_type=efficientnet \
--dataset=imagenet \
--tpu=$TPU_NAME \
--model_dir=$MODEL_DIR \
--data_dir=$DATA_DIR \
--config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
```
Note that the number of GPU devices can be overridden in the command line using
`--params_overrides`. The TPU does not need this override as the device is fixed
by providing the TPU address or name with the `--tpu` flag.
This repository is deprecated and replaced by the solid
implementations inside vision/beta/. All the content has been moved to
[official/legacy/image_classification](https://github.com/tensorflow/models/tree/master/official/legacy/image_classification).
......@@ -12,3 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deprecating the vision/detection folder."""
raise ImportError(
'This module has been moved to official/legacy/image_classification')
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for autoaugment."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from official.vision.image_classification import augment
def get_dtype_test_cases():
return [
('uint8', tf.uint8),
('int32', tf.int32),
('float16', tf.float16),
('float32', tf.float32),
]
@parameterized.named_parameters(get_dtype_test_cases())
class TransformsTest(parameterized.TestCase, tf.test.TestCase):
"""Basic tests for fundamental transformations."""
def test_to_from_4d(self, dtype):
for shape in [(10, 10), (10, 10, 10), (10, 10, 10, 10)]:
original_ndims = len(shape)
image = tf.zeros(shape, dtype=dtype)
image_4d = augment.to_4d(image)
self.assertEqual(4, tf.rank(image_4d))
self.assertAllEqual(image, augment.from_4d(image_4d, original_ndims))
def test_transform(self, dtype):
image = tf.constant([[1, 2], [3, 4]], dtype=dtype)
self.assertAllEqual(
augment.transform(image, transforms=[1] * 8), [[4, 4], [4, 4]])
def test_translate(self, dtype):
image = tf.constant(
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype=dtype)
translations = [-1, -1]
translated = augment.translate(image=image, translations=translations)
expected = [[1, 0, 1, 1], [0, 1, 0, 0], [1, 0, 1, 1], [1, 0, 1, 1]]
self.assertAllEqual(translated, expected)
def test_translate_shapes(self, dtype):
translation = [0, 0]
for shape in [(3, 3), (5, 5), (224, 224, 3)]:
image = tf.zeros(shape, dtype=dtype)
self.assertAllEqual(image, augment.translate(image, translation))
def test_translate_invalid_translation(self, dtype):
image = tf.zeros((1, 1), dtype=dtype)
invalid_translation = [[[1, 1]]]
with self.assertRaisesRegex(TypeError, 'rank 1 or 2'):
_ = augment.translate(image, invalid_translation)
def test_rotate(self, dtype):
image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3))
rotation = 90.
transformed = augment.rotate(image=image, degrees=rotation)
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
self.assertAllEqual(transformed, expected)
def test_rotate_shapes(self, dtype):
degrees = 0.
for shape in [(3, 3), (5, 5), (224, 224, 3)]:
image = tf.zeros(shape, dtype=dtype)
self.assertAllEqual(image, augment.rotate(image, degrees))
class AutoaugmentTest(tf.test.TestCase):
def test_autoaugment(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.AutoAugment()
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_randaug(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.RandAugment()
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_all_policy_ops(self):
"""Smoke test to be sure all augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8)
for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image = func(image, *args)
self.assertEqual((224, 224, 3), image.shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Common modules for callbacks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from typing import Any, List, MutableMapping, Optional, Text
from absl import logging
import tensorflow as tf
from official.modeling import optimization
from official.utils.misc import keras_utils
def get_callbacks(
model_checkpoint: bool = True,
include_tensorboard: bool = True,
time_history: bool = True,
track_lr: bool = True,
write_model_weights: bool = True,
apply_moving_average: bool = False,
initial_step: int = 0,
batch_size: int = 0,
log_steps: int = 0,
model_dir: Optional[str] = None,
backup_and_restore: bool = False) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks."""
model_dir = model_dir or ''
callbacks = []
if model_checkpoint:
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
if backup_and_restore:
backup_dir = os.path.join(model_dir, 'tmp')
callbacks.append(
tf.keras.callbacks.experimental.BackupAndRestore(backup_dir))
if include_tensorboard:
callbacks.append(
CustomTensorBoard(
log_dir=model_dir,
track_lr=track_lr,
initial_step=initial_step,
write_images=write_model_weights,
profile_batch=0))
if time_history:
callbacks.append(
keras_utils.TimeHistory(
batch_size,
log_steps,
logdir=model_dir if include_tensorboard else None))
if apply_moving_average:
# Save moving average model to a different file so that
# we can resume training from a checkpoint
ckpt_full_path = os.path.join(model_dir, 'average',
'model.ckpt-{epoch:04d}')
callbacks.append(
AverageModelCheckpoint(
update_weights=False,
filepath=ckpt_full_path,
save_weights_only=True,
verbose=1))
callbacks.append(MovingAverageCallback())
return callbacks
def get_scalar_from_tensor(t: tf.Tensor) -> int:
"""Utility function to convert a Tensor to a scalar."""
t = tf.keras.backend.get_value(t)
if callable(t):
return t()
else:
return t
class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
"""A customized TensorBoard callback that tracks additional datapoints.
Metrics tracked:
- Global learning rate
Attributes:
log_dir: the path of the directory where to save the log files to be parsed
by TensorBoard.
track_lr: `bool`, whether or not to track the global learning rate.
initial_step: the initial step, used for preemption recovery.
**kwargs: Additional arguments for backwards compatibility. Possible key is
`period`.
"""
# TODO(b/146499062): track params, flops, log lr, l2 loss,
# classification loss
def __init__(self,
log_dir: str,
track_lr: bool = False,
initial_step: int = 0,
**kwargs):
super(CustomTensorBoard, self).__init__(log_dir=log_dir, **kwargs)
self.step = initial_step
self._track_lr = track_lr
def on_batch_begin(self,
epoch: int,
logs: Optional[MutableMapping[str, Any]] = None) -> None:
self.step += 1
if logs is None:
logs = {}
logs.update(self._calculate_metrics())
super(CustomTensorBoard, self).on_batch_begin(epoch, logs)
def on_epoch_begin(self,
epoch: int,
logs: Optional[MutableMapping[str, Any]] = None) -> None:
if logs is None:
logs = {}
metrics = self._calculate_metrics()
logs.update(metrics)
for k, v in metrics.items():
logging.info('Current %s: %f', k, v)
super(CustomTensorBoard, self).on_epoch_begin(epoch, logs)
def on_epoch_end(self,
epoch: int,
logs: Optional[MutableMapping[str, Any]] = None) -> None:
if logs is None:
logs = {}
metrics = self._calculate_metrics()
logs.update(metrics)
super(CustomTensorBoard, self).on_epoch_end(epoch, logs)
def _calculate_metrics(self) -> MutableMapping[str, Any]:
logs = {}
# TODO(b/149030439): disable LR reporting.
# if self._track_lr:
# logs['learning_rate'] = self._calculate_lr()
return logs
def _calculate_lr(self) -> int:
"""Calculates the learning rate given the current step."""
return get_scalar_from_tensor(
self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32)) # pylint:disable=protected-access
def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Get the base optimizer used by the current model."""
optimizer = self.model.optimizer
# The optimizer might be wrapped by another class, so unwrap it
while hasattr(optimizer, '_optimizer'):
optimizer = optimizer._optimizer # pylint:disable=protected-access
return optimizer
class MovingAverageCallback(tf.keras.callbacks.Callback):
"""A Callback to be used with a `ExponentialMovingAverage` optimizer.
Applies moving average weights to the model during validation time to test
and predict on the averaged weights rather than the current model weights.
Once training is complete, the model weights will be overwritten with the
averaged weights (by default).
Attributes:
overwrite_weights_on_train_end: Whether to overwrite the current model
weights with the averaged weights from the moving average optimizer.
**kwargs: Any additional callback arguments.
"""
def __init__(self, overwrite_weights_on_train_end: bool = False, **kwargs):
super(MovingAverageCallback, self).__init__(**kwargs)
self.overwrite_weights_on_train_end = overwrite_weights_on_train_end
def set_model(self, model: tf.keras.Model):
super(MovingAverageCallback, self).set_model(model)
assert isinstance(self.model.optimizer,
optimization.ExponentialMovingAverage)
self.model.optimizer.shadow_copy(self.model)
def on_test_begin(self, logs: Optional[MutableMapping[Text, Any]] = None):
self.model.optimizer.swap_weights()
def on_test_end(self, logs: Optional[MutableMapping[Text, Any]] = None):
self.model.optimizer.swap_weights()
def on_train_end(self, logs: Optional[MutableMapping[Text, Any]] = None):
if self.overwrite_weights_on_train_end:
self.model.optimizer.assign_average_vars(self.model.variables)
class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
"""Saves and, optionally, assigns the averaged weights.
Taken from tfa.callbacks.AverageModelCheckpoint.
Attributes:
update_weights: If True, assign the moving average weights to the model, and
save them. If False, keep the old non-averaged weights, but the saved
model uses the average weights. See `tf.keras.callbacks.ModelCheckpoint`
for the other args.
"""
def __init__(self,
update_weights: bool,
filepath: str,
monitor: str = 'val_loss',
verbose: int = 0,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = 'auto',
save_freq: str = 'epoch',
**kwargs):
self.update_weights = update_weights
super().__init__(filepath, monitor, verbose, save_best_only,
save_weights_only, mode, save_freq, **kwargs)
def set_model(self, model):
if not isinstance(model.optimizer, optimization.ExponentialMovingAverage):
raise TypeError('AverageModelCheckpoint is only used when training'
'with MovingAverage')
return super().set_model(model)
def _save_model(self, epoch, logs):
assert isinstance(self.model.optimizer,
optimization.ExponentialMovingAverage)
if self.update_weights:
self.model.optimizer.assign_average_vars(self.model.variables)
return super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
else:
# Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables.
non_avg_weights = self.model.get_weights()
self.model.optimizer.assign_average_vars(self.model.variables)
# result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future.
result = super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
self.model.set_weights(non_avg_weights)
return result
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Runs an Image Classification model."""
import os
import pprint
from typing import Any, Tuple, Text, Optional, Mapping
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import hyperparams
from official.modeling import performance
from official.utils import hyperparams_flags
from official.utils.misc import keras_utils
from official.vision.image_classification import callbacks as custom_callbacks
from official.vision.image_classification import dataset_factory
from official.vision.image_classification import optimizer_factory
from official.vision.image_classification.configs import base_configs
from official.vision.image_classification.configs import configs
from official.vision.image_classification.efficientnet import efficientnet_model
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import resnet_model
def get_models() -> Mapping[str, tf.keras.Model]:
"""Returns the mapping from model type name to Keras model."""
return {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
}
def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
"""Returns the mapping from dtype string representations to TF dtypes."""
return {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
"""Get a dict of available metrics to track."""
if one_hot:
return {
# (name, metric_fn)
'acc':
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'accuracy':
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'top_1':
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'top_5':
tf.keras.metrics.TopKCategoricalAccuracy(
k=5, name='top_5_accuracy'),
}
else:
return {
# (name, metric_fn)
'acc':
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'accuracy':
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'top_1':
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'top_5':
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=5, name='top_5_accuracy'),
}
def get_image_size_from_model(
params: base_configs.ExperimentConfig) -> Optional[int]:
"""If the given model has a preferred image size, return it."""
if params.model_name == 'efficientnet':
efficientnet_name = params.model.model_params.model_name
if efficientnet_name in efficientnet_model.MODEL_CONFIGS:
return efficientnet_model.MODEL_CONFIGS[efficientnet_name].resolution
return None
def _get_dataset_builders(params: base_configs.ExperimentConfig,
strategy: tf.distribute.Strategy,
one_hot: bool) -> Tuple[Any, Any]:
"""Create and return train and validation dataset builders."""
if one_hot:
logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
else:
logging.warning('label_smoothing not applied, so datasets will not be one '
'hot encoded.')
num_devices = strategy.num_replicas_in_sync if strategy else 1
image_size = get_image_size_from_model(params)
dataset_configs = [params.train_dataset, params.validation_dataset]
builders = []
for config in dataset_configs:
if config is not None and config.has_data:
builder = dataset_factory.DatasetBuilder(
config,
image_size=image_size or config.image_size,
num_devices=num_devices,
one_hot=one_hot)
else:
builder = None
builders.append(builder)
return builders
def get_loss_scale(params: base_configs.ExperimentConfig,
fp16_default: float = 128.) -> float:
"""Returns the loss scale for initializations."""
loss_scale = params.runtime.loss_scale
if loss_scale == 'dynamic':
return loss_scale
elif loss_scale is not None:
return float(loss_scale)
elif (params.train_dataset.dtype == 'float32' or
params.train_dataset.dtype == 'bfloat16'):
return 1.
else:
assert params.train_dataset.dtype == 'float16'
return fp16_default
def _get_params_from_flags(flags_obj: flags.FlagValues):
"""Get ParamsDict from flags."""
model = flags_obj.model_type.lower()
dataset = flags_obj.dataset.lower()
params = configs.get_config(model=model, dataset=dataset)
flags_overrides = {
'model_dir': flags_obj.model_dir,
'mode': flags_obj.mode,
'model': {
'name': model,
},
'runtime': {
'run_eagerly': flags_obj.run_eagerly,
'tpu': flags_obj.tpu,
},
'train_dataset': {
'data_dir': flags_obj.data_dir,
},
'validation_dataset': {
'data_dir': flags_obj.data_dir,
},
'train': {
'time_history': {
'log_steps': flags_obj.log_steps,
},
},
}
overriding_configs = (flags_obj.config_file, flags_obj.params_override,
flags_overrides)
pp = pprint.PrettyPrinter()
logging.info('Base params: %s', pp.pformat(params.as_dict()))
for param in overriding_configs:
logging.info('Overriding params: %s', param)
params = hyperparams.override_params_dict(params, param, is_strict=True)
params.validate()
params.lock()
logging.info('Final model parameters: %s', pp.pformat(params.as_dict()))
return params
def resume_from_checkpoint(model: tf.keras.Model, model_dir: str,
train_steps: int) -> int:
"""Resumes from the latest checkpoint, if possible.
Loads the model weights and optimizer settings from a checkpoint.
This function should be used in case of preemption recovery.
Args:
model: The model whose weights should be restored.
model_dir: The directory where model weights were saved.
train_steps: The number of steps to train.
Returns:
The epoch of the latest checkpoint, or 0 if not restoring.
"""
logging.info('Load from checkpoint is enabled.')
latest_checkpoint = tf.train.latest_checkpoint(model_dir)
logging.info('latest_checkpoint: %s', latest_checkpoint)
if not latest_checkpoint:
logging.info('No checkpoint detected.')
return 0
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint)
model.load_weights(latest_checkpoint)
initial_epoch = model.optimizer.iterations // train_steps
logging.info('Completed loading from checkpoint.')
logging.info('Resuming from epoch %d', initial_epoch)
return int(initial_epoch)
def initialize(params: base_configs.ExperimentConfig,
dataset_builder: dataset_factory.DatasetBuilder):
"""Initializes backend related initializations."""
keras_utils.set_session_config(enable_xla=params.runtime.enable_xla)
performance.set_mixed_precision_policy(dataset_builder.dtype)
if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first'
else:
data_format = 'channels_last'
tf.keras.backend.set_image_data_format(data_format)
if params.runtime.run_eagerly:
# Enable eager execution to allow step-by-step debugging
tf.config.experimental_run_functions_eagerly(True)
if tf.config.list_physical_devices('GPU'):
if params.runtime.gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=params.runtime.per_gpu_thread_count,
gpu_thread_mode=params.runtime.gpu_thread_mode,
num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime
.dataset_num_private_threads) # pylint:disable=line-too-long
if params.runtime.batchnorm_spatial_persistent:
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
def define_classifier_flags():
"""Defines common flags for image classification."""
hyperparams_flags.initialize_common_flags()
flags.DEFINE_string(
'data_dir', default=None, help='The location of the input data.')
flags.DEFINE_string(
'mode',
default=None,
help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.')
flags.DEFINE_bool(
'run_eagerly',
default=None,
help='Use eager execution and disable autograph for debugging.')
flags.DEFINE_string(
'model_type',
default=None,
help='The type of the model, e.g. EfficientNet, etc.')
flags.DEFINE_string(
'dataset',
default=None,
help='The name of the dataset, e.g. ImageNet, etc.')
flags.DEFINE_integer(
'log_steps',
default=100,
help='The interval of steps between logging of batch level stats.')
def serialize_config(params: base_configs.ExperimentConfig, model_dir: str):
"""Serializes and saves the experiment config."""
params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir)
hyperparams.save_params_dict_to_yaml(params, params_save_path)
def train_and_eval(
params: base_configs.ExperimentConfig,
strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
"""Runs the train and eval path using compile/fit."""
logging.info('Running train and eval.')
distribute_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index)
# Note: for TPUs, strategy and scope should be created before the dataset
strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
logging.info('Detected %d devices.',
strategy.num_replicas_in_sync if strategy else 1)
label_smoothing = params.model.loss.label_smoothing
one_hot = label_smoothing and label_smoothing > 0
builders = _get_dataset_builders(params, strategy, one_hot)
datasets = [
builder.build(strategy) if builder else None for builder in builders
]
# Unpack datasets and builders based on train/val/test splits
train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
train_dataset, validation_dataset = datasets
train_epochs = params.train.epochs
train_steps = params.train.steps or train_builder.num_steps
validation_steps = params.evaluation.steps or validation_builder.num_steps
initialize(params, train_builder)
logging.info('Global batch size: %d', train_builder.global_batch_size)
with strategy_scope:
model_params = params.model.model_params.as_dict()
model = get_models()[params.model.name](**model_params)
learning_rate = optimizer_factory.build_learning_rate(
params=params.model.learning_rate,
batch_size=train_builder.global_batch_size,
train_epochs=train_epochs,
train_steps=train_steps)
optimizer = optimizer_factory.build_optimizer(
optimizer_name=params.model.optimizer.name,
base_learning_rate=learning_rate,
params=params.model.optimizer.as_dict(),
model=model)
optimizer = performance.configure_optimizer(
optimizer,
use_float16=train_builder.dtype == 'float16',
loss_scale=get_loss_scale(params))
metrics_map = _get_metrics(one_hot)
metrics = [metrics_map[metric] for metric in params.train.metrics]
steps_per_loop = train_steps if params.train.set_epoch_loop else 1
if one_hot:
loss_obj = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=params.model.loss.label_smoothing)
else:
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(
optimizer=optimizer,
loss=loss_obj,
metrics=metrics,
steps_per_execution=steps_per_loop)
initial_epoch = 0
if params.train.resume_checkpoint:
initial_epoch = resume_from_checkpoint(
model=model, model_dir=params.model_dir, train_steps=train_steps)
callbacks = custom_callbacks.get_callbacks(
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
include_tensorboard=params.train.callbacks.enable_tensorboard,
time_history=params.train.callbacks.enable_time_history,
track_lr=params.train.tensorboard.track_lr,
write_model_weights=params.train.tensorboard.write_model_weights,
initial_step=initial_epoch * train_steps,
batch_size=train_builder.global_batch_size,
log_steps=params.train.time_history.log_steps,
model_dir=params.model_dir,
backup_and_restore=params.train.callbacks.enable_backup_and_restore)
serialize_config(params=params, model_dir=params.model_dir)
if params.evaluation.skip_eval:
validation_kwargs = {}
else:
validation_kwargs = {
'validation_data': validation_dataset,
'validation_steps': validation_steps,
'validation_freq': params.evaluation.epochs_between_evals,
}
history = model.fit(
train_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
initial_epoch=initial_epoch,
callbacks=callbacks,
verbose=2,
**validation_kwargs)
validation_output = None
if not params.evaluation.skip_eval:
validation_output = model.evaluate(
validation_dataset, steps=validation_steps, verbose=2)
# TODO(dankondratyuk): eval and save final test accuracy
stats = common.build_stats(history, validation_output, callbacks)
return stats
def export(params: base_configs.ExperimentConfig):
"""Runs the model export functionality."""
logging.info('Exporting model.')
model_params = params.model.model_params.as_dict()
model = get_models()[params.model.name](**model_params)
checkpoint = params.export.checkpoint
if checkpoint is None:
logging.info('No export checkpoint was provided. Using the latest '
'checkpoint from model_dir.')
checkpoint = tf.train.latest_checkpoint(params.model_dir)
model.load_weights(checkpoint)
model.save(params.export.destination)
def run(flags_obj: flags.FlagValues,
strategy_override: tf.distribute.Strategy = None) -> Mapping[str, Any]:
"""Runs Image Classification model using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
strategy_override: A `tf.distribute.Strategy` object to use for model.
Returns:
Dictionary of training/eval stats
"""
params = _get_params_from_flags(flags_obj)
if params.mode == 'train_and_eval':
return train_and_eval(params, strategy_override)
elif params.mode == 'export_only':
export(params)
else:
raise ValueError('{} is not a valid mode.'.format(params.mode))
def main(_):
stats = run(flags.FLAGS)
if stats:
logging.info('Run stats:\n%s', stats)
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
define_classifier_flags()
flags.mark_flag_as_required('data_dir')
flags.mark_flag_as_required('mode')
flags.mark_flag_as_required('model_type')
flags.mark_flag_as_required('dataset')
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Unit tests for the classifier trainer models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import os
import sys
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Tuple
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.utils.flags import core as flags_core
from official.vision.image_classification import classifier_trainer
classifier_trainer.define_classifier_flags()
def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
"""Returns the combinations of end-to-end tests to run."""
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
model=[
'efficientnet',
'resnet',
],
dataset=[
'imagenet',
],
)
def get_params_override(params_override: Mapping[str, Any]) -> str:
"""Converts params_override dict to string command."""
return '--params_override=' + json.dumps(params_override)
def basic_params_override(dtype: str = 'float32') -> MutableMapping[str, Any]:
"""Returns a basic parameter configuration for testing."""
return {
'train_dataset': {
'builder': 'synthetic',
'use_per_replica_batch_size': True,
'batch_size': 1,
'image_size': 224,
'dtype': dtype,
},
'validation_dataset': {
'builder': 'synthetic',
'batch_size': 1,
'use_per_replica_batch_size': True,
'image_size': 224,
'dtype': dtype,
},
'train': {
'steps': 1,
'epochs': 1,
'callbacks': {
'enable_checkpoint_and_export': True,
'enable_tensorboard': False,
},
},
'evaluation': {
'steps': 1,
},
}
@flagsaver.flagsaver
def run_end_to_end(main: Callable[[Any], None],
extra_flags: Optional[Iterable[str]] = None,
model_dir: Optional[str] = None):
"""Runs the classifier trainer end-to-end."""
extra_flags = [] if extra_flags is None else extra_flags
args = [sys.argv[0], '--model_dir', model_dir] + extra_flags
flags_core.parse_flags(argv=args)
main(flags.FLAGS)
class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
"""Unit tests for Keras models."""
_tempdir = None
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(ClassifierTest, cls).setUpClass()
def tearDown(self):
super(ClassifierTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
@combinations.generate(distribution_strategy_combinations())
def test_end_to_end_train_and_eval(self, distribution, model, dataset):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.create_tempdir().full_path
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override()),
'--mode=train_and_eval',
]
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
run_end_to_end(
main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
],
model=[
'efficientnet',
'resnet',
],
dataset='imagenet',
dtype='float16',
))
def test_gpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.create_tempdir().full_path
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
export_params = basic_params_override()
export_path = os.path.join(model_dir, 'export')
export_params['export'] = {}
export_params['export']['destination'] = export_path
export_flags = base_flags + [
'--mode=export_only',
get_params_override(export_params)
]
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
run_end_to_end(
main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
run_end_to_end(main=run, extra_flags=export_flags, model_dir=model_dir)
self.assertTrue(os.path.exists(export_path))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
],
model=[
'efficientnet',
'resnet',
],
dataset='imagenet',
dtype='bfloat16',
))
def test_tpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.create_tempdir().full_path
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
run_end_to_end(
main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
@combinations.generate(distribution_strategy_combinations())
def test_end_to_end_invalid_mode(self, distribution, model, dataset):
"""Test the Keras EfficientNet model with `strategy`."""
model_dir = self.create_tempdir().full_path
extra_flags = [
'--data_dir=not_used',
'--mode=invalid_mode',
'--model_type=' + model,
'--dataset=' + dataset,
get_params_override(basic_params_override()),
]
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
with self.assertRaises(ValueError):
run_end_to_end(main=run, extra_flags=extra_flags, model_dir=model_dir)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Unit tests for the classifier trainer models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import os
from absl.testing import parameterized
import tensorflow as tf
from official.vision.image_classification import classifier_trainer
from official.vision.image_classification import dataset_factory
from official.vision.image_classification import test_utils
from official.vision.image_classification.configs import base_configs
def get_trivial_model(num_classes: int) -> tf.keras.Model:
"""Creates and compiles trivial model for ImageNet dataset."""
model = test_utils.trivial_model(num_classes=num_classes)
lr = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer, loss=loss_obj, run_eagerly=True)
return model
def get_trivial_data() -> tf.data.Dataset:
"""Gets trivial data in the ImageNet size."""
def generate_data(_) -> tf.data.Dataset:
image = tf.zeros(shape=(224, 224, 3), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
return image, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=1).batch(1)
return dataset
class UtilTests(parameterized.TestCase, tf.test.TestCase):
"""Tests for individual utility functions within classifier_trainer.py."""
@parameterized.named_parameters(
('efficientnet-b0', 'efficientnet', 'efficientnet-b0', 224),
('efficientnet-b1', 'efficientnet', 'efficientnet-b1', 240),
('efficientnet-b2', 'efficientnet', 'efficientnet-b2', 260),
('efficientnet-b3', 'efficientnet', 'efficientnet-b3', 300),
('efficientnet-b4', 'efficientnet', 'efficientnet-b4', 380),
('efficientnet-b5', 'efficientnet', 'efficientnet-b5', 456),
('efficientnet-b6', 'efficientnet', 'efficientnet-b6', 528),
('efficientnet-b7', 'efficientnet', 'efficientnet-b7', 600),
('resnet', 'resnet', '', None),
)
def test_get_model_size(self, model, model_name, expected):
config = base_configs.ExperimentConfig(
model_name=model,
model=base_configs.ModelConfig(
model_params={
'model_name': model_name,
},))
size = classifier_trainer.get_image_size_from_model(config)
self.assertEqual(size, expected)
@parameterized.named_parameters(
('dynamic', 'dynamic', None, 'dynamic'),
('scalar', 128., None, 128.),
('float32', None, 'float32', 1),
('float16', None, 'float16', 128),
)
def test_get_loss_scale(self, loss_scale, dtype, expected):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(loss_scale=loss_scale),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
self.assertEqual(ls, expected)
@parameterized.named_parameters(('float16', 'float16'),
('bfloat16', 'bfloat16'))
def test_initialize(self, dtype):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(
run_eagerly=False,
enable_xla=False,
per_gpu_thread_count=1,
gpu_thread_mode='gpu_private',
num_gpus=1,
dataset_num_private_threads=1,
),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
model=base_configs.ModelConfig(),
)
class EmptyClass:
pass
fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass()
classifier_trainer.initialize(config, fake_ds_builder)
def test_resume_from_checkpoint(self):
"""Tests functionality for resuming from checkpoint."""
# Set the keras policy
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
# Get the model, datasets, and compile it.
model = get_trivial_model(10)
# Create the checkpoint
model_dir = self.create_tempdir().full_path
train_epochs = 1
train_steps = 10
ds = get_trivial_data()
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
os.path.join(model_dir, 'model.ckpt-{epoch:04d}'),
save_weights_only=True)
]
model.fit(
ds,
callbacks=callbacks,
epochs=train_epochs,
steps_per_epoch=train_steps)
# Test load from checkpoint
clean_model = get_trivial_model(10)
weights_before_load = copy.deepcopy(clean_model.get_weights())
initial_epoch = classifier_trainer.resume_from_checkpoint(
model=clean_model, model_dir=model_dir, train_steps=train_steps)
self.assertEqual(initial_epoch, 1)
self.assertNotAllClose(weights_before_load, clean_model.get_weights())
tf.io.gfile.rmtree(model_dir)
def test_serialize_config(self):
"""Tests functionality for serializing data."""
config = base_configs.ExperimentConfig()
model_dir = self.create_tempdir().full_path
classifier_trainer.serialize_config(params=config, model_dir=model_dir)
saved_params_path = os.path.join(model_dir, 'params.yaml')
self.assertTrue(os.path.exists(saved_params_path))
tf.io.gfile.rmtree(model_dir)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Definitions for high level configuration groups.."""
import dataclasses
from typing import Any, List, Mapping, Optional
from official.core import config_definitions
from official.modeling import hyperparams
RuntimeConfig = config_definitions.RuntimeConfig
@dataclasses.dataclass
class TensorBoardConfig(hyperparams.Config):
"""Configuration for TensorBoard.
Attributes:
track_lr: Whether or not to track the learning rate in TensorBoard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as images in
TensorBoard. Defaults to False.
"""
track_lr: bool = True
write_model_weights: bool = False
@dataclasses.dataclass
class CallbacksConfig(hyperparams.Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable TensorBoard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True
enable_time_history: bool = True
@dataclasses.dataclass
class ExportConfig(hyperparams.Config):
"""Configuration for exports.
Attributes:
checkpoint: the path to the checkpoint to export.
destination: the path to where the checkpoint should be exported.
"""
checkpoint: str = None
destination: str = None
@dataclasses.dataclass
class MetricsConfig(hyperparams.Config):
"""Configuration for Metrics.
Attributes:
accuracy: Whether or not to track accuracy as a Callback. Defaults to None.
top_5: Whether or not to track top_5_accuracy as a Callback. Defaults to
None.
"""
accuracy: bool = None
top_5: bool = None
@dataclasses.dataclass
class TimeHistoryConfig(hyperparams.Config):
"""Configuration for the TimeHistory callback.
Attributes:
log_steps: Interval of steps between logging of batch level stats.
"""
log_steps: int = None
@dataclasses.dataclass
class TrainConfig(hyperparams.Config):
"""Configuration for training.
Attributes:
resume_checkpoint: Whether or not to enable load checkpoint loading.
Defaults to None.
epochs: The number of training epochs to run. Defaults to None.
steps: The number of steps to run per epoch. If None, then this will be
inferred based on the number of images and batch size. Defaults to None.
callbacks: An instance of CallbacksConfig.
metrics: An instance of MetricsConfig.
tensorboard: An instance of TensorBoardConfig.
set_epoch_loop: Whether or not to set `steps_per_execution` to
equal the number of training steps in `model.compile`. This reduces the
number of callbacks run per epoch which significantly improves end-to-end
TPU training time.
"""
resume_checkpoint: bool = None
epochs: int = None
steps: int = None
callbacks: CallbacksConfig = CallbacksConfig()
metrics: MetricsConfig = None
tensorboard: TensorBoardConfig = TensorBoardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig()
set_epoch_loop: bool = False
@dataclasses.dataclass
class EvalConfig(hyperparams.Config):
"""Configuration for evaluation.
Attributes:
epochs_between_evals: The number of train epochs to run between evaluations.
Defaults to None.
steps: The number of eval steps to run during evaluation. If None, this will
be inferred based on the number of images and batch size. Defaults to
None.
skip_eval: Whether or not to skip evaluation.
"""
epochs_between_evals: int = None
steps: int = None
skip_eval: bool = False
@dataclasses.dataclass
class LossConfig(hyperparams.Config):
"""Configuration for Loss.
Attributes:
name: The name of the loss. Defaults to None.
label_smoothing: Whether or not to apply label smoothing to the loss. This
only applies to 'categorical_cross_entropy'.
"""
name: str = None
label_smoothing: float = None
@dataclasses.dataclass
class OptimizerConfig(hyperparams.Config):
"""Configuration for Optimizers.
Attributes:
name: The name of the optimizer. Defaults to None.
decay: Decay or rho, discounting factor for gradient. Defaults to None.
epsilon: Small value used to avoid 0 denominator. Defaults to None.
momentum: Plain momentum constant. Defaults to None.
nesterov: Whether or not to apply Nesterov momentum. Defaults to None.
moving_average_decay: The amount of decay to apply. If 0 or None, then
exponential moving average is not used. Defaults to None.
lookahead: Whether or not to apply the lookahead optimizer. Defaults to
None.
beta_1: The exponential decay rate for the 1st moment estimates. Used in the
Adam optimizers. Defaults to None.
beta_2: The exponential decay rate for the 2nd moment estimates. Used in the
Adam optimizers. Defaults to None.
epsilon: Small value used to avoid 0 denominator. Defaults to 1e-7.
"""
name: str = None
decay: float = None
epsilon: float = None
momentum: float = None
nesterov: bool = None
moving_average_decay: Optional[float] = None
lookahead: Optional[bool] = None
beta_1: float = None
beta_2: float = None
epsilon: float = None
@dataclasses.dataclass
class LearningRateConfig(hyperparams.Config):
"""Configuration for learning rates.
Attributes:
name: The name of the learning rate. Defaults to None.
initial_lr: The initial learning rate. Defaults to None.
decay_epochs: The number of decay epochs. Defaults to None.
decay_rate: The rate of decay. Defaults to None.
warmup_epochs: The number of warmup epochs. Defaults to None.
batch_lr_multiplier: The multiplier to apply to the base learning rate, if
necessary. Defaults to None.
examples_per_epoch: the number of examples in a single epoch. Defaults to
None.
boundaries: boundaries used in piecewise constant decay with warmup.
multipliers: multipliers used in piecewise constant decay with warmup.
scale_by_batch_size: Scale the learning rate by a fraction of the batch
size. Set to 0 for no scaling (default).
staircase: Apply exponential decay at discrete values instead of continuous.
"""
name: str = None
initial_lr: float = None
decay_epochs: float = None
decay_rate: float = None
warmup_epochs: int = None
examples_per_epoch: int = None
boundaries: List[int] = None
multipliers: List[float] = None
scale_by_batch_size: float = 0.
staircase: bool = None
@dataclasses.dataclass
class ModelConfig(hyperparams.Config):
"""Configuration for Models.
Attributes:
name: The name of the model. Defaults to None.
model_params: The parameters used to create the model. Defaults to None.
num_classes: The number of classes in the model. Defaults to None.
loss: A `LossConfig` instance. Defaults to None.
optimizer: An `OptimizerConfig` instance. Defaults to None.
"""
name: str = None
model_params: hyperparams.Config = None
num_classes: int = None
loss: LossConfig = None
optimizer: OptimizerConfig = None
@dataclasses.dataclass
class ExperimentConfig(hyperparams.Config):
"""Base configuration for an image classification experiment.
Attributes:
model_dir: The directory to use when running an experiment.
mode: e.g. 'train_and_eval', 'export'
runtime: A `RuntimeConfig` instance.
train: A `TrainConfig` instance.
evaluation: An `EvalConfig` instance.
model: A `ModelConfig` instance.
export: An `ExportConfig` instance.
"""
model_dir: str = None
model_name: str = None
mode: str = None
runtime: RuntimeConfig = None
train_dataset: Any = None
validation_dataset: Any = None
train: TrainConfig = None
evaluation: EvalConfig = None
model: ModelConfig = None
export: ExportConfig = None
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configuration utils for image classification experiments."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import dataclasses
from official.vision.image_classification import dataset_factory
from official.vision.image_classification.configs import base_configs
from official.vision.image_classification.efficientnet import efficientnet_config
from official.vision.image_classification.resnet import resnet_config
@dataclasses.dataclass
class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
"""Base configuration to train efficientnet-b0 on ImageNet.
Attributes:
export: An `ExportConfig` instance
runtime: A `RuntimeConfig` instance.
dataset: A `DatasetConfig` instance.
train: A `TrainConfig` instance.
evaluation: An `EvalConfig` instance.
model: A `ModelConfig` instance.
"""
export: base_configs.ExportConfig = base_configs.ExportConfig()
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
train_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='train')
validation_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation')
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=500,
steps=None,
callbacks=base_configs.CallbacksConfig(
enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorBoardConfig(
track_lr=True, write_model_weights=False),
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = \
efficientnet_config.EfficientNetModelConfig()
@dataclasses.dataclass
class ResNetImagenetConfig(base_configs.ExperimentConfig):
"""Base configuration to train resnet-50 on ImageNet."""
export: base_configs.ExportConfig = base_configs.ExportConfig()
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
train_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='train',
one_hot=False,
mean_subtract=True,
standardize=True)
validation_dataset: dataset_factory.DatasetConfig = \
dataset_factory.ImageNetConfig(split='validation',
one_hot=False,
mean_subtract=True,
standardize=True)
train: base_configs.TrainConfig = base_configs.TrainConfig(
resume_checkpoint=True,
epochs=90,
steps=None,
callbacks=base_configs.CallbacksConfig(
enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorBoardConfig(
track_lr=True, write_model_weights=False),
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
"""Given model and dataset names, return the ExperimentConfig."""
dataset_model_config_map = {
'imagenet': {
'efficientnet': EfficientNetImageNetConfig(),
'resnet': ResNetImagenetConfig(),
}
}
try:
return dataset_model_config_map[dataset][model]
except KeyError:
if dataset not in dataset_model_config_map:
raise KeyError('Invalid dataset received. Received: {}. Supported '
'datasets include: {}'.format(
dataset, ', '.join(dataset_model_config_map.keys())))
raise KeyError('Invalid model received. Received: {}. Supported models for'
'{} include: {}'.format(
model, dataset,
', '.join(dataset_model_config_map[dataset].keys())))
# Training configuration for EfficientNet-b0 trained on ImageNet on GPUs.
# Takes ~32 minutes per epoch for 8 V100s.
# Reaches ~76.1% within 350 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
distribution_strategy: 'mirrored'
num_gpus: 1
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'train'
num_classes: 1000
num_examples: 1281167
batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32'
augmenter:
name: 'autoaugment'
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'validation'
num_classes: 1000
num_examples: 50000
batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32'
model:
model_params:
model_name: 'efficientnet-b0'
overrides:
num_classes: 1000
batch_norm: 'default'
dtype: 'float32'
activation: 'swish'
optimizer:
name: 'rmsprop'
momentum: 0.9
decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate:
name: 'exponential'
loss:
label_smoothing: 0.1
train:
resume_checkpoint: True
epochs: 500
evaluation:
epochs_between_evals: 1
# Training configuration for EfficientNet-b0 trained on ImageNet on TPUs.
# Takes ~2 minutes, 50 seconds per epoch for v3-32.
# Reaches ~76.1% within 350 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
distribution_strategy: 'tpu'
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'train'
num_classes: 1000
num_examples: 1281167
batch_size: 128
use_per_replica_batch_size: True
dtype: 'bfloat16'
augmenter:
name: 'autoaugment'
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'validation'
num_classes: 1000
num_examples: 50000
batch_size: 128
use_per_replica_batch_size: True
dtype: 'bfloat16'
model:
model_params:
model_name: 'efficientnet-b0'
overrides:
num_classes: 1000
batch_norm: 'tpu'
dtype: 'bfloat16'
activation: 'swish'
optimizer:
name: 'rmsprop'
momentum: 0.9
decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate:
name: 'exponential'
loss:
label_smoothing: 0.1
train:
resume_checkpoint: True
epochs: 500
set_epoch_loop: True
evaluation:
epochs_between_evals: 1
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
distribution_strategy: 'mirrored'
num_gpus: 1
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'train'
num_classes: 1000
num_examples: 1281167
batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32'
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'validation'
num_classes: 1000
num_examples: 50000
batch_size: 32
use_per_replica_batch_size: True
dtype: 'float32'
model:
model_params:
model_name: 'efficientnet-b1'
overrides:
num_classes: 1000
batch_norm: 'default'
dtype: 'float32'
activation: 'swish'
optimizer:
name: 'rmsprop'
momentum: 0.9
decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate:
name: 'exponential'
loss:
label_smoothing: 0.1
train:
resume_checkpoint: True
epochs: 500
evaluation:
epochs_between_evals: 1
# Training configuration for EfficientNet-b1 trained on ImageNet on TPUs.
# Takes ~3 minutes, 15 seconds per epoch for v3-32.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
distribution_strategy: 'tpu'
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'train'
num_classes: 1000
num_examples: 1281167
batch_size: 128
use_per_replica_batch_size: True
dtype: 'bfloat16'
augmenter:
name: 'autoaugment'
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'records'
split: 'validation'
num_classes: 1000
num_examples: 50000
batch_size: 128
use_per_replica_batch_size: True
dtype: 'bfloat16'
model:
model_params:
model_name: 'efficientnet-b1'
overrides:
num_classes: 1000
batch_norm: 'tpu'
dtype: 'bfloat16'
activation: 'swish'
optimizer:
name: 'rmsprop'
momentum: 0.9
decay: 0.9
moving_average_decay: 0.0
lookahead: false
learning_rate:
name: 'exponential'
loss:
label_smoothing: 0.1
train:
resume_checkpoint: True
epochs: 500
set_epoch_loop: True
evaluation:
epochs_between_evals: 1
# Training configuration for ResNet trained on ImageNet on GPUs.
# Reaches > 76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
distribution_strategy: 'mirrored'
num_gpus: 1
batchnorm_spatial_persistent: True
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'tfds'
split: 'train'
image_size: 224
num_classes: 1000
num_examples: 1281167
batch_size: 256
use_per_replica_batch_size: True
dtype: 'float16'
mean_subtract: True
standardize: True
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'tfds'
split: 'validation'
image_size: 224
num_classes: 1000
num_examples: 50000
batch_size: 256
use_per_replica_batch_size: True
dtype: 'float16'
mean_subtract: True
standardize: True
model:
name: 'resnet'
model_params:
rescale_inputs: False
optimizer:
name: 'momentum'
momentum: 0.9
decay: 0.9
epsilon: 0.001
loss:
label_smoothing: 0.1
train:
resume_checkpoint: True
epochs: 90
evaluation:
epochs_between_evals: 1
# Training configuration for ResNet trained on ImageNet on TPUs.
# Takes ~4 minutes, 30 seconds seconds per epoch for a v3-32.
# Reaches > 76.1% within 90 epochs.
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
runtime:
distribution_strategy: 'tpu'
train_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'tfds'
split: 'train'
one_hot: False
image_size: 224
num_classes: 1000
num_examples: 1281167
batch_size: 128
use_per_replica_batch_size: True
mean_subtract: False
standardize: False
dtype: 'bfloat16'
validation_dataset:
name: 'imagenet2012'
data_dir: null
builder: 'tfds'
split: 'validation'
one_hot: False
image_size: 224
num_classes: 1000
num_examples: 50000
batch_size: 128
use_per_replica_batch_size: True
mean_subtract: False
standardize: False
dtype: 'bfloat16'
model:
name: 'resnet'
model_params:
rescale_inputs: True
optimizer:
name: 'momentum'
momentum: 0.9
decay: 0.9
epsilon: 0.001
moving_average_decay: 0.
lookahead: False
loss:
label_smoothing: 0.1
train:
callbacks:
enable_checkpoint_and_export: True
resume_checkpoint: True
epochs: 90
set_epoch_loop: True
evaluation:
epochs_between_evals: 1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Dataset utilities for vision tasks using TFDS and tf.data.Dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from typing import Any, List, Optional, Tuple, Mapping, Union
from absl import logging
from dataclasses import dataclass
import tensorflow as tf
import tensorflow_datasets as tfds
from official.modeling.hyperparams import base_config
from official.vision.image_classification import augment
from official.vision.image_classification import preprocessing
AUGMENTERS = {
'autoaugment': augment.AutoAugment,
'randaugment': augment.RandAugment,
}
@dataclass
class AugmentConfig(base_config.Config):
"""Configuration for image augmenters.
Attributes:
name: The name of the image augmentation to use. Possible options are None
(default), 'autoaugment', or 'randaugment'.
params: Any paramaters used to initialize the augmenter.
"""
name: Optional[str] = None
params: Optional[Mapping[str, Any]] = None
def build(self) -> augment.ImageAugment:
"""Build the augmenter using this config."""
params = self.params or {}
augmenter = AUGMENTERS.get(self.name, None)
return augmenter(**params) if augmenter is not None else None
@dataclass
class DatasetConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
name: The name of the Dataset. Usually should correspond to a TFDS dataset.
data_dir: The path where the dataset files are stored, if available.
filenames: Optional list of strings representing the TFRecord names.
builder: The builder type used to load the dataset. Value should be one of
'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic'
(generate dummy synthetic data without reading from files).
split: The split of the dataset. Usually 'train', 'validation', or 'test'.
image_size: The size of the image in the dataset. This assumes that `width`
== `height`. Set to 'infer' to infer the image size from TFDS info. This
requires `name` to be a registered dataset in TFDS.
num_classes: The number of classes given by the dataset. Set to 'infer' to
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
num_channels: The number of channels given by the dataset. Set to 'infer' to
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
num_examples: The number of examples given by the dataset. Set to 'infer' to
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
batch_size: The base batch size for the dataset.
use_per_replica_batch_size: Whether to scale the batch size based on
available resources. If set to `True`, the dataset builder will return
batch_size multiplied by `num_devices`, the number of device replicas
(e.g., the number of GPUs or TPU cores). This setting should be `True` if
the strategy argument is passed to `build()` and `num_devices > 1`.
num_devices: The number of replica devices to use. This should be set by
`strategy.num_replicas_in_sync` when using a distribution strategy.
dtype: The desired dtype of the dataset. This will be set during
preprocessing.
one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
label smoothing.
augmenter: The augmenter config to use. No augmentation is used by default.
download: Whether to download data using TFDS.
shuffle_buffer_size: The buffer size used for shuffling training data.
file_shuffle_buffer_size: The buffer size used for shuffling raw training
files.
skip_decoding: Whether to skip image decoding when loading from TFDS.
cache: whether to cache to dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead.
tf_data_service: The URI of a tf.data service to offload preprocessing onto
during training. The URI should be in the format "protocol://address",
e.g. "grpc://tf-data-service:5050".
mean_subtract: whether or not to apply mean subtraction to the dataset.
standardize: whether or not to apply standardization to the dataset.
"""
name: Optional[str] = None
data_dir: Optional[str] = None
filenames: Optional[List[str]] = None
builder: str = 'tfds'
split: str = 'train'
image_size: Union[int, str] = 'infer'
num_classes: Union[int, str] = 'infer'
num_channels: Union[int, str] = 'infer'
num_examples: Union[int, str] = 'infer'
batch_size: int = 128
use_per_replica_batch_size: bool = True
num_devices: int = 1
dtype: str = 'float32'
one_hot: bool = True
augmenter: AugmentConfig = AugmentConfig()
download: bool = False
shuffle_buffer_size: int = 10000
file_shuffle_buffer_size: int = 1024
skip_decoding: bool = True
cache: bool = False
tf_data_service: Optional[str] = None
mean_subtract: bool = False
standardize: bool = False
@property
def has_data(self):
"""Whether this dataset is has any data associated with it."""
return self.name or self.data_dir or self.filenames
@dataclass
class ImageNetConfig(DatasetConfig):
"""The base ImageNet dataset config."""
name: str = 'imagenet2012'
# Note: for large datasets like ImageNet, using records is faster than tfds
builder: str = 'records'
image_size: int = 224
num_channels: int = 3
num_examples: int = 1281167
num_classes: int = 1000
batch_size: int = 128
@dataclass
class Cifar10Config(DatasetConfig):
"""The base CIFAR-10 dataset config."""
name: str = 'cifar10'
image_size: int = 224
batch_size: int = 128
download: bool = True
cache: bool = True
class DatasetBuilder:
"""An object for building datasets.
Allows building various pipelines fetching examples, preprocessing, etc.
Maintains additional state information calculated from the dataset, i.e.,
training set split, batch size, and number of steps (batches).
"""
def __init__(self, config: DatasetConfig, **overrides: Any):
"""Initialize the builder from the config."""
self.config = config.replace(**overrides)
self.builder_info = None
if self.config.augmenter is not None:
logging.info('Using augmentation: %s', self.config.augmenter.name)
self.augmenter = self.config.augmenter.build()
else:
self.augmenter = None
@property
def is_training(self) -> bool:
"""Whether this is the training set."""
return self.config.split == 'train'
@property
def batch_size(self) -> int:
"""The batch size, multiplied by the number of replicas (if configured)."""
if self.config.use_per_replica_batch_size:
return self.config.batch_size * self.config.num_devices
else:
return self.config.batch_size
@property
def global_batch_size(self):
"""The global batch size across all replicas."""
return self.batch_size
@property
def local_batch_size(self):
"""The base unscaled batch size."""
if self.config.use_per_replica_batch_size:
return self.config.batch_size
else:
return self.config.batch_size // self.config.num_devices
@property
def num_steps(self) -> int:
"""The number of steps (batches) to exhaust this dataset."""
# Always divide by the global batch size to get the correct # of steps
return self.num_examples // self.global_batch_size
@property
def dtype(self) -> tf.dtypes.DType:
"""Converts the config's dtype string to a tf dtype.
Returns:
A mapping from string representation of a dtype to the `tf.dtypes.DType`.
Raises:
ValueError if the config's dtype is not supported.
"""
dtype_map = {
'float32': tf.float32,
'bfloat16': tf.bfloat16,
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
try:
return dtype_map[self.config.dtype]
except:
raise ValueError('Invalid DType provided. Supported types: {}'.format(
dtype_map.keys()))
@property
def image_size(self) -> int:
"""The size of each image (can be inferred from the dataset)."""
if self.config.image_size == 'infer':
return self.info.features['image'].shape[0]
else:
return int(self.config.image_size)
@property
def num_channels(self) -> int:
"""The number of image channels (can be inferred from the dataset)."""
if self.config.num_channels == 'infer':
return self.info.features['image'].shape[-1]
else:
return int(self.config.num_channels)
@property
def num_examples(self) -> int:
"""The number of examples (can be inferred from the dataset)."""
if self.config.num_examples == 'infer':
return self.info.splits[self.config.split].num_examples
else:
return int(self.config.num_examples)
@property
def num_classes(self) -> int:
"""The number of classes (can be inferred from the dataset)."""
if self.config.num_classes == 'infer':
return self.info.features['label'].num_classes
else:
return int(self.config.num_classes)
@property
def info(self) -> tfds.core.DatasetInfo:
"""The TFDS dataset info, if available."""
try:
if self.builder_info is None:
self.builder_info = tfds.builder(self.config.name).info
except ConnectionError as e:
logging.error('Failed to use TFDS to load info. Please set dataset info '
'(image_size, num_channels, num_examples, num_classes) in '
'the dataset config.')
raise e
return self.builder_info
def build(
self,
strategy: Optional[tf.distribute.Strategy] = None) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it using an optional strategy.
Args:
strategy: a strategy that, if passed, will distribute the dataset
according to that strategy. If passed and `num_devices > 1`,
`use_per_replica_batch_size` must be set to `True`.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if strategy:
if strategy.num_replicas_in_sync != self.config.num_devices:
logging.warn(
'Passed a strategy with %d devices, but expected'
'%d devices.', strategy.num_replicas_in_sync,
self.config.num_devices)
dataset = strategy.distribute_datasets_from_function(self._build)
else:
dataset = self._build()
return dataset
def _build(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it.
Args:
input_context: An optional context provided by `tf.distribute` for
cross-replica training.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
builders = {
'tfds': self.load_tfds,
'records': self.load_records,
'synthetic': self.load_synthetic,
}
builder = builders.get(self.config.builder, None)
if builder is None:
raise ValueError('Unknown builder type {}'.format(self.config.builder))
self.input_context = input_context
dataset = builder()
dataset = self.pipeline(dataset)
return dataset
def load_tfds(self) -> tf.data.Dataset:
"""Return a dataset loading files from TFDS."""
logging.info('Using TFDS to load data.')
builder = tfds.builder(self.config.name, data_dir=self.config.data_dir)
if self.config.download:
builder.download_and_prepare()
decoders = {}
if self.config.skip_decoding:
decoders['image'] = tfds.decode.SkipDecoding()
read_config = tfds.ReadConfig(
interleave_cycle_length=10,
interleave_block_length=1,
input_context=self.input_context)
dataset = builder.as_dataset(
split=self.config.split,
as_supervised=True,
shuffle_files=True,
decoders=decoders,
read_config=read_config)
return dataset
def load_records(self) -> tf.data.Dataset:
"""Return a dataset loading files with TFRecords."""
logging.info('Using TFRecords to load data.')
if self.config.filenames is None:
if self.config.data_dir is None:
raise ValueError('Dataset must specify a path for the data files.')
file_pattern = os.path.join(self.config.data_dir,
'{}*'.format(self.config.split))
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)
else:
dataset = tf.data.Dataset.from_tensor_slices(self.config.filenames)
return dataset
def load_synthetic(self) -> tf.data.Dataset:
"""Return a dataset generating dummy synthetic data."""
logging.info('Generating a synthetic dataset.')
def generate_data(_):
image = tf.zeros([self.image_size, self.image_size, self.num_channels],
dtype=self.dtype)
label = tf.zeros([1], dtype=tf.int32)
return image, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
"""Build a pipeline fetching, shuffling, and preprocessing the dataset.
Args:
dataset: A `tf.data.Dataset` that loads raw files.
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if (self.config.builder != 'tfds' and self.input_context and
self.input_context.num_input_pipelines > 1):
dataset = dataset.shard(self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id)
logging.info(
'Sharding the dataset: input_pipeline_id=%d '
'num_input_pipelines=%d', self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id)
if self.is_training and self.config.builder == 'records':
# Shuffle the input files.
dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size)
if self.is_training and not self.config.cache:
dataset = dataset.repeat()
if self.config.builder == 'records':
# Read the data from disk in parallel
dataset = dataset.interleave(
tf.data.TFRecordDataset,
cycle_length=10,
block_length=1,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self.config.cache:
dataset = dataset.cache()
if self.is_training:
dataset = dataset.shuffle(self.config.shuffle_buffer_size)
dataset = dataset.repeat()
# Parse, pre-process, and batch the data in parallel
if self.config.builder == 'records':
preprocess = self.parse_record
else:
preprocess = self.preprocess
dataset = dataset.map(
preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self.input_context and self.config.num_devices > 1:
if not self.config.use_per_replica_batch_size:
raise ValueError(
'The builder does not support a global batch size with more than '
'one replica. Got {} replicas. Please set a '
'`per_replica_batch_size` and enable '
'`use_per_replica_batch_size=True`.'.format(
self.config.num_devices))
# The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here.
dataset = dataset.batch(
self.local_batch_size, drop_remainder=self.is_training)
else:
dataset = dataset.batch(
self.global_batch_size, drop_remainder=self.is_training)
# Prefetch overlaps in-feed with training
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
if self.config.tf_data_service:
if not hasattr(tf.data.experimental, 'service'):
raise ValueError('The tf_data_service flag requires Tensorflow version '
'>= 2.3.0, but the version is {}'.format(
tf.__version__))
dataset = dataset.apply(
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self.config.tf_data_service,
job_name='resnet_train'))
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
def parse_record(self, record: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Parse an ImageNet record from a serialized string Tensor."""
keys_to_features = {
'image/encoded': tf.io.FixedLenFeature((), tf.string, ''),
'image/format': tf.io.FixedLenFeature((), tf.string, 'jpeg'),
'image/class/label': tf.io.FixedLenFeature([], tf.int64, -1),
'image/class/text': tf.io.FixedLenFeature([], tf.string, ''),
'image/object/bbox/xmin': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/class/label': tf.io.VarLenFeature(dtype=tf.int64),
}
parsed = tf.io.parse_single_example(record, keys_to_features)
label = tf.reshape(parsed['image/class/label'], shape=[1])
# Subtract one so that labels are in [0, 1000)
label -= 1
image_bytes = tf.reshape(parsed['image/encoded'], shape=[])
image, label = self.preprocess(image_bytes, label)
return image, label
def preprocess(self, image: tf.Tensor,
label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Apply image preprocessing and augmentation to the image and label."""
if self.is_training:
image = preprocessing.preprocess_for_train(
image,
image_size=self.image_size,
mean_subtract=self.config.mean_subtract,
standardize=self.config.standardize,
dtype=self.dtype,
augmenter=self.augmenter)
else:
image = preprocessing.preprocess_for_eval(
image,
image_size=self.image_size,
num_channels=self.num_channels,
mean_subtract=self.config.mean_subtract,
standardize=self.config.standardize,
dtype=self.dtype)
label = tf.cast(label, tf.int32)
if self.config.one_hot:
label = tf.one_hot(label, self.num_classes)
label = tf.reshape(label, [self.num_classes])
return image, label
@classmethod
def from_params(cls, *args, **kwargs):
"""Construct a dataset builder from a default config and any overrides."""
config = DatasetConfig.from_args(*args, **kwargs)
return cls(config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common modeling utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
from typing import Text, Optional
from tensorflow.python.tpu import tpu_function
@tf.keras.utils.register_keras_serializable(package='Vision')
class TpuBatchNormalization(tf.keras.layers.BatchNormalization):
"""Cross replica batch normalization."""
def __init__(self, fused: Optional[bool] = False, **kwargs):
if fused in (True, None):
raise ValueError('TpuBatchNormalization does not support fused=True.')
super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs)
def _cross_replica_average(self, t: tf.Tensor, num_shards_per_group: int):
"""Calculates the average value of input tensor across TPU replicas."""
num_shards = tpu_function.get_tpu_context().number_of_shards
group_assignment = None
if num_shards_per_group > 1:
if num_shards % num_shards_per_group != 0:
raise ValueError(
'num_shards: %d mod shards_per_group: %d, should be 0' %
(num_shards, num_shards_per_group))
num_groups = num_shards // num_shards_per_group
group_assignment = [[
x for x in range(num_shards) if x // num_shards_per_group == y
] for y in range(num_groups)]
return tf1.tpu.cross_replica_sum(t, group_assignment) / tf.cast(
num_shards_per_group, t.dtype)
def _moments(self, inputs: tf.Tensor, reduction_axes: int, keep_dims: int):
"""Compute the mean and variance: it overrides the original _moments."""
shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
inputs, reduction_axes, keep_dims=keep_dims)
num_shards = tpu_function.get_tpu_context().number_of_shards or 1
if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices.
num_shards_per_group = 1
else:
num_shards_per_group = max(8, num_shards // 8)
if num_shards_per_group > 1:
# Compute variance using: Var[X]= E[X^2] - E[X]^2.
shard_square_of_mean = tf.math.square(shard_mean)
shard_mean_of_square = shard_variance + shard_square_of_mean
group_mean = self._cross_replica_average(shard_mean, num_shards_per_group)
group_mean_of_square = self._cross_replica_average(
shard_mean_of_square, num_shards_per_group)
group_variance = group_mean_of_square - tf.math.square(group_mean)
return (group_mean, group_variance)
else:
return (shard_mean, shard_variance)
def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization:
"""A helper to create a batch normalization getter.
Args:
batch_norm_type: The type of batch normalization layer implementation. `tpu`
will use `TpuBatchNormalization`.
Returns:
An instance of `tf.keras.layers.BatchNormalization`.
"""
if batch_norm_type == 'tpu':
return TpuBatchNormalization
return tf.keras.layers.BatchNormalization # pytype: disable=bad-return-type # typed-keras
def count_params(model, trainable_only=True):
"""Returns the count of all model parameters, or just trainable ones."""
if not trainable_only:
return model.count_params()
else:
return int(
np.sum([
tf.keras.backend.count_params(p) for p in model.trainable_weights
]))
def load_weights(model: tf.keras.Model,
model_weights_path: Text,
weights_format: Text = 'saved_model'):
"""Load model weights from the given file path.
Args:
model: the model to load weights into
model_weights_path: the path of the model weights
weights_format: the model weights format. One of 'saved_model', 'h5', or
'checkpoint'.
"""
if weights_format == 'saved_model':
loaded_model = tf.keras.models.load_model(model_weights_path)
model.set_weights(loaded_model.get_weights())
else:
model.load_weights(model_weights_path)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configuration definitions for EfficientNet losses, learning rates, and optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Mapping
import dataclasses
from official.modeling.hyperparams import base_config
from official.vision.image_classification.configs import base_configs
@dataclasses.dataclass
class EfficientNetModelConfig(base_configs.ModelConfig):
"""Configuration for the EfficientNet model.
This configuration will default to settings used for training efficientnet-b0
on a v3-8 TPU on ImageNet.
Attributes:
name: The name of the model. Defaults to 'EfficientNet'.
num_classes: The number of classes in the model.
model_params: A dictionary that represents the parameters of the
EfficientNet model. These will be passed in to the "from_name" function.
loss: The configuration for loss. Defaults to a categorical cross entropy
implementation.
optimizer: The configuration for optimizations. Defaults to an RMSProp
configuration.
learning_rate: The configuration for learning rate. Defaults to an
exponential configuration.
"""
name: str = 'EfficientNet'
num_classes: int = 1000
model_params: base_config.Config = dataclasses.field(
default_factory=lambda: {
'model_name': 'efficientnet-b0',
'model_weights_path': '',
'weights_format': 'saved_model',
'overrides': {
'batch_norm': 'default',
'rescale_input': True,
'num_classes': 1000,
'activation': 'swish',
'dtype': 'float32',
}
})
loss: base_configs.LossConfig = base_configs.LossConfig(
name='categorical_crossentropy', label_smoothing=0.1)
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
name='rmsprop',
decay=0.9,
epsilon=0.001,
momentum=0.9,
moving_average_decay=None)
learning_rate: base_configs.LearningRateConfig = base_configs.LearningRateConfig( # pylint: disable=line-too-long
name='exponential',
initial_lr=0.008,
decay_epochs=2.4,
decay_rate=0.97,
warmup_epochs=5,
scale_by_batch_size=1. / 128.,
staircase=True)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Contains definitions for EfficientNet model.
[1] Mingxing Tan, Quoc V. Le
EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.
ICML'19, https://arxiv.org/abs/1905.11946
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
from typing import Any, Dict, Optional, Text, Tuple
from absl import logging
from dataclasses import dataclass
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.vision.image_classification import preprocessing
from official.vision.image_classification.efficientnet import common_modules
@dataclass
class BlockConfig(base_config.Config):
"""Config for a single MB Conv Block."""
input_filters: int = 0
output_filters: int = 0
kernel_size: int = 3
num_repeat: int = 1
expand_ratio: int = 1
strides: Tuple[int, int] = (1, 1)
se_ratio: Optional[float] = None
id_skip: bool = True
fused_conv: bool = False
conv_type: str = 'depthwise'
@dataclass
class ModelConfig(base_config.Config):
"""Default Config for Efficientnet-B0."""
width_coefficient: float = 1.0
depth_coefficient: float = 1.0
resolution: int = 224
dropout_rate: float = 0.2
blocks: Tuple[BlockConfig, ...] = (
# (input_filters, output_filters, kernel_size, num_repeat,
# expand_ratio, strides, se_ratio)
# pylint: disable=bad-whitespace
BlockConfig.from_args(32, 16, 3, 1, 1, (1, 1), 0.25),
BlockConfig.from_args(16, 24, 3, 2, 6, (2, 2), 0.25),
BlockConfig.from_args(24, 40, 5, 2, 6, (2, 2), 0.25),
BlockConfig.from_args(40, 80, 3, 3, 6, (2, 2), 0.25),
BlockConfig.from_args(80, 112, 5, 3, 6, (1, 1), 0.25),
BlockConfig.from_args(112, 192, 5, 4, 6, (2, 2), 0.25),
BlockConfig.from_args(192, 320, 3, 1, 6, (1, 1), 0.25),
# pylint: enable=bad-whitespace
)
stem_base_filters: int = 32
top_base_filters: int = 1280
activation: str = 'simple_swish'
batch_norm: str = 'default'
bn_momentum: float = 0.99
bn_epsilon: float = 1e-3
# While the original implementation used a weight decay of 1e-5,
# tf.nn.l2_loss divides it by 2, so we halve this to compensate in Keras
weight_decay: float = 5e-6
drop_connect_rate: float = 0.2
depth_divisor: int = 8
min_depth: Optional[int] = None
use_se: bool = True
input_channels: int = 3
num_classes: int = 1000
model_name: str = 'efficientnet'
rescale_input: bool = True
data_format: str = 'channels_last'
dtype: str = 'float32'
MODEL_CONFIGS = {
# (width, depth, resolution, dropout)
'efficientnet-b0': ModelConfig.from_args(1.0, 1.0, 224, 0.2),
'efficientnet-b1': ModelConfig.from_args(1.0, 1.1, 240, 0.2),
'efficientnet-b2': ModelConfig.from_args(1.1, 1.2, 260, 0.3),
'efficientnet-b3': ModelConfig.from_args(1.2, 1.4, 300, 0.3),
'efficientnet-b4': ModelConfig.from_args(1.4, 1.8, 380, 0.4),
'efficientnet-b5': ModelConfig.from_args(1.6, 2.2, 456, 0.4),
'efficientnet-b6': ModelConfig.from_args(1.8, 2.6, 528, 0.5),
'efficientnet-b7': ModelConfig.from_args(2.0, 3.1, 600, 0.5),
'efficientnet-b8': ModelConfig.from_args(2.2, 3.6, 672, 0.5),
'efficientnet-l2': ModelConfig.from_args(4.3, 5.3, 800, 0.5),
}
CONV_KERNEL_INITIALIZER = {
'class_name': 'VarianceScaling',
'config': {
'scale': 2.0,
'mode': 'fan_out',
# Note: this is a truncated normal distribution
'distribution': 'normal'
}
}
DENSE_KERNEL_INITIALIZER = {
'class_name': 'VarianceScaling',
'config': {
'scale': 1 / 3.0,
'mode': 'fan_out',
'distribution': 'uniform'
}
}
def round_filters(filters: int, config: ModelConfig) -> int:
"""Round number of filters based on width coefficient."""
width_coefficient = config.width_coefficient
min_depth = config.min_depth
divisor = config.depth_divisor
orig_filters = filters
if not width_coefficient:
return filters
filters *= width_coefficient
min_depth = min_depth or divisor
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_filters < 0.9 * filters:
new_filters += divisor
logging.info('round_filter input=%s output=%s', orig_filters, new_filters)
return int(new_filters)
def round_repeats(repeats: int, depth_coefficient: float) -> int:
"""Round number of repeats based on depth coefficient."""
return int(math.ceil(depth_coefficient * repeats))
def conv2d_block(inputs: tf.Tensor,
conv_filters: Optional[int],
config: ModelConfig,
kernel_size: Any = (1, 1),
strides: Any = (1, 1),
use_batch_norm: bool = True,
use_bias: bool = False,
activation: Optional[Any] = None,
depthwise: bool = False,
name: Optional[Text] = None):
"""A conv2d followed by batch norm and an activation."""
batch_norm = common_modules.get_batch_norm(config.batch_norm)
bn_momentum = config.bn_momentum
bn_epsilon = config.bn_epsilon
data_format = tf.keras.backend.image_data_format()
weight_decay = config.weight_decay
name = name or ''
# Collect args based on what kind of conv2d block is desired
init_kwargs = {
'kernel_size': kernel_size,
'strides': strides,
'use_bias': use_bias,
'padding': 'same',
'name': name + '_conv2d',
'kernel_regularizer': tf.keras.regularizers.l2(weight_decay),
'bias_regularizer': tf.keras.regularizers.l2(weight_decay),
}
if depthwise:
conv2d = tf.keras.layers.DepthwiseConv2D
init_kwargs.update({'depthwise_initializer': CONV_KERNEL_INITIALIZER})
else:
conv2d = tf.keras.layers.Conv2D
init_kwargs.update({
'filters': conv_filters,
'kernel_initializer': CONV_KERNEL_INITIALIZER
})
x = conv2d(**init_kwargs)(inputs)
if use_batch_norm:
bn_axis = 1 if data_format == 'channels_first' else -1
x = batch_norm(
axis=bn_axis,
momentum=bn_momentum,
epsilon=bn_epsilon,
name=name + '_bn')(
x)
if activation is not None:
x = tf.keras.layers.Activation(activation, name=name + '_activation')(x)
return x
def mb_conv_block(inputs: tf.Tensor,
block: BlockConfig,
config: ModelConfig,
prefix: Optional[Text] = None):
"""Mobile Inverted Residual Bottleneck.
Args:
inputs: the Keras input to the block
block: BlockConfig, arguments to create a Block
config: ModelConfig, a set of model parameters
prefix: prefix for naming all layers
Returns:
the output of the block
"""
use_se = config.use_se
activation = tf_utils.get_activation(config.activation)
drop_connect_rate = config.drop_connect_rate
data_format = tf.keras.backend.image_data_format()
use_depthwise = block.conv_type != 'no_depthwise'
prefix = prefix or ''
filters = block.input_filters * block.expand_ratio
x = inputs
if block.fused_conv:
# If we use fused mbconv, skip expansion and use regular conv.
x = conv2d_block(
x,
filters,
config,
kernel_size=block.kernel_size,
strides=block.strides,
activation=activation,
name=prefix + 'fused')
else:
if block.expand_ratio != 1:
# Expansion phase
kernel_size = (1, 1) if use_depthwise else (3, 3)
x = conv2d_block(
x,
filters,
config,
kernel_size=kernel_size,
activation=activation,
name=prefix + 'expand')
# Depthwise Convolution
if use_depthwise:
x = conv2d_block(
x,
conv_filters=None,
config=config,
kernel_size=block.kernel_size,
strides=block.strides,
activation=activation,
depthwise=True,
name=prefix + 'depthwise')
# Squeeze and Excitation phase
if use_se:
assert block.se_ratio is not None
assert 0 < block.se_ratio <= 1
num_reduced_filters = max(1, int(block.input_filters * block.se_ratio))
if data_format == 'channels_first':
se_shape = (filters, 1, 1)
else:
se_shape = (1, 1, filters)
se = tf.keras.layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x)
se = tf.keras.layers.Reshape(se_shape, name=prefix + 'se_reshape')(se)
se = conv2d_block(
se,
num_reduced_filters,
config,
use_bias=True,
use_batch_norm=False,
activation=activation,
name=prefix + 'se_reduce')
se = conv2d_block(
se,
filters,
config,
use_bias=True,
use_batch_norm=False,
activation='sigmoid',
name=prefix + 'se_expand')
x = tf.keras.layers.multiply([x, se], name=prefix + 'se_excite')
# Output phase
x = conv2d_block(
x, block.output_filters, config, activation=None, name=prefix + 'project')
# Add identity so that quantization-aware training can insert quantization
# ops correctly.
x = tf.keras.layers.Activation(
tf_utils.get_activation('identity'), name=prefix + 'id')(
x)
if (block.id_skip and all(s == 1 for s in block.strides) and
block.input_filters == block.output_filters):
if drop_connect_rate and drop_connect_rate > 0:
# Apply dropconnect
# The only difference between dropout and dropconnect in TF is scaling by
# drop_connect_rate during training. See:
# https://github.com/keras-team/keras/pull/9898#issuecomment-380577612
x = tf.keras.layers.Dropout(
drop_connect_rate, noise_shape=(None, 1, 1, 1), name=prefix + 'drop')(
x)
x = tf.keras.layers.add([x, inputs], name=prefix + 'add')
return x
def efficientnet(image_input: tf.keras.layers.Input, config: ModelConfig): # pytype: disable=invalid-annotation # typed-keras
"""Creates an EfficientNet graph given the model parameters.
This function is wrapped by the `EfficientNet` class to make a tf.keras.Model.
Args:
image_input: the input batch of images
config: the model config
Returns:
the output of efficientnet
"""
depth_coefficient = config.depth_coefficient
blocks = config.blocks
stem_base_filters = config.stem_base_filters
top_base_filters = config.top_base_filters
activation = tf_utils.get_activation(config.activation)
dropout_rate = config.dropout_rate
drop_connect_rate = config.drop_connect_rate
num_classes = config.num_classes
input_channels = config.input_channels
rescale_input = config.rescale_input
data_format = tf.keras.backend.image_data_format()
dtype = config.dtype
weight_decay = config.weight_decay
x = image_input
if data_format == 'channels_first':
# Happens on GPU/TPU if available.
x = tf.keras.layers.Permute((3, 1, 2))(x)
if rescale_input:
x = preprocessing.normalize_images(
x, num_channels=input_channels, dtype=dtype, data_format=data_format)
# Build stem
x = conv2d_block(
x,
round_filters(stem_base_filters, config),
config,
kernel_size=[3, 3],
strides=[2, 2],
activation=activation,
name='stem')
# Build blocks
num_blocks_total = sum(
round_repeats(block.num_repeat, depth_coefficient) for block in blocks)
block_num = 0
for stack_idx, block in enumerate(blocks):
assert block.num_repeat > 0
# Update block input and output filters based on depth multiplier
block = block.replace(
input_filters=round_filters(block.input_filters, config),
output_filters=round_filters(block.output_filters, config),
num_repeat=round_repeats(block.num_repeat, depth_coefficient))
# The first block needs to take care of stride and filter size increase
drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
config = config.replace(drop_connect_rate=drop_rate)
block_prefix = 'stack_{}/block_0/'.format(stack_idx)
x = mb_conv_block(x, block, config, block_prefix)
block_num += 1
if block.num_repeat > 1:
block = block.replace(input_filters=block.output_filters, strides=[1, 1])
for block_idx in range(block.num_repeat - 1):
drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
config = config.replace(drop_connect_rate=drop_rate)
block_prefix = 'stack_{}/block_{}/'.format(stack_idx, block_idx + 1)
x = mb_conv_block(x, block, config, prefix=block_prefix)
block_num += 1
# Build top
x = conv2d_block(
x,
round_filters(top_base_filters, config),
config,
activation=activation,
name='top')
# Build classifier
x = tf.keras.layers.GlobalAveragePooling2D(name='top_pool')(x)
if dropout_rate and dropout_rate > 0:
x = tf.keras.layers.Dropout(dropout_rate, name='top_dropout')(x)
x = tf.keras.layers.Dense(
num_classes,
kernel_initializer=DENSE_KERNEL_INITIALIZER,
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay),
name='logits')(
x)
x = tf.keras.layers.Activation('softmax', name='probs')(x)
return x
class EfficientNet(tf.keras.Model):
"""Wrapper class for an EfficientNet Keras model.
Contains helper methods to build, manage, and save metadata about the model.
"""
def __init__(self,
config: Optional[ModelConfig] = None,
overrides: Optional[Dict[Text, Any]] = None):
"""Create an EfficientNet model.
Args:
config: (optional) the main model parameters to create the model
overrides: (optional) a dict containing keys that can override config
"""
overrides = overrides or {}
config = config or ModelConfig()
self.config = config.replace(**overrides)
input_channels = self.config.input_channels
model_name = self.config.model_name
input_shape = (None, None, input_channels) # Should handle any size image
image_input = tf.keras.layers.Input(shape=input_shape)
output = efficientnet(image_input, self.config)
# Cast to float32 in case we have a different model dtype
output = tf.cast(output, tf.float32)
logging.info('Building model %s with params %s', model_name, self.config)
super(EfficientNet, self).__init__(
inputs=image_input, outputs=output, name=model_name)
@classmethod
def from_name(cls,
model_name: Text,
model_weights_path: Optional[Text] = None,
weights_format: Text = 'saved_model',
overrides: Optional[Dict[Text, Any]] = None):
"""Construct an EfficientNet model from a predefined model name.
E.g., `EfficientNet.from_name('efficientnet-b0')`.
Args:
model_name: the predefined model name
model_weights_path: the path to the weights (h5 file or saved model dir)
weights_format: the model weights format. One of 'saved_model', 'h5', or
'checkpoint'.
overrides: (optional) a dict containing keys that can override config
Returns:
A constructed EfficientNet instance.
"""
model_configs = dict(MODEL_CONFIGS)
overrides = dict(overrides) if overrides else {}
# One can define their own custom models if necessary
model_configs.update(overrides.pop('model_config', {}))
if model_name not in model_configs:
raise ValueError('Unknown model name {}'.format(model_name))
config = model_configs[model_name]
model = cls(config=config, overrides=overrides)
if model_weights_path:
common_modules.load_weights(
model, model_weights_path, weights_format=weights_format)
return model
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to export TF-Hub SavedModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
from official.vision.image_classification.efficientnet import efficientnet_model
FLAGS = flags.FLAGS
flags.DEFINE_string("model_name", None, "EfficientNet model name.")
flags.DEFINE_string("model_path", None, "File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None,
"TF-Hub SavedModel destination path to export.")
def export_tfhub(model_path, hub_destination, model_name):
"""Restores a tf.keras.Model and saves for TF-Hub."""
model_configs = dict(efficientnet_model.MODEL_CONFIGS)
config = model_configs[model_name]
image_input = tf.keras.layers.Input(
shape=(None, None, 3), name="image_input", dtype=tf.float32)
x = image_input * 255.0
ouputs = efficientnet_model.efficientnet(x, config)
hub_model = tf.keras.Model(image_input, ouputs)
ckpt = tf.train.Checkpoint(model=hub_model)
ckpt.restore(model_path).assert_existing_objects_matched()
hub_model.save(
os.path.join(hub_destination, "classification"), include_optimizer=False)
feature_vector_output = hub_model.get_layer(name="top_pool").get_output_at(0)
hub_model2 = tf.keras.Model(image_input, feature_vector_output)
hub_model2.save(
os.path.join(hub_destination, "feature-vector"), include_optimizer=False)
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
export_tfhub(FLAGS.model_path, FLAGS.export_path, FLAGS.model_name)
if __name__ == "__main__":
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Learning rate utilities for vision tasks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Mapping, Optional
import numpy as np
import tensorflow as tf
BASE_LEARNING_RATE = 0.1
class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
"""A wrapper for LearningRateSchedule that includes warmup steps."""
def __init__(self,
lr_schedule: tf.keras.optimizers.schedules.LearningRateSchedule,
warmup_steps: int,
warmup_lr: Optional[float] = None):
"""Add warmup decay to a learning rate schedule.
Args:
lr_schedule: base learning rate scheduler
warmup_steps: number of warmup steps
warmup_lr: an optional field for the final warmup learning rate. This
should be provided if the base `lr_schedule` does not contain this
field.
"""
super(WarmupDecaySchedule, self).__init__()
self._lr_schedule = lr_schedule
self._warmup_steps = warmup_steps
self._warmup_lr = warmup_lr
def __call__(self, step: int):
lr = self._lr_schedule(step)
if self._warmup_steps:
if self._warmup_lr is not None:
initial_learning_rate = tf.convert_to_tensor(
self._warmup_lr, name="initial_learning_rate")
else:
initial_learning_rate = tf.convert_to_tensor(
self._lr_schedule.initial_learning_rate,
name="initial_learning_rate")
dtype = initial_learning_rate.dtype
global_step_recomp = tf.cast(step, dtype)
warmup_steps = tf.cast(self._warmup_steps, dtype)
warmup_lr = initial_learning_rate * global_step_recomp / warmup_steps
lr = tf.cond(global_step_recomp < warmup_steps, lambda: warmup_lr,
lambda: lr)
return lr
def get_config(self) -> Mapping[str, Any]:
config = self._lr_schedule.get_config()
config.update({
"warmup_steps": self._warmup_steps,
"warmup_lr": self._warmup_lr,
})
return config
class CosineDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Class to generate learning rate tensor."""
def __init__(self, batch_size: int, total_steps: int, warmup_steps: int):
"""Creates the consine learning rate tensor with linear warmup.
Args:
batch_size: The training batch size used in the experiment.
total_steps: Total training steps.
warmup_steps: Steps for the warm up period.
"""
super(CosineDecayWithWarmup, self).__init__()
base_lr_batch_size = 256
self._total_steps = total_steps
self._init_learning_rate = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
self._warmup_steps = warmup_steps
def __call__(self, global_step: int):
global_step = tf.cast(global_step, dtype=tf.float32)
warmup_steps = self._warmup_steps
init_lr = self._init_learning_rate
total_steps = self._total_steps
linear_warmup = global_step / warmup_steps * init_lr
cosine_learning_rate = init_lr * (tf.cos(np.pi *
(global_step - warmup_steps) /
(total_steps - warmup_steps)) +
1.0) / 2.0
learning_rate = tf.where(global_step < warmup_steps, linear_warmup,
cosine_learning_rate)
return learning_rate
def get_config(self):
return {
"total_steps": self._total_steps,
"warmup_learning_rate": self._warmup_learning_rate,
"warmup_steps": self._warmup_steps,
"init_learning_rate": self._init_learning_rate,
}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for learning_rate."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.vision.image_classification import learning_rate
class LearningRateTests(tf.test.TestCase):
def test_warmup_decay(self):
"""Basic computational test for warmup decay."""
initial_lr = 0.01
decay_steps = 100
decay_rate = 0.01
warmup_steps = 10
base_lr = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=initial_lr,
decay_steps=decay_steps,
decay_rate=decay_rate)
lr = learning_rate.WarmupDecaySchedule(
lr_schedule=base_lr, warmup_steps=warmup_steps)
for step in range(warmup_steps - 1):
config = lr.get_config()
self.assertEqual(config['warmup_steps'], warmup_steps)
self.assertAllClose(
self.evaluate(lr(step)), step / warmup_steps * initial_lr)
def test_cosine_decay_with_warmup(self):
"""Basic computational test for cosine decay with warmup."""
expected_lrs = [0.0, 0.1, 0.05, 0.0]
lr = learning_rate.CosineDecayWithWarmup(
batch_size=256, total_steps=3, warmup_steps=1)
for step in [0, 1, 2, 3]:
self.assertAllClose(lr(step), expected_lrs[step])
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs a simple model on the MNIST dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Import libraries
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
from official.common import distribute_utils
from official.utils.flags import core as flags_core
from official.utils.misc import model_helpers
from official.vision.image_classification.resnet import common
FLAGS = flags.FLAGS
def build_model():
"""Constructs the ML model used to predict handwritten digits."""
image = tf.keras.layers.Input(shape=(28, 28, 1))
y = tf.keras.layers.Conv2D(filters=32,
kernel_size=5,
padding='same',
activation='relu')(image)
y = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
strides=(2, 2),
padding='same')(y)
y = tf.keras.layers.Conv2D(filters=32,
kernel_size=5,
padding='same',
activation='relu')(y)
y = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
strides=(2, 2),
padding='same')(y)
y = tf.keras.layers.Flatten()(y)
y = tf.keras.layers.Dense(1024, activation='relu')(y)
y = tf.keras.layers.Dropout(0.4)(y)
probs = tf.keras.layers.Dense(10, activation='softmax')(y)
model = tf.keras.models.Model(image, probs, name='mnist')
return model
@tfds.decode.make_decoder(output_dtype=tf.float32)
def decode_image(example, feature):
"""Convert image to float32 and normalize from [0, 255] to [0.0, 1.0]."""
return tf.cast(feature.decode_example(example), dtype=tf.float32) / 255
def run(flags_obj, datasets_override=None, strategy_override=None):
"""Run MNIST model training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
datasets_override: A pair of `tf.data.Dataset` objects to train the model,
representing the train and test sets.
strategy_override: A `tf.distribute.Strategy` object to use for model.
Returns:
Dictionary of training and eval stats.
"""
# Start TF profiler server.
tf.profiler.experimental.server.start(flags_obj.profiler_port)
strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
tpu_address=flags_obj.tpu)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
if flags_obj.download:
mnist.download_and_prepare()
mnist_train, mnist_test = datasets_override or mnist.as_dataset(
split=['train', 'test'],
decoders={'image': decode_image()}, # pylint: disable=no-value-for-parameter
as_supervised=True)
train_input_dataset = mnist_train.cache().repeat().shuffle(
buffer_size=50000).batch(flags_obj.batch_size)
eval_input_dataset = mnist_test.cache().repeat().batch(flags_obj.batch_size)
with strategy_scope:
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
0.05, decay_steps=100000, decay_rate=0.96)
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)
model = build_model()
model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
num_train_examples = mnist.info.splits['train'].num_examples
train_steps = num_train_examples // flags_obj.batch_size
train_epochs = flags_obj.train_epochs
ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True),
tf.keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir),
]
num_eval_examples = mnist.info.splits['test'].num_examples
num_eval_steps = num_eval_examples // flags_obj.batch_size
history = model.fit(
train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
callbacks=callbacks,
validation_steps=num_eval_steps,
validation_data=eval_input_dataset,
validation_freq=flags_obj.epochs_between_evals)
export_path = os.path.join(flags_obj.model_dir, 'saved_model')
model.save(export_path, include_optimizer=False)
eval_output = model.evaluate(
eval_input_dataset, steps=num_eval_steps, verbose=2)
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_mnist_flags():
"""Define command line flags for MNIST model."""
flags_core.define_base(
clean=True,
num_gpu=True,
train_epochs=True,
epochs_between_evals=True,
distribution_strategy=True)
flags_core.define_device()
flags_core.define_distribution()
flags.DEFINE_bool('download', True,
'Whether to download data to `--data_dir`.')
flags.DEFINE_integer('profiler_port', 9012,
'Port to start profiler server on.')
FLAGS.set_default('batch_size', 1024)
def main(_):
model_helpers.apply_clean(FLAGS)
stats = run(flags.FLAGS)
logging.info('Run stats:\n%s', stats)
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
define_mnist_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the Keras MNIST model on GPU."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.utils.testing import integration
from official.vision.image_classification import mnist_main
mnist_main.define_mnist_flags()
def eager_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
"""Unit tests for sample Keras MNIST model."""
_tempdir = None
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasMnistTest, cls).setUpClass()
def tearDown(self):
super(KerasMnistTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
@combinations.generate(eager_strategy_combinations())
def test_end_to_end(self, distribution):
"""Test Keras MNIST model with `strategy`."""
extra_flags = [
"-train_epochs",
"1",
# Let TFDS find the metadata folder automatically
"--data_dir="
]
dummy_data = (
tf.ones(shape=(10, 28, 28, 1), dtype=tf.int32),
tf.range(10),
)
datasets = (
tf.data.Dataset.from_tensor_slices(dummy_data),
tf.data.Dataset.from_tensor_slices(dummy_data),
)
run = functools.partial(
mnist_main.run,
datasets_override=datasets,
strategy_override=distribution)
integration.run_synthetic(
main=run,
synth=False,
tmp_root=self.create_tempdir().full_path,
extra_flags=extra_flags)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for optimizer_factory."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from official.vision.image_classification import optimizer_factory
from official.vision.image_classification.configs import base_configs
class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
def build_toy_model(self) -> tf.keras.Model:
"""Creates a toy `tf.Keras.Model`."""
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_shape=(1,)))
return model
@parameterized.named_parameters(
('sgd', 'sgd', 0., False), ('momentum', 'momentum', 0., False),
('rmsprop', 'rmsprop', 0., False), ('adam', 'adam', 0., False),
('adamw', 'adamw', 0., False),
('momentum_lookahead', 'momentum', 0., True),
('sgd_ema', 'sgd', 0.999, False),
('momentum_ema', 'momentum', 0.999, False),
('rmsprop_ema', 'rmsprop', 0.999, False))
def test_optimizer(self, optimizer_name, moving_average_decay, lookahead):
"""Smoke test to be sure no syntax errors."""
model = self.build_toy_model()
params = {
'learning_rate': 0.001,
'rho': 0.09,
'momentum': 0.,
'epsilon': 1e-07,
'moving_average_decay': moving_average_decay,
'lookahead': lookahead,
}
optimizer = optimizer_factory.build_optimizer(
optimizer_name=optimizer_name,
base_learning_rate=params['learning_rate'],
params=params,
model=model)
self.assertTrue(issubclass(type(optimizer), tf.keras.optimizers.Optimizer))
def test_unknown_optimizer(self):
with self.assertRaises(ValueError):
optimizer_factory.build_optimizer(
optimizer_name='this_optimizer_does_not_exist',
base_learning_rate=None,
params=None)
def test_learning_rate_without_decay_or_warmups(self):
params = base_configs.LearningRateConfig(
name='exponential',
initial_lr=0.01,
decay_rate=0.01,
decay_epochs=None,
warmup_epochs=None,
scale_by_batch_size=0.01,
examples_per_epoch=1,
boundaries=[0],
multipliers=[0, 1])
batch_size = 1
train_steps = 1
lr = optimizer_factory.build_learning_rate(
params=params, batch_size=batch_size, train_steps=train_steps)
self.assertTrue(
issubclass(
type(lr), tf.keras.optimizers.schedules.LearningRateSchedule))
@parameterized.named_parameters(('exponential', 'exponential'),
('cosine_with_warmup', 'cosine_with_warmup'))
def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
"""Basic smoke test for syntax."""
params = base_configs.LearningRateConfig(
name=lr_decay_type,
initial_lr=0.01,
decay_rate=0.01,
decay_epochs=1,
warmup_epochs=1,
scale_by_batch_size=0.01,
examples_per_epoch=1,
boundaries=[0],
multipliers=[0, 1])
batch_size = 1
train_epochs = 1
train_steps = 1
lr = optimizer_factory.build_learning_rate(
params=params,
batch_size=batch_size,
train_epochs=train_epochs,
train_steps=train_steps)
self.assertTrue(
issubclass(
type(lr), tf.keras.optimizers.schedules.LearningRateSchedule))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册