未验证 提交 8dea7bed 编写于 作者: L LielinJiang 提交者: GitHub

Add some transform apis (#25357)

* add more vision transfrom apis
上级 417b2439
......@@ -23,6 +23,7 @@ import numpy as np
from paddle.incubate.hapi.datasets import DatasetFolder
from paddle.incubate.hapi.vision.transforms import transforms
import paddle.incubate.hapi.vision.transforms.functional as F
class TestTransforms(unittest.TestCase):
......@@ -100,6 +101,78 @@ class TestTransforms(unittest.TestCase):
])
self.do_transform(trans)
def test_rotate(self):
trans = transforms.Compose([
transforms.RandomRotate(90),
transforms.RandomRotate([-10, 10]),
transforms.RandomRotate(
45, expand=True),
transforms.RandomRotate(
10, expand=True, center=(60, 80)),
])
self.do_transform(trans)
def test_pad(self):
trans = transforms.Compose([transforms.Pad(2)])
self.do_transform(trans)
fake_img = np.random.rand(200, 150, 3).astype('float32')
trans_pad = transforms.Pad(10)
fake_img_padded = trans_pad(fake_img)
np.testing.assert_equal(fake_img_padded.shape, (220, 170, 3))
trans_pad1 = transforms.Pad([1, 2])
trans_pad2 = transforms.Pad([1, 2, 3, 4])
img = trans_pad1(fake_img)
img = trans_pad2(img)
def test_erase(self):
trans = transforms.Compose(
[transforms.RandomErasing(), transforms.RandomErasing(value=0.0)])
self.do_transform(trans)
def test_random_crop(self):
trans = transforms.Compose([
transforms.RandomCrop(200),
transforms.RandomCrop((140, 160)),
])
self.do_transform(trans)
trans_random_crop1 = transforms.RandomCrop(224)
trans_random_crop2 = transforms.RandomCrop((140, 160))
fake_img = np.random.rand(500, 400, 3).astype('float32')
fake_img_crop1 = trans_random_crop1(fake_img)
fake_img_crop2 = trans_random_crop2(fake_img_crop1)
np.testing.assert_equal(fake_img_crop1.shape, (224, 224, 3))
np.testing.assert_equal(fake_img_crop2.shape, (140, 160, 3))
trans_random_crop_same = transforms.RandomCrop((140, 160))
img = trans_random_crop_same(fake_img_crop2)
trans_random_crop_bigger = transforms.RandomCrop((180, 200))
img = trans_random_crop_bigger(img)
trans_random_crop_pad = transforms.RandomCrop((224, 256), 2, True)
img = trans_random_crop_pad(img)
def test_grayscale(self):
trans = transforms.Compose([transforms.Grayscale()])
self.do_transform(trans)
trans_gray = transforms.Grayscale()
fake_img = np.random.rand(500, 400, 3).astype('float32')
fake_img_gray = trans_gray(fake_img)
np.testing.assert_equal(len(fake_img_gray.shape), 2)
np.testing.assert_equal(fake_img_gray.shape[0], 500)
np.testing.assert_equal(fake_img_gray.shape[1], 400)
trans_gray3 = transforms.Grayscale(3)
fake_img = np.random.rand(500, 400, 3).astype('float32')
fake_img_gray = trans_gray3(fake_img)
def test_exception(self):
trans = transforms.Compose([transforms.Resize(-1)])
......@@ -123,6 +196,36 @@ class TestTransforms(unittest.TestCase):
with self.assertRaises(ValueError):
transforms.BrightnessTransform(-1.0)
with self.assertRaises(ValueError):
transforms.Pad([1.0, 2.0, 3.0])
with self.assertRaises(TypeError):
fake_img = np.random.rand(100, 120, 3).astype('float32')
F.pad(fake_img, '1')
with self.assertRaises(TypeError):
fake_img = np.random.rand(100, 120, 3).astype('float32')
F.pad(fake_img, 1, {})
with self.assertRaises(TypeError):
fake_img = np.random.rand(100, 120, 3).astype('float32')
F.pad(fake_img, 1, padding_mode=-1)
with self.assertRaises(ValueError):
fake_img = np.random.rand(100, 120, 3).astype('float32')
F.pad(fake_img, [1.0, 2.0, 3.0])
with self.assertRaises(ValueError):
transforms.RandomRotate(-2)
with self.assertRaises(ValueError):
transforms.RandomRotate([1, 2, 3])
with self.assertRaises(ValueError):
trans_gray = transforms.Grayscale(5)
fake_img = np.random.rand(100, 120, 3).astype('float32')
trans_gray(fake_img)
def test_info(self):
str(transforms.Compose([transforms.Resize((224, 224))]))
str(transforms.BatchCompose([transforms.Resize((224, 224))]))
......
......@@ -15,8 +15,10 @@
import sys
import collections
import random
import math
import cv2
import numbers
import numpy as np
if sys.version_info < (3, 3):
......@@ -26,7 +28,7 @@ else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
__all__ = ['flip', 'resize']
__all__ = ['flip', 'resize', 'pad', 'rotate', 'to_grayscale']
def flip(image, code):
......@@ -99,3 +101,202 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR):
return cv2.resize(img, (ow, oh), interpolation=interpolation)
else:
return cv2.resize(img, size[::-1], interpolation=interpolation)
def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'):
"""Pads the given CV Image on all sides with speficified padding mode and fill value.
Args:
img (np.ndarray): Image to be padded.
padding (int|tuple): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill (int|tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
``constant`` means padding with a constant value, this value is specified with fill.
``edge`` means padding with the last value at the edge of the image.
``reflect`` means padding with reflection of image (without repeating the last value on the edge)
padding ``[1, 2, 3, 4]`` with 2 elements on both sides in reflect mode
will result in ``[3, 2, 1, 2, 3, 4, 3, 2]``.
``symmetric`` menas pads with reflection of image (repeating the last value on the edge)
padding ``[1, 2, 3, 4]`` with 2 elements on both sides in symmetric mode
will result in ``[2, 1, 1, 2, 3, 4, 4, 3]``.
Returns:
numpy ndarray: Padded image.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms.functional import pad
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = pad(fake_img, 2)
print(fake_img.shape)
"""
if not isinstance(padding, (numbers.Number, list, tuple)):
raise TypeError('Got inappropriate padding arg')
if not isinstance(fill, (numbers.Number, str, list, tuple)):
raise TypeError('Got inappropriate fill arg')
if not isinstance(padding_mode, str):
raise TypeError('Got inappropriate padding_mode arg')
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
raise ValueError(
"Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
'Expected padding mode be either constant, edge, reflect or symmetric, but got {}'.format(padding_mode)
PAD_MOD = {
'constant': cv2.BORDER_CONSTANT,
'edge': cv2.BORDER_REPLICATE,
'reflect': cv2.BORDER_DEFAULT,
'symmetric': cv2.BORDER_REFLECT
}
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, collections.Sequence) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, collections.Sequence) and len(padding) == 4:
pad_left, pad_top, pad_right, pad_bottom = padding
if isinstance(fill, numbers.Number):
fill = (fill, ) * (2 * len(img.shape) - 3)
if padding_mode == 'constant':
assert (len(fill) == 3 and len(img.shape) == 3) or (len(fill) == 1 and len(img.shape) == 2), \
'channel of image is {} but length of fill is {}'.format(img.shape[-1], len(fill))
img = cv2.copyMakeBorder(
src=img,
top=pad_top,
bottom=pad_bottom,
left=pad_left,
right=pad_right,
borderType=PAD_MOD[padding_mode],
value=fill)
return img
def rotate(img,
angle,
interpolation=cv2.INTER_LINEAR,
expand=False,
center=None):
"""Rotates the image by angle.
Args:
img (numpy.ndarray): Image to be rotated.
angle (float|int): In degrees clockwise order.
interpolation (int, optional):
interpolation: Interpolation method.
expand (bool|optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (2-tuple|optional): Optional center of rotation.
Origin is the upper left corner.
Default is the center of the image.
Returns:
numpy ndarray: Rotated image.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms.functional import rotate
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = rotate(fake_img, 10)
print(fake_img.shape)
"""
dtype = img.dtype
h, w, _ = img.shape
point = center or (w / 2, h / 2)
M = cv2.getRotationMatrix2D(point, angle=-angle, scale=1)
if expand:
if center is None:
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
nW = int((h * sin) + (w * cos))
nH = int((h * cos) + (w * sin))
M[0, 2] += (nW / 2) - point[0]
M[1, 2] += (nH / 2) - point[1]
dst = cv2.warpAffine(img, M, (nW, nH))
else:
xx = []
yy = []
for point in (np.array([0, 0, 1]), np.array([w - 1, 0, 1]),
np.array([w - 1, h - 1, 1]), np.array([0, h - 1, 1])):
target = np.dot(M, point)
xx.append(target[0])
yy.append(target[1])
nh = int(math.ceil(max(yy)) - math.floor(min(yy)))
nw = int(math.ceil(max(xx)) - math.floor(min(xx)))
M[0, 2] += (nw - w) / 2
M[1, 2] += (nh - h) / 2
dst = cv2.warpAffine(img, M, (nw, nh), flags=interpolation)
else:
dst = cv2.warpAffine(img, M, (w, h), flags=interpolation)
return dst.astype(dtype)
def to_grayscale(img, num_output_channels=1):
"""Converts image to grayscale version of image.
Args:
img (numpy.ndarray): Image to be converted to grayscale.
Returns:
numpy.ndarray: Grayscale version of the image.
if num_output_channels == 1, returned image is single channel
if num_output_channels == 3, returned image is 3 channel with r == g == b
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms.functional import to_grayscale
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = to_grayscale(fake_img)
print(fake_img.shape)
"""
if num_output_channels == 1:
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
elif num_output_channels == 3:
img = cv2.cvtColor(
cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB)
else:
raise ValueError('num_output_channels should be either 1 or 3')
return img
......@@ -52,6 +52,11 @@ __all__ = [
"ContrastTransform",
"HueTransform",
"ColorJitter",
"RandomCrop",
"RandomErasing",
"Pad",
"RandomRotate",
"Grayscale",
]
......@@ -756,17 +761,13 @@ class ColorJitter(object):
Args:
brightness: How much to jitter brightness.
Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. Should be non negative numbers.
contrast: How much to jitter contrast.
Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non negative numbers.
Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. Should be non negative numbers.
saturation: How much to jitter saturation.
Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. Should be non negative numbers.
hue: How much to jitter hue.
Chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
Chosen uniformly from [-hue, hue]. Should have 0<= hue <= 0.5.
Examples:
......@@ -800,3 +801,342 @@ class ColorJitter(object):
def __call__(self, img):
return self.transforms(img)
class RandomCrop(object):
"""Crops the given CV Image at a random location.
Args:
size (sequence|int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int|sequence|optional): Optional padding on each border
of the image. If a sequence of length 4 is provided, it is used to pad left,
top, right, bottom borders respectively. Default: 0.
pad_if_needed (boolean|optional): It will pad the image if smaller than the
desired size to avoid raising an exception. Default: False.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import RandomCrop
transform = RandomCrop(224)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, size, padding=0, pad_if_needed=False):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
def _get_params(self, img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (numpy.ndarray): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
h, w, _ = img.shape
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
try:
i = random.randint(0, h - th)
except ValueError:
i = random.randint(h - th, 0)
try:
j = random.randint(0, w - tw)
except ValueError:
j = random.randint(w - tw, 0)
return i, j, th, tw
def __call__(self, img):
"""
Args:
img (numpy.ndarray): Image to be cropped.
Returns:
numpy.ndarray: Cropped image.
"""
if self.padding > 0:
img = F.pad(img, self.padding)
# pad the width if needed
if self.pad_if_needed and img.shape[1] < self.size[1]:
img = F.pad(img, (int((1 + self.size[1] - img.shape[1]) / 2), 0))
# pad the height if needed
if self.pad_if_needed and img.shape[0] < self.size[0]:
img = F.pad(img, (0, int((1 + self.size[0] - img.shape[0]) / 2)))
i, j, h, w = self._get_params(img, self.size)
return img[i:i + h, j:j + w]
class RandomErasing(object):
"""Randomly selects a rectangle region in an image and erases its pixels.
``Random Erasing Data Augmentation`` by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf
Args:
prob (float): probability that the random erasing operation will be performed.
scale (tuple): range of proportion of erased area against input image. Should be (min, max).
ratio (float): range of aspect ratio of erased area.
value (float|list|tuple): erasing value. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively. Default: 0.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import RandomCrop
transform = RandomCrop(224)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self,
prob=0.5,
scale=(0.02, 0.4),
ratio=0.3,
value=[0., 0., 0.]):
assert isinstance(value, (
float, Sequence
)), "Expected type of value in [float, list, tupue], but got {}".format(
type(value))
assert scale[0] <= scale[1], "scale range should be of kind (min, max)!"
if isinstance(value, float):
self.value = [value, value, value]
else:
self.value = value
self.p = prob
self.scale = scale
self.ratio = ratio
def __call__(self, img):
if random.uniform(0, 1) > self.p:
return img
for _ in range(100):
area = img.shape[0] * img.shape[1]
target_area = random.uniform(self.scale[0], self.scale[1]) * area
aspect_ratio = random.uniform(self.ratio, 1 / self.ratio)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img.shape[1] and h < img.shape[0]:
x1 = random.randint(0, img.shape[0] - h)
y1 = random.randint(0, img.shape[1] - w)
if len(img.shape) == 3 and img.shape[2] == 3:
img[x1:x1 + h, y1:y1 + w, 0] = self.value[0]
img[x1:x1 + h, y1:y1 + w, 1] = self.value[1]
img[x1:x1 + h, y1:y1 + w, 2] = self.value[2]
else:
img[x1:x1 + h, y1:y1 + w] = self.value[1]
return img
return img
class Pad(object):
"""Pads the given CV Image on all sides with the given "pad" value.
Args:
padding (int|list|tuple): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill (int|list|tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
``constant`` means pads with a constant value, this value is specified with fill.
``edge`` means pads with the last value at the edge of the image.
``reflect`` means pads with reflection of image (without repeating the last value on the edge)
padding ``[1, 2, 3, 4]`` with 2 elements on both sides in reflect mode
will result in ``[3, 2, 1, 2, 3, 4, 3, 2]``.
``symmetric`` menas pads with reflection of image (repeating the last value on the edge)
padding ``[1, 2, 3, 4]`` with 2 elements on both sides in symmetric mode
will result in ``[2, 1, 1, 2, 3, 4, 4, 3]``.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import Pad
transform = Pad(2)
fake_img = np.random.rand(500, 500, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, padding, fill=0, padding_mode='constant'):
assert isinstance(padding, (numbers.Number, list, tuple))
assert isinstance(fill, (numbers.Number, str, list, tuple))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
if isinstance(padding,
collections.Sequence) and len(padding) not in [2, 4]:
raise ValueError(
"Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
def __call__(self, img):
"""
Args:
img (numpy.ndarray): Image to be padded.
Returns:
numpy.ndarray: Padded image.
"""
return F.pad(img, self.padding, self.fill, self.padding_mode)
class RandomRotate(object):
"""Rotates the image by angle.
Args:
degrees (sequence or float or int): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees) clockwise order.
interpolation (int|optional): Interpolation mode of resize. Default: cv2.INTER_LINEAR.
expand (bool|optional): Optional expansion flag. Default: False.
If true, expands the output to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (2-tuple|optional): Optional center of rotation.
Origin is the upper left corner.
Default is the center of the image.
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import RandomRotate
transform = RandomRotate(90)
fake_img = np.random.rand(500, 400, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self,
degrees,
interpolation=cv2.INTER_LINEAR,
expand=False,
center=None):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError(
"If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError(
"If degrees is a sequence, it must be of len 2.")
self.degrees = degrees
self.interpolation = interpolation
self.expand = expand
self.center = center
def _get_params(self, degrees):
"""Get parameters for ``rotate`` for a random rotation.
Returns:
sequence: params to be passed to ``rotate`` for random rotation.
"""
angle = random.uniform(degrees[0], degrees[1])
return angle
def __call__(self, img):
"""
img (np.ndarray): Image to be rotated.
Returns:
np.ndarray: Rotated image.
"""
angle = self._get_params(self.degrees)
return F.rotate(img, angle, self.interpolation, self.expand,
self.center)
class Grayscale(object):
"""Converts image to grayscale.
Args:
output_channels (int): (1 or 3) number of channels desired for output image
Returns:
CV Image: Grayscale version of the input.
- If output_channels == 1 : returned image is single channel
- If output_channels == 3 : returned image is 3 channel with r == g == b
Examples:
.. code-block:: python
import numpy as np
from paddle.incubate.hapi.vision.transforms import Grayscale
transform = Grayscale()
fake_img = np.random.rand(500, 400, 3).astype('float32')
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self, output_channels=1):
self.output_channels = output_channels
def __call__(self, img):
"""
Args:
img (numpy.ndarray): Image to be converted to grayscale.
Returns:
numpy.ndarray: Randomly grayscaled image.
"""
return F.to_grayscale(img, num_output_channels=self.output_channels)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册