diff --git a/official/vision/beta/configs/retinanet.py b/official/vision/beta/configs/retinanet.py index e0c793496278c0a6f44db207becb71ee4d76d8e0..5c1dee4f29ef0e53514b4c220ab4346d102fc285 100644 --- a/official/vision/beta/configs/retinanet.py +++ b/official/vision/beta/configs/retinanet.py @@ -58,7 +58,6 @@ class Parser(hyperparams.Config): skip_crowd_during_training: bool = True max_num_instances: int = 100 # Can choose AutoAugment and RandAugment. - # TODO(b/205346436) Support RandAugment. aug_type: Optional[common.Augmentation] = None # Keep for backward compatibility. Not used. diff --git a/official/vision/beta/dataloaders/retinanet_input.py b/official/vision/beta/dataloaders/retinanet_input.py index 846a0137593ee80d57812c569affa20e6cf65253..4f734ac7ecf123b1a740f206e750e1090aa9224f 100644 --- a/official/vision/beta/dataloaders/retinanet_input.py +++ b/official/vision/beta/dataloaders/retinanet_input.py @@ -75,7 +75,7 @@ class Parser(parser.Parser): upper-bound threshold to assign negative labels for anchors. An anchor with a score below the threshold is labeled negative. aug_type: An optional Augmentation object to choose from AutoAugment and - RandAugment. The latter is not supported, and will raise ValueError. + RandAugment. aug_rand_hflip: `bool`, if True, augment training with random horizontal flip. aug_scale_min: `float`, the minimum scale applied to `output_size` for @@ -122,8 +122,16 @@ class Parser(parser.Parser): augmentation_name=aug_type.autoaug.augmentation_name, cutout_const=aug_type.autoaug.cutout_const, translate_const=aug_type.autoaug.translate_const) + elif aug_type.type == 'randaug': + logging.info('Using RandAugment.') + self._augmenter = augment.RandAugment.build_for_detection( + num_layers=aug_type.randaug.num_layers, + magnitude=aug_type.randaug.magnitude, + cutout_const=aug_type.randaug.cutout_const, + translate_const=aug_type.randaug.translate_const, + prob_to_apply=aug_type.randaug.prob_to_apply, + exclude_ops=aug_type.randaug.exclude_ops) else: - # TODO(b/205346436) Support RandAugment. raise ValueError(f'Augmentation policy {aug_type.type} not supported.') # Deprecated. Data Augmentation with AutoAugment. @@ -162,7 +170,6 @@ class Parser(parser.Parser): # Apply autoaug or randaug. if self._augmenter is not None: image, boxes = self._augmenter.distort_with_boxes(image, boxes) - image_shape = tf.shape(input=image)[0:2] # Normalizes image with mean and std pixel values. diff --git a/official/vision/beta/ops/augment.py b/official/vision/beta/ops/augment.py index 2ec7519f5d50889e43f40279fc82160d85c341b3..32395d883eccb88a5107e967e438534853bd0684 100644 --- a/official/vision/beta/ops/augment.py +++ b/official/vision/beta/ops/augment.py @@ -1950,6 +1950,37 @@ class RandAugment(ImageAugment): op for op in self.available_ops if op not in exclude_ops ] + @classmethod + def build_for_detection(cls, + num_layers: int = 2, + magnitude: float = 10., + cutout_const: float = 40., + translate_const: float = 100., + magnitude_std: float = 0.0, + prob_to_apply: Optional[float] = None, + exclude_ops: Optional[List[str]] = None): + """Builds a RandAugment that modifies bboxes for geometric transforms.""" + augmenter = cls( + num_layers=num_layers, + magnitude=magnitude, + cutout_const=cutout_const, + translate_const=translate_const, + magnitude_std=magnitude_std, + prob_to_apply=prob_to_apply, + exclude_ops=exclude_ops) + box_aware_ops_by_base_name = { + 'Rotate': 'Rotate_BBox', + 'ShearX': 'ShearX_BBox', + 'ShearY': 'ShearY_BBox', + 'TranslateX': 'TranslateX_BBox', + 'TranslateY': 'TranslateY_BBox', + } + augmenter.available_ops = [ + box_aware_ops_by_base_name.get(op_name) or op_name + for op_name in augmenter.available_ops + ] + return augmenter + def _distort_common( self, image: tf.Tensor, diff --git a/official/vision/beta/ops/augment_test.py b/official/vision/beta/ops/augment_test.py index 45d248464781217df16e8dc060eb61cc0a61e736..f5deb77f6959d76aecbbcd36539459fc446ba03c 100644 --- a/official/vision/beta/ops/augment_test.py +++ b/official/vision/beta/ops/augment_test.py @@ -140,6 +140,23 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): self.assertEqual((224, 224, 3), aug_image.shape) self.assertEqual((2, 4), aug_bboxes.shape) + def test_randaug_build_for_detection(self): + """Smoke test to be sure there are no syntax errors built for detection.""" + image = tf.zeros((224, 224, 3), dtype=tf.uint8) + bboxes = tf.ones((2, 4), dtype=tf.float32) + + augmenter = augment.RandAugment.build_for_detection() + self.assertCountEqual(augmenter.available_ops, [ + 'AutoContrast', 'Equalize', 'Invert', 'Posterize', 'Solarize', 'Color', + 'Contrast', 'Brightness', 'Sharpness', 'Cutout', 'SolarizeAdd', + 'Rotate_BBox', 'ShearX_BBox', 'ShearY_BBox', 'TranslateX_BBox', + 'TranslateY_BBox' + ]) + + aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes) + self.assertEqual((224, 224, 3), aug_image.shape) + self.assertEqual((2, 4), aug_bboxes.shape) + def test_all_policy_ops(self): """Smoke test to be sure all augmentation functions can execute."""