提交 5b7cd01f 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 416886349
上级 782e39e8
......@@ -55,9 +55,14 @@ class Parser(hyperparams.Config):
aug_rand_hflip: bool = False
aug_scale_min: float = 1.0
aug_scale_max: float = 1.0
aug_policy: Optional[str] = None
skip_crowd_during_training: bool = True
max_num_instances: int = 100
# Can choose AutoAugment and RandAugment.
# TODO(b/205346436) Support RandAugment.
aug_type: Optional[common.Augmentation] = None
# Keep for backward compatibility. Not used.
aug_policy: Optional[str] = None
@dataclasses.dataclass
......
......@@ -19,11 +19,13 @@ into (image, labels) tuple for RetinaNet.
"""
# Import libraries
from absl import logging
import tensorflow as tf
from official.vision.beta.dataloaders import parser
from official.vision.beta.dataloaders import utils
from official.vision.beta.ops import anchor
from official.vision.beta.ops import augment
from official.vision.beta.ops import box_ops
from official.vision.beta.ops import preprocess_ops
......@@ -40,6 +42,7 @@ class Parser(parser.Parser):
anchor_size,
match_threshold=0.5,
unmatched_threshold=0.5,
aug_type=None,
aug_rand_hflip=False,
aug_scale_min=1.0,
aug_scale_max=1.0,
......@@ -71,6 +74,8 @@ class Parser(parser.Parser):
unmatched_threshold: `float` number between 0 and 1 representing the
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment. The latter is not supported, and will raise ValueError.
aug_rand_hflip: `bool`, if True, augment training with random horizontal
flip.
aug_scale_min: `float`, the minimum scale applied to `output_size` for
......@@ -108,7 +113,20 @@ class Parser(parser.Parser):
self._aug_scale_min = aug_scale_min
self._aug_scale_max = aug_scale_max
# Data Augmentation with AutoAugment.
# Data augmentation with AutoAugment or RandAugment.
self._augmenter = None
if aug_type is not None:
if aug_type.type == 'autoaug':
logging.info('Using AutoAugment.')
self._augmenter = augment.AutoAugment(
augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
else:
# TODO(b/205346436) Support RandAugment.
raise ValueError(f'Augmentation policy {aug_type.type} not supported.')
# Deprecated. Data Augmentation with AutoAugment.
self._use_autoaugment = use_autoaugment
self._autoaugment_policy_name = autoaugment_policy_name
......@@ -138,9 +156,13 @@ class Parser(parser.Parser):
for k, v in attributes.items():
attributes[k] = tf.gather(v, indices)
# Gets original image and its size.
# Gets original image.
image = data['image']
# Apply autoaug or randaug.
if self._augmenter is not None:
image, boxes = self._augmenter.distort_with_boxes(image, boxes)
image_shape = tf.shape(input=image)[0:2]
# Normalizes image with mean and std pixel values.
......
此差异已折叠。
......@@ -95,15 +95,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
'reduced_cifar10',
'svhn',
'reduced_imagenet',
]
AVAILABLE_POLICIES = [
'v0',
'test',
'simple',
'reduced_cifar10',
'svhn',
'reduced_imagenet',
'detection_v0',
]
def test_autoaugment(self):
......@@ -116,6 +108,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual((224, 224, 3), aug_image.shape)
def test_autoaugment_with_bboxes(self):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_randaug(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
......@@ -125,6 +129,17 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual((224, 224, 3), aug_image.shape)
def test_randaug_with_bboxes(self):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
augmenter = augment.RandAugment()
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_all_policy_ops(self):
"""Smoke test to be sure all augmentation functions can execute."""
......@@ -135,14 +150,37 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8)
bboxes = None
for op_name in augment.NAME_TO_FUNC.keys() - augment.REQUIRE_BOXES_FUNCS:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image, bboxes = func(image, bboxes, *args)
self.assertEqual((224, 224, 3), image.shape)
self.assertIsNone(bboxes)
def test_all_policy_ops_with_bboxes(self):
"""Smoke test to be sure all augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image = func(image, *args)
image, bboxes = func(image, bboxes, *args)
self.assertEqual((224, 224, 3), image.shape)
self.assertEqual((2, 4), bboxes.shape)
def test_autoaugment_video(self):
"""Smoke test with video to be sure there are no syntax errors."""
......@@ -154,6 +192,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual((2, 224, 224, 3), aug_image.shape)
def test_autoaugment_video_with_boxes(self):
"""Smoke test with video to be sure there are no syntax errors."""
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 2, 4), dtype=tf.float32)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((2, 224, 224, 3), aug_image.shape)
self.assertEqual((2, 2, 4), aug_bboxes.shape)
def test_randaug_video(self):
"""Smoke test with video to be sure there are no syntax errors."""
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8)
......@@ -173,14 +223,48 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const = 250
image = tf.ones((2, 224, 224, 3), dtype=tf.uint8)
bboxes = None
for op_name in augment.NAME_TO_FUNC.keys() - augment.REQUIRE_BOXES_FUNCS:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image, bboxes = func(image, bboxes, *args)
self.assertEqual((2, 224, 224, 3), image.shape)
self.assertIsNone(bboxes)
def test_all_policy_ops_video_with_bboxes(self):
"""Smoke test to be sure all video augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((2, 224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 2, 4), dtype=tf.float32)
for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image = func(image, *args)
if op_name in {
'Rotate_BBox',
'ShearX_BBox',
'ShearY_BBox',
'TranslateX_BBox',
'TranslateY_BBox',
'TranslateY_Only_BBoxes',
}:
with self.assertRaises(ValueError):
func(image, bboxes, *args)
else:
image, bboxes = func(image, bboxes, *args)
self.assertEqual((2, 224, 224, 3), image.shape)
self.assertEqual((2, 2, 4), bboxes.shape)
def _generate_test_policy(self):
"""Generate a test policy at random."""
......
......@@ -119,6 +119,7 @@ class RetinaNetTask(base_task.Task):
dtype=params.dtype,
match_threshold=params.parser.match_threshold,
unmatched_threshold=params.parser.unmatched_threshold,
aug_type=params.parser.aug_type,
aug_rand_hflip=params.parser.aug_rand_hflip,
aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册