提交 1d61880a 编写于 作者: Z Zhenyu Tan 提交者: TensorFlower Gardener

Create Image Preproc RandomFlip layer.

PiperOrigin-RevId: 285843812
Change-Id: Ifd8183f916d673a7e661ca5dca84312c5cc74078
上级 ea0c4e61
......@@ -313,6 +313,75 @@ class Rescaling(Layer):
return dict(list(base_config.items()) + list(config.items()))
class RandomFlip(Layer):
"""Randomly flip each image horizontally and vertically.
This layer will by default flip the images horizontally and then vertically
during training time.
`RandomFlip(horizontal=True)` will only flip the input horizontally.
`RandomFlip(vertical=True)` will only flip the input vertically.
During inference time, the output will be identical to input. Call the layer
with `training=True` to flip the input.
Input shape:
4D tensor with shape:
`(samples, height, width, channels)`, data_format='channels_last'.
Output shape:
4D tensor with shape:
`(samples, height, width, channels)`, data_format='channels_last'.
Attributes:
horizontal: Bool, whether to randomly flip horizontally.
width: Bool, whether to randomly flip vertically.
seed: Integer. Used to create a random seed.
"""
def __init__(self, horizontal=None, vertical=None, seed=None, **kwargs):
# If both arguments are None, set both to True.
if horizontal is None and vertical is None:
self.horizontal = True
self.vertical = True
else:
self.horizontal = horizontal or False
self.vertical = vertical or False
self.seed = seed
self._rng = make_generator(self.seed)
self.input_spec = InputSpec(ndim=4)
super(RandomFlip, self).__init__(**kwargs)
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
def random_flipped_inputs():
flipped_outputs = inputs
if self.horizontal:
flipped_outputs = image_ops.random_flip_up_down(flipped_outputs,
self.seed)
if self.vertical:
flipped_outputs = image_ops.random_flip_left_right(
flipped_outputs, self.seed)
return flipped_outputs
output = tf_utils.smart_cond(training, random_flipped_inputs,
lambda: inputs)
output.set_shape(inputs.shape)
return output
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {
'horizontal': self.horizontal,
'vertical': self.vertical,
'seed': self.seed,
}
base_config = super(RandomFlip, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def make_generator(seed=None):
if seed:
return stateful_random_ops.Generator.from_seed(seed)
......
......@@ -286,5 +286,93 @@ class RescalingTest(keras_parameterized.TestCase):
self.assertEqual(layer_1.name, layer.name)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class RandomFlipTest(keras_parameterized.TestCase):
def _run_test(self,
flip_horizontal,
flip_vertical,
expected_output=None,
mock_random=None):
np.random.seed(1337)
num_samples = 2
orig_height = 5
orig_width = 8
channels = 3
if mock_random is None:
mock_random = [1 for _ in range(num_samples)]
mock_random = np.reshape(mock_random, [2, 1, 1, 1])
inp = np.random.random((num_samples, orig_height, orig_width, channels))
if expected_output is None:
expected_output = inp
if flip_horizontal:
expected_output = np.flip(expected_output, axis=1)
if flip_vertical:
expected_output = np.flip(expected_output, axis=2)
with test.mock.patch.object(
random_ops, 'random_uniform', return_value=mock_random):
with tf_test_util.use_gpu():
layer = image_preprocessing.RandomFlip(flip_horizontal, flip_vertical)
actual_output = layer(inp, training=1)
self.assertAllClose(expected_output, actual_output)
@parameterized.named_parameters(('random_flip_horizontal', True, False),
('random_flip_vertical', False, True),
('random_flip_both', True, True),
('random_flip_neither', False, False))
def test_random_flip(self, flip_horizontal, flip_vertical):
with CustomObjectScope({'RandomFlip': image_preprocessing.RandomFlip}):
self._run_test(flip_horizontal, flip_vertical)
def test_random_flip_horizontal_half(self):
with CustomObjectScope({'RandomFlip': image_preprocessing.RandomFlip}):
np.random.seed(1337)
mock_random = [1, 0]
mock_random = np.reshape(mock_random, [2, 1, 1, 1])
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
expected_output = input_images.copy()
expected_output[0, :, :, :] = np.flip(input_images[0, :, :, :], axis=0)
self._run_test(True, False, expected_output, mock_random)
def test_random_flip_vertical_half(self):
with CustomObjectScope({'RandomFlip': image_preprocessing.RandomFlip}):
np.random.seed(1337)
mock_random = [1, 0]
mock_random = np.reshape(mock_random, [2, 1, 1, 1])
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
expected_output = input_images.copy()
expected_output[0, :, :, :] = np.flip(input_images[0, :, :, :], axis=1)
self._run_test(False, True, expected_output, mock_random)
def test_random_flip_inference(self):
with CustomObjectScope({'RandomFlip': image_preprocessing.RandomFlip}):
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
expected_output = input_images
with tf_test_util.use_gpu():
layer = image_preprocessing.RandomFlip(True, True)
actual_output = layer(input_images, training=0)
self.assertAllClose(expected_output, actual_output)
def test_random_flip_default(self):
with CustomObjectScope({'RandomFlip': image_preprocessing.RandomFlip}):
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
expected_output = np.flip(np.flip(input_images, axis=1), axis=2)
mock_random = [1, 1]
mock_random = np.reshape(mock_random, [2, 1, 1, 1])
with test.mock.patch.object(
random_ops, 'random_uniform', return_value=mock_random):
with self.cached_session(use_gpu=True):
layer = image_preprocessing.RandomFlip()
actual_output = layer(input_images, training=1)
self.assertAllClose(expected_output, actual_output)
@tf_test_util.run_v2_only
def test_config_with_custom_name(self):
layer = image_preprocessing.RandomFlip(5, 5, name='image_preproc')
config = layer.get_config()
layer_1 = image_preprocessing.RandomFlip.from_config(config)
self.assertEqual(layer_1.name, layer.name)
if __name__ == '__main__':
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册