提交 08c87366 编写于 作者: V vkk800 提交者: François Chollet

Refactor ImageDataGenerator (#10130)

* Create get_random_transform and refactor

* Fix style and add tests

* Add more tests

* Fix documentation error

* Fix documentation style issue

* add apply_affine_transform

* document transformation dictionary

* Doc style fix
上级 25a8973d
......@@ -62,14 +62,9 @@ def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0,
# Returns
Rotated Numpy image tensor.
"""
theta = np.deg2rad(np.random.uniform(-rg, rg))
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0],
[0, 0, 1]])
h, w = x.shape[row_axis], x.shape[col_axis]
transform_matrix = transform_matrix_offset_center(rotation_matrix, h, w)
x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval)
theta = np.random.uniform(-rg, rg)
x = apply_affine_transform(x, theta=theta, channel_axis=channel_axis,
fill_mode=fill_mode, cval=cval)
return x
......@@ -96,12 +91,8 @@ def random_shift(x, wrg, hrg, row_axis=1, col_axis=2, channel_axis=0,
h, w = x.shape[row_axis], x.shape[col_axis]
tx = np.random.uniform(-hrg, hrg) * h
ty = np.random.uniform(-wrg, wrg) * w
translation_matrix = np.array([[1, 0, tx],
[0, 1, ty],
[0, 0, 1]])
transform_matrix = translation_matrix # no need to do offset
x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval)
x = apply_affine_transform(x, tx=tx, ty=ty, channel_axis=channel_axis,
fill_mode=fill_mode, cval=cval)
return x
......@@ -124,14 +115,9 @@ def random_shear(x, intensity, row_axis=1, col_axis=2, channel_axis=0,
# Returns
Sheared Numpy image tensor.
"""
shear = np.deg2rad(np.random.uniform(-intensity, intensity))
shear_matrix = np.array([[1, -np.sin(shear), 0],
[0, np.cos(shear), 0],
[0, 0, 1]])
h, w = x.shape[row_axis], x.shape[col_axis]
transform_matrix = transform_matrix_offset_center(shear_matrix, h, w)
x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval)
shear = np.random.uniform(-intensity, intensity)
x = apply_affine_transform(x, shear=shear, channel_axis=channel_axis,
fill_mode=fill_mode, cval=cval)
return x
......@@ -165,18 +151,13 @@ def random_zoom(x, zoom_range, row_axis=1, col_axis=2, channel_axis=0,
zx, zy = 1, 1
else:
zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
zoom_matrix = np.array([[zx, 0, 0],
[0, zy, 0],
[0, 0, 1]])
h, w = x.shape[row_axis], x.shape[col_axis]
transform_matrix = transform_matrix_offset_center(zoom_matrix, h, w)
x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval)
x = apply_affine_transform(x, zx=zx, zy=zy, channel_axis=channel_axis,
fill_mode=fill_mode, cval=cval)
return x
def random_channel_shift(x, intensity, channel_axis=0):
"""Performs a random channel shift.
def apply_channel_shift(x, intensity, channel_axis=0):
"""Performs a channel shift.
# Arguments
x: Input tensor. Must be 3D.
......@@ -190,7 +171,7 @@ def random_channel_shift(x, intensity, channel_axis=0):
x = np.rollaxis(x, channel_axis, 0)
min_x, max_x = np.min(x), np.max(x)
channel_images = [
np.clip(x_channel + np.random.uniform(-intensity, intensity),
np.clip(x_channel + intensity,
min_x,
max_x)
for x_channel in x]
......@@ -199,6 +180,42 @@ def random_channel_shift(x, intensity, channel_axis=0):
return x
def random_channel_shift(x, intensity_range, channel_axis=0):
"""Performs a random channel shift.
# Arguments
x: Input tensor. Must be 3D.
intensity_range: Transformation intensity.
channel_axis: Index of axis for channels in the input tensor.
# Returns
Numpy image tensor.
"""
intensity = np.random.uniform(-intensity_range, intensity_range)
return apply_channel_shift(x, intensity, channel_axis=channel_axis)
def apply_brightness_shift(x, brightness):
"""Performs a brightness shift.
# Arguments
x: Input tensor. Must be 3D.
brightness: Float. The new brightness value.
channel_axis: Index of axis for channels in the input tensor.
# Returns
Numpy image tensor.
# Raises
ValueError if `brightness_range` isn't a tuple.
"""
x = array_to_img(x)
x = imgenhancer_Brightness = ImageEnhance.Brightness(x)
x = imgenhancer_Brightness.enhance(brightness)
x = img_to_array(x)
return x
def random_brightness(x, brightness_range):
"""Performs a random brightness shift.
......@@ -212,19 +229,14 @@ def random_brightness(x, brightness_range):
# Raises
ValueError if `brightness_range` isn't a tuple.
"""
if len(brightness_range) != 2:
raise ValueError(
'`brightness_range should be tuple or list of two floats. '
'Received: %s' % brightness_range)
x = array_to_img(x)
x = imgenhancer_Brightness = ImageEnhance.Brightness(x)
u = np.random.uniform(brightness_range[0], brightness_range[1])
x = imgenhancer_Brightness.enhance(u)
x = img_to_array(x)
return x
return apply_brightness_shift(x, u)
def transform_matrix_offset_center(matrix, x, y):
......@@ -236,17 +248,22 @@ def transform_matrix_offset_center(matrix, x, y):
return transform_matrix
def apply_transform(x,
transform_matrix,
channel_axis=0,
fill_mode='nearest',
cval=0.):
"""Applies the image transformation specified by a matrix.
def apply_affine_transform(x, theta=0, tx=0, ty=0, shear=0, zx=1, zy=1,
row_axis=0, col_axis=1, channel_axis=2,
fill_mode='nearest', cval=0.):
"""Applies an affine transformation specified by the parameters given.
# Arguments
x: 2D numpy array, single image.
transform_matrix: Numpy array specifying the geometric transformation.
channel_axis: Index of axis for channels in the input tensor.
theta: Rotation angle in degrees.
tx: Width shift.
ty: Heigh shift.
shear: Shear angle in degrees.
zx: Zoom in x direction.
zy: Zoom in y direction
row_axis: Index of axis for rows in the input image.
col_axis: Index of axis for columns in the input image.
channel_axis: Index of axis for channels in the input image.
fill_mode: Points outside the boundaries of the input
are filled according to the given mode
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
......@@ -256,18 +273,50 @@ def apply_transform(x,
# Returns
The transformed version of the input.
"""
x = np.rollaxis(x, channel_axis, 0)
final_affine_matrix = transform_matrix[:2, :2]
final_offset = transform_matrix[:2, 2]
channel_images = [ndi.interpolation.affine_transform(
x_channel,
final_affine_matrix,
final_offset,
order=1,
mode=fill_mode,
cval=cval) for x_channel in x]
x = np.stack(channel_images, axis=0)
x = np.rollaxis(x, 0, channel_axis + 1)
transform_matrix = None
if theta != 0:
theta = np.deg2rad(theta)
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0],
[0, 0, 1]])
transform_matrix = rotation_matrix
if tx != 0 or ty != 0:
shift_matrix = np.array([[1, 0, tx],
[0, 1, ty],
[0, 0, 1]])
transform_matrix = shift_matrix if transform_matrix is None else np.dot(transform_matrix, shift_matrix)
if shear != 0:
shear = np.deg2rad(shear)
shear_matrix = np.array([[1, -np.sin(shear), 0],
[0, np.cos(shear), 0],
[0, 0, 1]])
transform_matrix = shear_matrix if transform_matrix is None else np.dot(transform_matrix, shear_matrix)
if zx != 1 or zy != 1:
zoom_matrix = np.array([[zx, 0, 0],
[0, zy, 0],
[0, 0, 1]])
transform_matrix = zoom_matrix if transform_matrix is None else np.dot(transform_matrix, zoom_matrix)
if transform_matrix is not None:
h, w = x.shape[row_axis], x.shape[col_axis]
transform_matrix = transform_matrix_offset_center(
transform_matrix, h, w)
x = np.rollaxis(x, channel_axis, 0)
final_affine_matrix = transform_matrix[:2, :2]
final_offset = transform_matrix[:2, 2]
channel_images = [ndi.interpolation.affine_transform(
x_channel,
final_affine_matrix,
final_offset,
order=1,
mode=fill_mode,
cval=cval) for x_channel in x]
x = np.stack(channel_images, axis=0)
x = np.rollaxis(x, 0, channel_axis + 1)
return x
......@@ -920,30 +969,27 @@ class ImageDataGenerator(object):
'first by calling `.fit(numpy_data)`.')
return x
def random_transform(self, x, seed=None):
"""Randomly augments a single image tensor.
def get_random_transform(self, img_shape, seed=None):
"""Generates random parameters for a transformation.
# Arguments
x: 3D tensor, single image.
seed: Random seed.
img_shape: Tuple of integers. Shape of the image that is transformed.
# Returns
A randomly transformed version of the input (same shape).
A dictionary containing randomly chosen parameters describing the
transformation.
"""
# x is a single image, so it doesn't have image number at index 0
img_row_axis = self.row_axis - 1
img_col_axis = self.col_axis - 1
img_channel_axis = self.channel_axis - 1
if seed is not None:
np.random.seed(seed)
# Use composition of homographies
# to generate final transform that needs to be applied
if self.rotation_range:
theta = np.deg2rad(np.random.uniform(
theta = np.random.uniform(
-self.rotation_range,
self.rotation_range))
self.rotation_range)
else:
theta = 0
......@@ -955,7 +1001,7 @@ class ImageDataGenerator(object):
tx = np.random.uniform(-self.height_shift_range,
self.height_shift_range)
if np.max(self.height_shift_range) < 1:
tx *= x.shape[img_row_axis]
tx *= img_shape[img_row_axis]
else:
tx = 0
......@@ -967,14 +1013,14 @@ class ImageDataGenerator(object):
ty = np.random.uniform(-self.width_shift_range,
self.width_shift_range)
if np.max(self.width_shift_range) < 1:
ty *= x.shape[img_col_axis]
ty *= img_shape[img_col_axis]
else:
ty = 0
if self.shear_range:
shear = np.deg2rad(np.random.uniform(
shear = np.random.uniform(
-self.shear_range,
self.shear_range))
self.shear_range)
else:
shear = 0
......@@ -985,55 +1031,103 @@ class ImageDataGenerator(object):
self.zoom_range[0],
self.zoom_range[1],
2)
transform_matrix = None
if theta != 0:
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0],
[0, 0, 1]])
transform_matrix = rotation_matrix
if tx != 0 or ty != 0:
shift_matrix = np.array([[1, 0, tx],
[0, 1, ty],
[0, 0, 1]])
transform_matrix = shift_matrix if transform_matrix is None else np.dot(transform_matrix, shift_matrix)
if shear != 0:
shear_matrix = np.array([[1, -np.sin(shear), 0],
[0, np.cos(shear), 0],
[0, 0, 1]])
transform_matrix = shear_matrix if transform_matrix is None else np.dot(transform_matrix, shear_matrix)
if zx != 1 or zy != 1:
zoom_matrix = np.array([[zx, 0, 0],
[0, zy, 0],
[0, 0, 1]])
transform_matrix = zoom_matrix if transform_matrix is None else np.dot(transform_matrix, zoom_matrix)
if transform_matrix is not None:
h, w = x.shape[img_row_axis], x.shape[img_col_axis]
transform_matrix = transform_matrix_offset_center(
transform_matrix, h, w)
x = apply_transform(x, transform_matrix, img_channel_axis,
fill_mode=self.fill_mode, cval=self.cval)
flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip
flip_vertical = (np.random.random() < 0.5) * self.vertical_flip
channel_shift_intensity = None
if self.channel_shift_range != 0:
x = random_channel_shift(x,
self.channel_shift_range,
img_channel_axis)
if self.horizontal_flip:
if np.random.random() < 0.5:
x = flip_axis(x, img_col_axis)
if self.vertical_flip:
if np.random.random() < 0.5:
x = flip_axis(x, img_row_axis)
channel_shift_intensity = np.random.uniform(-self.channel_shift_range,
self.channel_shift_range)
brightness = None
if self.brightness_range is not None:
x = random_brightness(x, self.brightness_range)
if len(self.brightness_range) != 2:
raise ValueError(
'`brightness_range should be tuple or list of two floats. '
'Received: %s' % brightness_range)
brightness = np.random.uniform(self.brightness_range[0],
self.brightness_range[1])
transform_parameters = {'theta': theta,
'tx': tx,
'ty': ty,
'shear': shear,
'zx': zx,
'zy': zy,
'flip_horizontal': flip_horizontal,
'flip_vertical': flip_vertical,
'channel_shift_intensity': channel_shift_intensity,
'brightness': brightness}
return transform_parameters
def apply_transform(self, x, transform_parameters):
"""Applies a transformation to an image according to given parameters.
# Arguments
x: 3D tensor, single image.
transform_parameters: Dictionary with string - parameter pairs
describing the transformation. Currently, the following parameters
from the dictionary are used:
- `'theta'`: Float. Rotation angle in degrees.
- `'tx'`: Float. Shift in the x direction.
- `'ty'`: Float. Shift in the y direction.
- `'shear'`: Float. Shear angle in degrees.
- `'zx'`: Float. Zoom in the x direction.
- `'zy'`: Float. Zoom in the y direction.
- `'flip_horizontal'`: Boolean. Horizontal flip.
- `'flip_vertical'`: Boolean. Vertical flip.
- `'channel_shift_intencity'`: Float. Channel shift intensity.
- `'brightness'`: Float. Brightness shift intensity.
# Returns
A ransformed version of the input (same shape).
"""
# x is a single image, so it doesn't have image number at index 0
img_row_axis = self.row_axis - 1
img_col_axis = self.col_axis - 1
img_channel_axis = self.channel_axis - 1
x = apply_affine_transform(x, transform_parameters.get('theta', 0),
transform_parameters.get('tx', 0),
transform_parameters.get('ty', 0),
transform_parameters.get('shear', 0),
transform_parameters.get('zx', 1),
transform_parameters.get('zy', 1),
row_axis=img_row_axis, col_axis=img_col_axis,
channel_axis=img_channel_axis,
fill_mode=self.fill_mode, cval=self.cval)
if transform_parameters.get('channel_shift_intensity') is not None:
x = apply_channel_shift(x,
transform_parameters['channel_shift_intensity'],
img_channel_axis)
if transform_parameters.get('flip_horizontal', False):
x = flip_axis(x, img_col_axis)
if transform_parameters.get('flip_vertical', False):
x = flip_axis(x, img_row_axis)
if transform_parameters.get('brightness') is not None:
x = apply_brightness_shift(x, transform_parameters['brightness'])
return x
def random_transform(self, x, seed=None):
"""Applies a random transformation to an image.
# Arguments
x: 3D tensor, single image.
seed: Random seed.
# Returns
A randomly transformed version of the input (same shape).
"""
params = self.get_random_transform(x.shape, seed)
return self.apply_transform(x, params)
def fit(self, x,
augment=False,
rounds=1,
......@@ -1314,8 +1408,9 @@ class NumpyArrayIterator(Iterator):
dtype=K.floatx())
for i, j in enumerate(index_array):
x = self.x[j]
x = self.image_data_generator.random_transform(
x.astype(K.floatx()))
params = self.image_data_generator.get_random_transform(x.shape)
x = self.image_data_generator.apply_transform(
x.astype(K.floatx()), params)
x = self.image_data_generator.standardize(x)
batch_x[i] = x
......@@ -1621,7 +1716,8 @@ class DirectoryIterator(Iterator):
target_size=self.target_size,
interpolation=self.interpolation)
x = img_to_array(img, data_format=self.data_format)
x = self.image_data_generator.random_transform(x)
params = self.image_data_generator.get_random_transform(x.shape)
x = self.image_data_generator.apply_transform(x, params)
x = self.image_data_generator.standardize(x)
batch_x[i] = x
# optionally save augmented images to disk for debugging purposes
......
......@@ -452,6 +452,74 @@ class TestImage(object):
assert image.random_zoom(x, (5, 5)).shape == (2, 28, 28)
assert image.random_channel_shift(x, 20).shape == (2, 28, 28)
# Test get_random_transform with predefined seed
seed = 1
generator = image.ImageDataGenerator(
rotation_range=90.,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.5,
zoom_range=0.2,
channel_shift_range=0.1,
brightness_range=(1, 5),
horizontal_flip=True,
vertical_flip=True)
transform_dict = generator.get_random_transform(x.shape, seed)
transform_dict2 = generator.get_random_transform(x.shape, seed * 2)
assert transform_dict['theta'] != 0
assert transform_dict['theta'] != transform_dict2['theta']
assert transform_dict['tx'] != 0
assert transform_dict['tx'] != transform_dict2['tx']
assert transform_dict['ty'] != 0
assert transform_dict['ty'] != transform_dict2['ty']
assert transform_dict['shear'] != 0
assert transform_dict['shear'] != transform_dict2['shear']
assert transform_dict['zx'] != 0
assert transform_dict['zx'] != transform_dict2['zx']
assert transform_dict['zy'] != 0
assert transform_dict['zy'] != transform_dict2['zy']
assert transform_dict['channel_shift_intensity'] != 0
assert transform_dict['channel_shift_intensity'] != transform_dict2['channel_shift_intensity']
assert transform_dict['brightness'] != 0
assert transform_dict['brightness'] != transform_dict2['brightness']
# Test get_random_transform without any randomness
generator = image.ImageDataGenerator()
transform_dict = generator.get_random_transform(x.shape, seed)
assert transform_dict['theta'] == 0
assert transform_dict['tx'] == 0
assert transform_dict['ty'] == 0
assert transform_dict['shear'] == 0
assert transform_dict['zx'] == 1
assert transform_dict['zy'] == 1
assert transform_dict['channel_shift_intensity'] is None
assert transform_dict['brightness'] is None
def test_deterministic_transform(self):
x = np.ones((32, 32, 3))
generator = image.ImageDataGenerator(
rotation_range=90,
fill_mode='constant')
x = np.random.random((32, 32, 3))
assert np.allclose(generator.apply_transform(x, {'flip_vertical': True}),
x[::-1, :, :])
assert np.allclose(generator.apply_transform(x, {'flip_horizontal': True}),
x[:, ::-1, :])
x = np.ones((3, 3, 3))
x_rotated = np.array([[[0., 0., 0.],
[0., 0., 0.],
[1., 1., 1.]],
[[0., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]],
[[0., 0., 0.],
[0., 0., 0.],
[1., 1., 1.]]])
assert np.allclose(generator.apply_transform(x, {'theta': 45}),
x_rotated)
assert np.allclose(image.apply_affine_transform(x, theta=45, channel_axis=2,
fill_mode='constant'), x_rotated)
def test_batch_standardize(self):
# ImageDataGenerator.standardize should work on batches
for test_images in self.all_test_images:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册