提交 c9a7e0b2 编写于 作者: A A. Unique TensorFlower

Add builder that applies bounding box-specific ops for RandAugment

PiperOrigin-RevId: 421439862
上级 49a5706c
......@@ -58,7 +58,6 @@ class Parser(hyperparams.Config):
skip_crowd_during_training: bool = True
max_num_instances: int = 100
# Can choose AutoAugment and RandAugment.
# TODO(b/205346436) Support RandAugment.
aug_type: Optional[common.Augmentation] = None
# Keep for backward compatibility. Not used.
......
......@@ -75,7 +75,7 @@ class Parser(parser.Parser):
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment. The latter is not supported, and will raise ValueError.
RandAugment.
aug_rand_hflip: `bool`, if True, augment training with random horizontal
flip.
aug_scale_min: `float`, the minimum scale applied to `output_size` for
......@@ -122,8 +122,16 @@ class Parser(parser.Parser):
augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
elif aug_type.type == 'randaug':
logging.info('Using RandAugment.')
self._augmenter = augment.RandAugment.build_for_detection(
num_layers=aug_type.randaug.num_layers,
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
else:
# TODO(b/205346436) Support RandAugment.
raise ValueError(f'Augmentation policy {aug_type.type} not supported.')
# Deprecated. Data Augmentation with AutoAugment.
......@@ -162,7 +170,6 @@ class Parser(parser.Parser):
# Apply autoaug or randaug.
if self._augmenter is not None:
image, boxes = self._augmenter.distort_with_boxes(image, boxes)
image_shape = tf.shape(input=image)[0:2]
# Normalizes image with mean and std pixel values.
......
......@@ -1950,6 +1950,37 @@ class RandAugment(ImageAugment):
op for op in self.available_ops if op not in exclude_ops
]
@classmethod
def build_for_detection(cls,
num_layers: int = 2,
magnitude: float = 10.,
cutout_const: float = 40.,
translate_const: float = 100.,
magnitude_std: float = 0.0,
prob_to_apply: Optional[float] = None,
exclude_ops: Optional[List[str]] = None):
"""Builds a RandAugment that modifies bboxes for geometric transforms."""
augmenter = cls(
num_layers=num_layers,
magnitude=magnitude,
cutout_const=cutout_const,
translate_const=translate_const,
magnitude_std=magnitude_std,
prob_to_apply=prob_to_apply,
exclude_ops=exclude_ops)
box_aware_ops_by_base_name = {
'Rotate': 'Rotate_BBox',
'ShearX': 'ShearX_BBox',
'ShearY': 'ShearY_BBox',
'TranslateX': 'TranslateX_BBox',
'TranslateY': 'TranslateY_BBox',
}
augmenter.available_ops = [
box_aware_ops_by_base_name.get(op_name) or op_name
for op_name in augmenter.available_ops
]
return augmenter
def _distort_common(
self,
image: tf.Tensor,
......
......@@ -140,6 +140,23 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_randaug_build_for_detection(self):
"""Smoke test to be sure there are no syntax errors built for detection."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
augmenter = augment.RandAugment.build_for_detection()
self.assertCountEqual(augmenter.available_ops, [
'AutoContrast', 'Equalize', 'Invert', 'Posterize', 'Solarize', 'Color',
'Contrast', 'Brightness', 'Sharpness', 'Cutout', 'SolarizeAdd',
'Rotate_BBox', 'ShearX_BBox', 'ShearY_BBox', 'TranslateX_BBox',
'TranslateY_BBox'
])
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_all_policy_ops(self):
"""Smoke test to be sure all augmentation functions can execute."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册