提交 848b2513 编写于 作者: D Dan Ringwalt 提交者: TensorFlower Gardener

Add a tf.contrib.image.compose_transforms function (#781).

Change: 150570754
上级 b089f96e
......@@ -49,6 +49,7 @@ py_library(
":image_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
......
......@@ -16,11 +16,13 @@
### API
This module provides functions for image manipulation; currently, only projective
transforms (including rotation) are supported.
This module provides functions for image manipulation; currently, only
projective transforms (including rotation) are supported.
## Image `Ops`
@@angles_to_projective_transforms
@@compose_transforms
@@rotate
@@transform
"""
......@@ -29,6 +31,8 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=line-too-long
from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms
from tensorflow.contrib.image.python.ops.image_ops import compose_transforms
from tensorflow.contrib.image.python.ops.image_ops import rotate
from tensorflow.contrib.image.python.ops.image_ops import transform
......
......@@ -32,11 +32,10 @@ _DTYPES = set(
[dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
class ImageOpsTestCpu(test_util.TensorFlowTestCase):
_use_gpu = False
class ImageOpsTest(test_util.TensorFlowTestCase):
def test_zeros(self):
with self.test_session(use_gpu=self._use_gpu):
with self.test_session():
for dtype in _DTYPES:
for shape in [(5, 5), (24, 24), (2, 24, 24, 3)]:
for angle in [0, 1, np.pi / 2.0]:
......@@ -46,7 +45,7 @@ class ImageOpsTestCpu(test_util.TensorFlowTestCase):
np.zeros(shape, dtype.as_numpy_dtype()))
def test_rotate_even(self):
with self.test_session(use_gpu=self._use_gpu):
with self.test_session():
for dtype in _DTYPES:
image = array_ops.reshape(
math_ops.cast(math_ops.range(36), dtype), (6, 6))
......@@ -68,7 +67,7 @@ class ImageOpsTestCpu(test_util.TensorFlowTestCase):
[1, 7, 13, 19, 25, 31], [0, 6, 12, 18, 24, 30]]])
def test_rotate_odd(self):
with self.test_session(use_gpu=self._use_gpu):
with self.test_session():
for dtype in _DTYPES:
image = array_ops.reshape(
math_ops.cast(math_ops.range(25), dtype), (5, 5))
......@@ -87,9 +86,29 @@ class ImageOpsTestCpu(test_util.TensorFlowTestCase):
[22, 17, 12, 7, 2], [23, 18, 13, 8, 3],
[24, 19, 14, 9, 4]]])
class ImageOpsTestGpu(ImageOpsTestCpu):
_use_gpu = True
def test_compose(self):
with self.test_session():
for dtype in _DTYPES:
image = constant_op.constant(
[[1, 1, 1, 0],
[1, 0, 0, 0],
[1, 1, 1, 0],
[0, 0, 0, 0]], dtype=dtype)
# Rotate counter-clockwise by pi / 2.
rotation = image_ops.angles_to_projective_transforms(np.pi / 2, 4, 4)
# Translate right by 1 (the transformation matrix is always inverted,
# hence the -1).
translation = constant_op.constant([1, 0, -1,
0, 1, 0,
0, 0],
dtype=dtypes.float32)
composed = image_ops.compose_transforms(rotation, translation)
image_transformed = image_ops.transform(image, composed)
self.assertAllEqual(image_transformed.eval(),
[[0, 0, 0, 0],
[0, 1, 0, 1],
[0, 1, 0, 1],
[0, 1, 1, 1]])
if __name__ == "__main__":
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.contrib.util import loader
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
......@@ -52,8 +53,6 @@ def rotate(images, angles):
TypeError: If `image` is an invalid type.
"""
image_or_images = ops.convert_to_tensor(images, name="images")
angle_or_angles = ops.convert_to_tensor(
angles, name="angles", dtype=dtypes.float32)
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
if len(image_or_images.get_shape()) == 2:
......@@ -65,14 +64,40 @@ def rotate(images, angles):
else:
raise TypeError("Images should have rank between 2 and 4.")
image_height = math_ops.cast(array_ops.shape(images)[1], dtypes.float32)[None]
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
output = transform(
images,
angles_to_projective_transforms(angles, image_width, image_height))
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
return output[0, :, :, :]
else:
return output
def angles_to_projective_transforms(angles, image_height, image_width):
"""Returns projective transform(s) for the given angle(s).
Args:
angles: A scalar angle to rotate all images by, or (for batches of images)
a vector with an angle to rotate each image in the batch.
image_height: Height of the image(s) to be transformed.
image_width: Width of the image(s) to be transformed.
Returns:
A tensor of shape (num_images, 8). Projective transforms which can be given
to `tf.contrib.image.transform`.
"""
angle_or_angles = ops.convert_to_tensor(
angles, name="angles", dtype=dtypes.float32)
if len(angle_or_angles.get_shape()) == 0: # pylint: disable=g-explicit-length-test
angles = angle_or_angles[None]
elif len(angle_or_angles.get_shape()) == 1:
angles = angle_or_angles
else:
raise TypeError("Angles should have rank 0 or 1.")
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
image_height = math_ops.cast(array_ops.shape(images)[1], dtypes.float32)[None]
x_offset = ((image_width - 1) - (math_ops.cos(angles) *
(image_width - 1) - math_ops.sin(angles) *
(image_height - 1))) / 2.0
......@@ -80,7 +105,7 @@ def rotate(images, angles):
(image_width - 1) + math_ops.cos(angles) *
(image_height - 1))) / 2.0
num_angles = array_ops.shape(angles)[0]
transforms = array_ops.concat(
return array_ops.concat(
values=[
math_ops.cos(angles)[:, None],
-math_ops.sin(angles)[:, None],
......@@ -91,14 +116,6 @@ def rotate(images, angles):
array_ops.zeros((num_angles, 2), dtypes.float32),
],
axis=1)
# pylint: disable=protected-access
output = transform(images, transforms)
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
return output[0, :, :, :]
else:
return output
def transform(images, transforms):
......@@ -113,7 +130,8 @@ def transform(images, transforms):
[a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point
`(x, y)` to a transformed *input* point
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
where `k = c0 x + c1 y + 1`.
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
the transform mapping input points to output points.
Returns:
Image(s) with the same type and shape as `images`, with the given
......@@ -153,4 +171,46 @@ def transform(images, transforms):
return output
def compose_transforms(*transforms):
"""Composes the transforms tensors.
Args:
*transforms: List of image projective transforms to be composed. Each
transform is length 8 (single transform) or shape (N, 8) (batched
transforms). The shapes of all inputs must be equal, and at least one
input must be given.
Returns:
A composed transform tensor. When passed to `tf.contrib.image.transform`,
equivalent to applying each of the given transforms to the image in
order.
"""
assert transforms, "transforms cannot be empty"
composed = _flat_transforms_to_matrices(transforms[0])
for tr in transforms[1:]:
# Multiply batches of matrices.
composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr))
return _transform_matrices_to_flat(composed)
def _flat_transforms_to_matrices(transforms):
# Make the transform(s) 2D in case the input is a single transform.
transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8]))
num_transforms = array_ops.shape(transforms)[0]
# Add a column of ones for the implicit last entry in the matrix.
return array_ops.reshape(
array_ops.concat(
[transforms, array_ops.ones([num_transforms, 1])], axis=1),
constant_op.constant([-1, 3, 3]))
def _transform_matrices_to_flat(transform_matrices):
# Flatten each matrix.
transforms = array_ops.reshape(
transform_matrices, constant_op.constant([-1, 9]))
# Divide each matrix by the last entry (normally 1).
transforms /= transforms[:, 8:9]
return transforms[:, :8]
ops.NotDifferentiable("ImageProjectiveTransform")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册