提交 1c79ece9 编写于 作者: A A. Unique TensorFlower

Merge pull request #10227 from sigeisler:master

PiperOrigin-RevId: 397161611
......@@ -16,7 +16,7 @@
"""Common configurations."""
import dataclasses
from typing import Optional
from typing import List, Optional
# Import libraries
......@@ -60,7 +60,9 @@ class RandAugment(hyperparams.Config):
magnitude: float = 10
cutout_const: float = 40
translate_const: float = 10
magnitude_std: float = 0.0
prob_to_apply: Optional[float] = None
exclude_ops: List[str] = dataclasses.field(default_factory=list)
@dataclasses.dataclass
......@@ -71,6 +73,29 @@ class AutoAugment(hyperparams.Config):
translate_const: float = 250
@dataclasses.dataclass
class RandomErasing(hyperparams.Config):
"""Configuration for RandomErasing."""
probability: float = 0.25
min_area: float = 0.02
max_area: float = 1 / 3
min_aspect: float = 0.3
max_aspect = None
min_count = 1
max_count = 1
trials = 10
@dataclasses.dataclass
class MixupAndCutmix(hyperparams.Config):
"""Configuration for MixupAndCutmix."""
mixup_alpha: float = .8
cutmix_alpha: float = 1.
prob: float = 1.0
switch_prob: float = 0.5
label_smoothing: float = 0.1
@dataclasses.dataclass
class Augmentation(hyperparams.OneOfConfig):
"""Configuration for input data augmentation.
......
......@@ -39,10 +39,13 @@ class DataConfig(cfg.DataConfig):
aug_rand_hflip: bool = True
aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment.
color_jitter: float = 0.
random_erasing: Optional[common.RandomErasing] = None
file_type: str = 'tfrecord'
image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True
mixup_and_cutmix: Optional[common.MixupAndCutmix] = None
decoder: Optional[common.DataDecoder] = common.DataDecoder()
# Keep for backward compatibility.
......@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
@dataclasses.dataclass
......@@ -69,6 +73,7 @@ class Losses(hyperparams.Config):
one_hot: bool = True
label_smoothing: float = 0.0
l2_weight_decay: float = 0.0
soft_labels: bool = False
@dataclasses.dataclass
......
......@@ -69,6 +69,8 @@ class Parser(parser.Parser):
decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True,
aug_type: Optional[common.Augmentation] = None,
color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None,
is_multilabel: bool = False,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset.
......@@ -85,6 +87,11 @@ class Parser(parser.Parser):
horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
color_jitter: Magnitude of color jitter. If > 0, the value is used to
generate random scale factor for brightness, contrast and saturation.
See `preprocess_ops.color_jitter` for more details.
random_erasing: if not None, augment input image by random erasing. See
`augment.RandomErasing` for more details.
is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
......@@ -113,13 +120,27 @@ class Parser(parser.Parser):
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply)
prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
else:
raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type))
else:
self._augmenter = None
self._label_field_key = label_field_key
self._color_jitter = color_jitter
if random_erasing:
self._random_erasing = augment.RandomErasing(
probability=random_erasing.probability,
min_area=random_erasing.min_area,
max_area=random_erasing.max_area,
min_aspect=random_erasing.min_aspect,
max_aspect=random_erasing.max_aspect,
min_count=random_erasing.min_count,
max_count=random_erasing.max_count,
trials=random_erasing.trials)
else:
self._random_erasing = None
self._is_multilabel = is_multilabel
self._decode_jpeg_only = decode_jpeg_only
......@@ -173,6 +194,12 @@ class Parser(parser.Parser):
if self._aug_rand_hflip:
image = tf.image.random_flip_left_right(image)
# Color jitter.
if self._color_jitter > 0:
image = preprocess_ops.color_jitter(image, self._color_jitter,
self._color_jitter,
self._color_jitter)
# Resizes image.
image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
......@@ -187,6 +214,10 @@ class Parser(parser.Parser):
offset=MEAN_RGB,
scale=STDDEV_RGB)
# Random erasing after the image has been normalized
if self._random_erasing is not None:
image = self._random_erasing.distort(image)
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
......
......@@ -56,6 +56,7 @@ def build_classification_model(
num_classes=model_config.num_classes,
input_specs=input_specs,
dropout_rate=model_config.dropout_rate,
kernel_initializer=model_config.kernel_initializer,
kernel_regularizer=l2_regularizer,
add_head_batch_norm=model_config.add_head_batch_norm,
use_sync_bn=norm_activation_config.use_sync_bn,
......
......@@ -12,10 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""AutoAugment and RandAugment policies for enhanced image/video preprocessing.
"""Augmentation policies for enhanced image/video preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
RandomErasing, Mixup and Cutmix are inspired by
https://github.com/rwightman/pytorch-image-models
"""
import math
from typing import Any, List, Iterable, Optional, Text, Tuple
......@@ -295,10 +303,26 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
cutout_center_width = tf.random.uniform(
shape=[], minval=0, maxval=image_width, dtype=tf.int32)
lower_pad = tf.maximum(0, cutout_center_height - pad_size)
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
left_pad = tf.maximum(0, cutout_center_width - pad_size)
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
image = _fill_rectangle(image, cutout_center_width, cutout_center_height,
pad_size, pad_size, replace)
return image
def _fill_rectangle(image,
center_width,
center_height,
half_width,
half_height,
replace=None):
"""Fill blank area."""
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
lower_pad = tf.maximum(0, center_height - half_height)
upper_pad = tf.maximum(0, image_height - center_height - half_height)
left_pad = tf.maximum(0, center_width - half_width)
right_pad = tf.maximum(0, image_width - center_width - half_width)
cutout_shape = [
image_height - (lower_pad + upper_pad),
......@@ -311,9 +335,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
constant_values=1)
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 3])
image = tf.where(
tf.equal(mask, 0),
tf.ones_like(image, dtype=image.dtype) * replace, image)
if replace is None:
fill = tf.random.normal(tf.shape(image), dtype=image.dtype)
elif isinstance(replace, tf.Tensor):
fill = replace
else:
fill = tf.ones_like(image, dtype=image.dtype) * replace
image = tf.where(tf.equal(mask, 0), fill, image)
return image
......@@ -803,11 +833,20 @@ def level_to_arg(cutout_const: float, translate_const: float):
return args
def _parse_policy_info(name: Text, prob: float, level: float,
replace_value: List[int], cutout_const: float,
translate_const: float) -> Tuple[Any, float, Any]:
def _parse_policy_info(name: Text,
prob: float,
level: float,
replace_value: List[int],
cutout_const: float,
translate_const: float,
level_std: float = 0.) -> Tuple[Any, float, Any]:
"""Return the function that corresponds to `name` and update `level` param."""
func = NAME_TO_FUNC[name]
if level_std > 0:
level += tf.random.normal([], dtype=tf.float32)
level = tf.clip_by_value(level, 0., _MAX_LEVEL)
args = level_to_arg(cutout_const, translate_const)[name](level)
if name in REPLACE_FUNCS:
......@@ -1184,7 +1223,9 @@ class RandAugment(ImageAugment):
magnitude: float = 10.,
cutout_const: float = 40.,
translate_const: float = 100.,
prob_to_apply: Optional[float] = None):
magnitude_std: float = 0.0,
prob_to_apply: Optional[float] = None,
exclude_ops: Optional[List[str]] = None):
"""Applies the RandAugment policy to images.
Args:
......@@ -1196,8 +1237,11 @@ class RandAugment(ImageAugment):
[5, 10].
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
magnitude_std: randomness of the severity as proposed by the authors of
the timm library.
prob_to_apply: The probability to apply the selected augmentation at each
layer.
exclude_ops: exclude selected operations.
"""
super(RandAugment, self).__init__()
......@@ -1212,6 +1256,11 @@ class RandAugment(ImageAugment):
'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'
]
self.magnitude_std = magnitude_std
if exclude_ops:
self.available_ops = [
op for op in self.available_ops if op not in exclude_ops
]
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies the RandAugment policy to `image`.
......@@ -1246,7 +1295,8 @@ class RandAugment(ImageAugment):
dtype=tf.float32)
func, _, args = _parse_policy_info(op_name, prob, self.magnitude,
replace_value, self.cutout_const,
self.translate_const)
self.translate_const,
self.magnitude_std)
branch_fns.append((
i,
# pylint:disable=g-long-lambda
......@@ -1267,3 +1317,271 @@ class RandAugment(ImageAugment):
image = tf.cast(image, dtype=input_image_type)
return image
class RandomErasing(ImageAugment):
"""Applies RandomErasing to a single image.
Reference: https://arxiv.org/abs/1708.04896
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def __init__(self,
probability: float = 0.25,
min_area: float = 0.02,
max_area: float = 1 / 3,
min_aspect: float = 0.3,
max_aspect=None,
min_count=1,
max_count=1,
trials=10):
"""Applies RandomErasing to a single image.
Args:
probability (float, optional): Probability of augmenting the image.
Defaults to 0.25.
min_area (float, optional): Minimum area of the random erasing rectangle.
Defaults to 0.02.
max_area (float, optional): Maximum area of the random erasing rectangle.
Defaults to 1/3.
min_aspect (float, optional): Minimum aspect rate of the random erasing
rectangle. Defaults to 0.3.
max_aspect ([type], optional): Maximum aspect rate of the random erasing
rectangle. Defaults to None.
min_count (int, optional): Minimum number of erased rectangles. Defaults
to 1.
max_count (int, optional): Maximum number of erased rectangles. Defaults
to 1.
trials (int, optional): Maximum number of trials to randomly sample a
rectangle that fulfills constraint. Defaults to 10.
"""
self._probability = probability
self._min_area = float(min_area)
self._max_area = float(max_area)
self._min_log_aspect = math.log(min_aspect)
self._max_log_aspect = math.log(max_aspect or 1 / min_aspect)
self._min_count = min_count
self._max_count = max_count
self._trials = trials
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies RandomErasing to single `image`.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
Returns:
tf.Tensor: The augmented version of `image`.
"""
uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0)
mirror_cond = tf.less(uniform_random, self._probability)
image = tf.cond(mirror_cond, lambda: self._erase(image), lambda: image)
return image
@tf.function
def _erase(self, image: tf.Tensor) -> tf.Tensor:
"""Erase an area."""
if self._min_count == self._max_count:
count = self._min_count
else:
count = tf.random.uniform(
shape=[],
minval=int(self._min_count),
maxval=int(self._max_count - self._min_count + 1),
dtype=tf.int32)
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
area = tf.cast(image_width * image_height, tf.float32)
for _ in range(count):
# Work around since break is not supported in tf.function
is_trial_successfull = False
for _ in range(self._trials):
if not is_trial_successfull:
erase_area = tf.random.uniform(
shape=[],
minval=area * self._min_area,
maxval=area * self._max_area)
aspect_ratio = tf.math.exp(
tf.random.uniform(
shape=[],
minval=self._min_log_aspect,
maxval=self._max_log_aspect))
half_height = tf.cast(
tf.math.round(tf.math.sqrt(erase_area * aspect_ratio) / 2),
dtype=tf.int32)
half_width = tf.cast(
tf.math.round(tf.math.sqrt(erase_area / aspect_ratio) / 2),
dtype=tf.int32)
if 2 * half_height < image_height and 2 * half_width < image_width:
center_height = tf.random.uniform(
shape=[],
minval=0,
maxval=int(image_height - 2 * half_height),
dtype=tf.int32)
center_width = tf.random.uniform(
shape=[],
minval=0,
maxval=int(image_width - 2 * half_width),
dtype=tf.int32)
image = _fill_rectangle(
image,
center_width,
center_height,
half_width,
half_height,
replace=None)
is_trial_successfull = True
return image
class MixupAndCutmix:
"""Applies Mixup and/or Cutmix to a batch of images.
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def __init__(self,
mixup_alpha: float = .8,
cutmix_alpha: float = 1.,
prob: float = 1.0,
switch_prob: float = 0.5,
label_smoothing: float = 0.1,
num_classes: int = 1001):
"""Applies Mixup and/or Cutmix to a batch of images.
Args:
mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Mixup is deactivated.
Defaults to .8.
cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Cutmix is deactivated.
Defaults to 1..
prob (float, optional): Of augmenting the batch. Defaults to 1.0.
switch_prob (float, optional): Probability of applying Cutmix for the
batch. Defaults to 0.5.
label_smoothing (float, optional): Constant for label smoothing. Defaults
to 0.1.
num_classes (int, optional): Number of classes. Defaults to 1001.
"""
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mode = 'batch'
self.mixup_enabled = True
if self.mixup_alpha and not self.cutmix_alpha:
self.switch_prob = -1
elif not self.mixup_alpha and self.cutmix_alpha:
self.switch_prob = 1
def __call__(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
return self.distort(images, labels)
def distort(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Applies Mixup and/or Cutmix to batch of images and transforms labels.
Args:
images (tf.Tensor): Of shape [batch_size,height, width, 3] representing a
batch of image.
labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
each image of the batch.
Returns:
Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
`labels`.
"""
augment_cond = tf.less(
tf.random.uniform(shape=[], minval=0., maxval=1.0), self.mix_prob)
# pylint: disable=g-long-lambda
augment_a = lambda: self._update_labels(*tf.cond(
tf.less(
tf.random.uniform(shape=[], minval=0., maxval=1.0), self.switch_prob
), lambda: self._cutmix(images, labels), lambda: self._mixup(
images, labels)))
augment_b = lambda: (images, self._smooth_labels(labels))
# pylint: enable=g-long-lambda
return tf.cond(augment_cond, augment_a, augment_b)
@staticmethod
def _sample_from_beta(alpha, beta, shape):
sample_alpha = tf.random.gamma(shape, 1., beta=alpha)
sample_beta = tf.random.gamma(shape, 1., beta=beta)
return sample_alpha / (sample_alpha + sample_beta)
def _cutmix(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Apply cutmix."""
lam = MixupAndCutmix._sample_from_beta(self.cutmix_alpha, self.cutmix_alpha,
labels.shape)
ratio = tf.math.sqrt(1 - lam)
batch_size = tf.shape(images)[0]
image_height, image_width = tf.shape(images)[1], tf.shape(images)[2]
cut_height = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
cut_width = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
random_center_height = tf.random.uniform(
shape=[batch_size], minval=0, maxval=image_height, dtype=tf.int32)
random_center_width = tf.random.uniform(
shape=[batch_size], minval=0, maxval=image_width, dtype=tf.int32)
bbox_area = cut_height * cut_width
lam = 1. - bbox_area / (image_height * image_width)
lam = tf.cast(lam, dtype=tf.float32)
images = tf.map_fn(
lambda x: _fill_rectangle(*x),
(images, random_center_width, random_center_height, cut_width // 2,
cut_height // 2, tf.reverse(images, [0])),
dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.float32),
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=tf.float32))
return images, labels, lam
def _mixup(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha,
labels.shape)
lam = tf.reshape(lam, [-1, 1, 1, 1])
images = lam * images + (1. - lam) * tf.reverse(images, [0])
return images, labels, tf.squeeze(lam)
def _smooth_labels(self, labels: tf.Tensor) -> tf.Tensor:
off_value = self.label_smoothing / self.num_classes
on_value = 1. - self.label_smoothing + off_value
smooth_labels = tf.one_hot(
labels, self.num_classes, on_value=on_value, off_value=off_value)
return smooth_labels
def _update_labels(self, images: tf.Tensor, labels: tf.Tensor,
lam: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
labels_1 = self._smooth_labels(labels)
labels_2 = tf.reverse(labels_1, [0])
lam = tf.reshape(lam, [-1, 1])
labels = lam * labels_1 + (1. - lam) * labels_2
return images, labels
......@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
augmenter.distort(image)
class RandomErasingTest(tf.test.TestCase, parameterized.TestCase):
def test_random_erase_replaces_some_pixels(self):
image = tf.zeros((224, 224, 3), dtype=tf.float32)
augmenter = augment.RandomErasing(probability=1., max_count=10)
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertNotEqual(0, tf.reduce_max(aug_image))
class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
def test_mixup_and_cutmix_smoothes_labels(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
def test_mixup_changes_image(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_image(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
if __name__ == '__main__':
tf.test.main()
......@@ -15,12 +15,13 @@
"""Preprocessing ops."""
import math
from typing import Optional
from six.moves import range
import tensorflow as tf
from official.vision.beta.ops import augment
from official.vision.beta.ops import box_ops
CENTER_CROP_FRACTION = 0.875
......@@ -557,6 +558,107 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
return image, normalized_boxes, masks
def color_jitter(image: tf.Tensor,
brightness: Optional[float] = 0.,
contrast: Optional[float] = 0.,
saturation: Optional[float] = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
brightness (float, optional): Magnitude for brightness jitter. Defaults to
0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
saturation (float, optional): Magnitude for saturation jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
image = tf.cast(image, dtype=tf.uint8)
image = random_brightness(image, brightness, seed=seed)
image = random_contrast(image, contrast, seed=seed)
image = random_saturation(image, saturation, seed=seed)
return image
def random_brightness(image: tf.Tensor,
brightness: float = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters brightness of an image.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
brightness (float, optional): Magnitude for brightness jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert brightness >= 0, '`brightness` must be positive'
brightness = tf.random.uniform([],
max(0, 1 - brightness),
1 + brightness,
seed=seed,
dtype=tf.float32)
return augment.brightness(image, brightness)
def random_contrast(image: tf.Tensor,
contrast: float = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert contrast >= 0, '`contrast` must be positive'
contrast = tf.random.uniform([],
max(0, 1 - contrast),
1 + contrast,
seed=seed,
dtype=tf.float32)
return augment.contrast(image, contrast)
def random_saturation(image: tf.Tensor,
saturation: float = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
saturation (float, optional): Magnitude for saturation jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert saturation >= 0, '`saturation` must be positive'
saturation = tf.random.uniform([],
max(0, 1 - saturation),
1 + saturation,
seed=seed,
dtype=tf.float32)
return _saturation(image, saturation)
def _saturation(image: tf.Tensor,
saturation: Optional[float] = 0.) -> tf.Tensor:
return augment.blend(
tf.repeat(tf.image.rgb_to_grayscale(image), 3, axis=-1), image,
saturation)
def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
aspect_ratio_range,
min_overlap_params, max_retry):
......
......@@ -197,6 +197,19 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
_ = preprocess_ops.random_crop_image_v2(
image_bytes, tf.constant([input_height, input_width, 3], tf.int32))
@parameterized.parameters((400, 600, 0), (400, 600, 0.4), (600, 400, 1.4))
def testColorJitter(self, input_height, input_width, color_jitter):
image = tf.convert_to_tensor(np.random.rand(input_height, input_width, 3))
jittered_image = preprocess_ops.color_jitter(image, color_jitter,
color_jitter, color_jitter)
assert jittered_image.shape == image.shape
@parameterized.parameters((400, 600, 0), (400, 600, 0.4), (600, 400, 1))
def testSaturation(self, input_height, input_width, saturation):
image = tf.convert_to_tensor(np.random.rand(input_height, input_width, 3))
jittered_image = preprocess_ops._saturation(image, saturation)
assert jittered_image.shape == image.shape
@parameterized.parameters((640, 640, 20), (1280, 1280, 30))
def test_random_crop(self, input_height, input_width, num_boxes):
image = tf.convert_to_tensor(np.random.rand(input_height, input_width, 3))
......
# Vision Transformer (ViT)
# Vision Transformer (ViT) and Data-Efficient Image Transformer (DEIT)
**DISCLAIMER**: This implementation is still under development. No support will
be provided during the development phase.
[![Paper](http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv)](https://arxiv.org/abs/2010.11929)
- [![ViT Paper](http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv)](https://arxiv.org/abs/2010.11929)
- [![DEIT Paper](http://img.shields.io/badge/Paper-arXiv.2012.12877-B3181B?logo=arXiv)](https://arxiv.org/abs/2012.12877)
This repository is the implementations of Vision Transformer (ViT) in
TensorFlow 2.
This repository is the implementations of Vision Transformer (ViT) and
Data-Efficient Image Transformer (DEIT) in TensorFlow 2.
* Paper title:
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf).
\ No newline at end of file
- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf).
- [Training data-efficient image transformers & distillation through attention](https://arxiv.org/pdf/2012.12877.pdf).
......@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config):
hidden_size: int = 1
patch_size: int = 16
transformer: Transformer = Transformer()
init_stochastic_depth_rate: float = 0.0
original_init: bool = True
@dataclasses.dataclass
......
......@@ -44,6 +44,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
@dataclasses.dataclass
......@@ -51,6 +52,7 @@ class Losses(hyperparams.Config):
one_hot: bool = True
label_smoothing: float = 0.0
l2_weight_decay: float = 0.0
soft_labels: bool = False
@dataclasses.dataclass
......@@ -79,6 +81,87 @@ task_factory.register_task_cls(ImageClassificationTask)(
image_classification.ImageClassificationTask)
@exp_factory.register_config_factory('deit_imagenet_pretrain')
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
eval_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
num_classes = 1001
label_smoothing = 0.1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=num_classes,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(
l2_weight_decay=0.0,
label_smoothing=label_smoothing,
one_hot=False,
soft_labels=True),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
aug_type=common.Augmentation(
type='randaug',
randaug=common.RandAugment(
magnitude=9, exclude_ops=['Cutout'])),
mixup_and_cutmix=common.MixupAndCutmix(
label_smoothing=label_smoothing)),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_pretrain')
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
......@@ -90,6 +173,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
......@@ -116,12 +200,13 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
'adamw': {
'weight_decay_rate': 0.3,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003,
'initial_learning_rate': 0.003 * train_batch_size / 4096,
'decay_steps': 300 * steps_per_epoch,
}
},
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
import tensorflow as tf
from official.nlp import keras_nlp
from official.vision.beta.modeling.layers.nn_layers import StochasticDepth
class TransformerEncoderBlock(keras_nlp.layers.TransformerEncoderBlock):
"""TransformerEncoderBlock layer with stochastic depth."""
def __init__(self, *args, stochastic_depth_drop_rate=0.0, **kwargs):
"""Initializes TransformerEncoderBlock."""
super().__init__(*args, **kwargs)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
def build(self, input_shape):
if self._stochastic_depth_drop_rate:
self._stochastic_depth = StochasticDepth(self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x)
super().build(input_shape)
def get_config(self):
config = {"stochastic_depth_drop_rate": self._stochastic_depth_drop_rate}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
"""Transformer self-attention encoder block call."""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + self._stochastic_depth(
attention_output, training=training)
else:
attention_output = self._attention_layer_norm(
target_tensor +
self._stochastic_depth(attention_output, training=training))
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + self._stochastic_depth(
layer_output, training=training)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(
layer_output +
self._stochastic_depth(attention_output, training=training))
......@@ -17,17 +17,24 @@
import tensorflow as tf
from official.modeling import activations
from official.nlp import keras_nlp
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.projects.vit.modeling import nn_blocks
layers = tf.keras.layers
VIT_SPECS = {
'vit-testing':
'vit-ti16':
dict(
hidden_size=1,
hidden_size=192,
patch_size=16,
transformer=dict(mlp_dim=1, num_heads=1, num_layers=1),
transformer=dict(mlp_dim=768, num_heads=3, num_layers=12),
),
'vit-s16':
dict(
hidden_size=384,
patch_size=16,
transformer=dict(mlp_dim=1536, num_heads=6, num_layers=12),
),
'vit-b16':
dict(
......@@ -112,6 +119,8 @@ class Encoder(tf.keras.layers.Layer):
attention_dropout_rate=0.1,
kernel_regularizer=None,
inputs_positions=None,
init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform',
**kwargs):
super().__init__(**kwargs)
self._num_layers = num_layers
......@@ -121,6 +130,8 @@ class Encoder(tf.keras.layers.Layer):
self._attention_dropout_rate = attention_dropout_rate
self._kernel_regularizer = kernel_regularizer
self._inputs_positions = inputs_positions
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer
def build(self, input_shape):
self._pos_embed = AddPositionEmbs(
......@@ -131,15 +142,18 @@ class Encoder(tf.keras.layers.Layer):
self._encoder_layers = []
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for _ in range(self._num_layers):
encoder_layer = keras_nlp.layers.TransformerEncoderBlock(
for i in range(self._num_layers):
encoder_layer = nn_blocks.TransformerEncoderBlock(
inner_activation=activations.gelu,
num_attention_heads=self._num_heads,
inner_dim=self._mlp_dim,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
norm_first=True,
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 1, self._num_layers),
norm_epsilon=1e-6)
self._encoder_layers.append(encoder_layer)
self._norm = layers.LayerNormalization(epsilon=1e-6)
......@@ -164,12 +178,14 @@ class VisionTransformer(tf.keras.Model):
num_layers=12,
attention_dropout_rate=0.0,
dropout_rate=0.1,
init_stochastic_depth_rate=0.0,
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
patch_size=16,
hidden_size=768,
representation_size=0,
classifier='token',
kernel_regularizer=None):
kernel_regularizer=None,
original_init=True):
"""VisionTransformer initialization function."""
inputs = tf.keras.Input(shape=input_specs.shape[1:])
......@@ -178,7 +194,8 @@ class VisionTransformer(tf.keras.Model):
kernel_size=patch_size,
strides=patch_size,
padding='valid',
kernel_regularizer=kernel_regularizer)(
kernel_regularizer=kernel_regularizer,
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
inputs)
if tf.keras.backend.image_data_format() == 'channels_last':
rows_axis, cols_axis = (1, 2)
......@@ -203,7 +220,10 @@ class VisionTransformer(tf.keras.Model):
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
kernel_regularizer=kernel_regularizer)(
kernel_regularizer=kernel_regularizer,
kernel_initializer='glorot_uniform' if original_init else dict(
class_name='TruncatedNormal', config=dict(stddev=.02)),
init_stochastic_depth_rate=init_stochastic_depth_rate)(
x)
if classifier == 'token':
......@@ -215,7 +235,8 @@ class VisionTransformer(tf.keras.Model):
x = tf.keras.layers.Dense(
representation_size,
kernel_regularizer=kernel_regularizer,
name='pre_logits')(
name='pre_logits',
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
x)
x = tf.nn.tanh(x)
else:
......@@ -247,9 +268,11 @@ def build_vit(input_specs,
num_layers=backbone_cfg.transformer.num_layers,
attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
dropout_rate=backbone_cfg.transformer.dropout_rate,
init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
input_specs=input_specs,
patch_size=backbone_cfg.patch_size,
hidden_size=backbone_cfg.hidden_size,
representation_size=backbone_cfg.representation_size,
classifier=backbone_cfg.classifier,
kernel_regularizer=l2_regularizer)
kernel_regularizer=l2_regularizer,
original_init=backbone_cfg.original_init)
......@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig):
@exp_factory.register_config_factory('darknet_classification')
def image_classification() -> cfg.ExperimentConfig:
def darknet_classification() -> cfg.ExperimentConfig:
"""Image classification general."""
return cfg.ExperimentConfig(
task=ImageClassificationTask(),
......
......@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input
from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.modeling import factory
from official.vision.beta.ops import augment
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
......@@ -103,14 +104,26 @@ class ImageClassificationTask(base_task.Task):
decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type,
color_jitter=params.color_jitter,
random_erasing=params.random_erasing,
is_multilabel=is_multilabel,
dtype=params.dtype)
postprocess_fn = None
if params.mixup_and_cutmix:
postprocess_fn = augment.MixupAndCutmix(
mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
prob=params.mixup_and_cutmix.prob,
label_smoothing=params.mixup_and_cutmix.label_smoothing,
num_classes=num_classes)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn)
dataset = reader.read(input_context=input_context)
......@@ -140,6 +153,9 @@ class ImageClassificationTask(base_task.Task):
model_outputs,
from_logits=True,
label_smoothing=losses_config.label_smoothing)
elif losses_config.soft_labels:
total_loss = tf.nn.softmax_cross_entropy_with_logits(
labels, model_outputs)
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
......@@ -161,7 +177,8 @@ class ImageClassificationTask(base_task.Task):
is_multilabel = self.task_config.train_data.is_multilabel
if not is_multilabel:
k = self.task_config.evaluation.top_k
if self.task_config.losses.one_hot:
if (self.task_config.losses.one_hot or
self.task_config.losses.soft_labels):
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
......@@ -223,7 +240,9 @@ class ImageClassificationTask(base_task.Task):
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
model_outputs=outputs,
labels=labels,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
......@@ -266,14 +285,18 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs.
"""
features, labels = inputs
one_hot = self.task_config.losses.one_hot
soft_labels = self.task_config.losses.soft_labels
is_multilabel = self.task_config.train_data.is_multilabel
if self.task_config.losses.one_hot and not is_multilabel:
if (one_hot or soft_labels) and not is_multilabel:
labels = tf.one_hot(labels, self.task_config.model.num_classes)
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
loss = self.build_losses(
model_outputs=outputs,
labels=labels,
aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册