未验证 提交 754820fe 编写于 作者: F Feng Ni 提交者: GitHub

[New API] add API paddle.vision.transforms.RandomPerspective and...

[New API] add API paddle.vision.transforms.RandomPerspective and paddle.vision.transforms.perspective (#42390)

* add RandomPerspective and perspective

* fix UT, clean codes

* fix UT

* add batch transform

* remove batch in tensor func

* fix typos and coments, test=develop
上级 cbb8df78
......@@ -172,6 +172,14 @@ class TestTransformsCV2(unittest.TestCase):
])
self.do_transform(trans)
def test_perspective(self):
trans = transforms.Compose([
transforms.RandomPerspective(prob=1.0),
transforms.RandomPerspective(
prob=1.0, distortion_scale=0.9),
])
self.do_transform(trans)
def test_pad(self):
trans = transforms.Compose([transforms.Pad(2)])
self.do_transform(trans)
......@@ -964,6 +972,37 @@ class TestFunctional(unittest.TestCase):
np.testing.assert_equal(rotated_np_img.shape,
np.array(rotated_pil_img).shape)
def test_perspective(self):
np_img = (np.random.rand(32, 26, 3) * 255).astype('uint8')
pil_img = Image.fromarray(np_img).convert('RGB')
tensor_img = F.to_tensor(pil_img, data_format='CHW') * 255
np.testing.assert_almost_equal(
np_img, tensor_img.transpose((1, 2, 0)), decimal=4)
startpoints = [[0, 0], [13, 0], [13, 15], [0, 15]]
endpoints = [[3, 2], [12, 3], [10, 14], [2, 15]]
np_perspectived_img = F.perspective(np_img, startpoints, endpoints)
pil_perspectived_img = F.perspective(pil_img, startpoints, endpoints)
tensor_perspectived_img = F.perspective(tensor_img, startpoints,
endpoints)
np.testing.assert_equal(np_perspectived_img.shape,
np.array(pil_perspectived_img).shape)
np.testing.assert_equal(np_perspectived_img.shape,
tensor_perspectived_img.transpose(
(1, 2, 0)).shape)
result_pil = np.array(pil_perspectived_img)
result_tensor = tensor_perspectived_img.numpy().transpose(
(1, 2, 0)).astype('uint8')
num_diff_pixels = (result_pil != result_tensor).sum() / 3.0
ratio_diff_pixels = num_diff_pixels / result_tensor.shape[
0] / result_tensor.shape[1]
# Tolerance : less than 6% of different pixels
assert ratio_diff_pixels < 0.06
if __name__ == '__main__':
unittest.main()
......@@ -30,6 +30,7 @@ from .transforms import RandomCrop # noqa: F401
from .transforms import Pad # noqa: F401
from .transforms import RandomAffine # noqa: F401
from .transforms import RandomRotation # noqa: F401
from .transforms import RandomPerspective # noqa: F401
from .transforms import Grayscale # noqa: F401
from .transforms import ToTensor # noqa: F401
from .transforms import RandomErasing # noqa: F401
......@@ -40,6 +41,7 @@ from .functional import resize # noqa: F401
from .functional import pad # noqa: F401
from .functional import affine # noqa: F401
from .functional import rotate # noqa: F401
from .functional import perspective # noqa: F401
from .functional import to_grayscale # noqa: F401
from .functional import crop # noqa: F401
from .functional import center_crop # noqa: F401
......@@ -68,6 +70,7 @@ __all__ = [ #noqa
'Pad',
'RandomAffine',
'RandomRotation',
'RandomPerspective',
'Grayscale',
'ToTensor',
'RandomErasing',
......@@ -78,6 +81,7 @@ __all__ = [ #noqa
'pad',
'affine',
'rotate',
'perspective',
'to_grayscale',
'crop',
'center_crop',
......
......@@ -767,6 +767,95 @@ def rotate(img,
return F_cv2.rotate(img, angle, interpolation, expand, center, fill)
def _get_perspective_coeffs(startpoints, endpoints):
"""
get coefficients (a, b, c, d, e, f, g, h) of the perspective transforms.
In Perspective Transform each pixel (x, y) in the original image gets transformed as,
(x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
Args:
startpoints (list[list[int]]): [top-left, top-right, bottom-right, bottom-left] of the original image,
endpoints (list[list[int]]): [top-left, top-right, bottom-right, bottom-left] of the transformed image.
Returns:
output (list): octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
"""
a_matrix = np.zeros((2 * len(startpoints), 8))
for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
a_matrix[2 * i, :] = [
p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]
]
a_matrix[2 * i + 1, :] = [
0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]
]
b_matrix = np.array(startpoints).reshape([8])
res = np.linalg.lstsq(a_matrix, b_matrix)[0]
output = list(res)
return output
def perspective(img, startpoints, endpoints, interpolation='nearest', fill=0):
"""Perform perspective transform of the given image.
Args:
img (PIL.Image|np.array|paddle.Tensor): Image to be transformed.
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
interpolation (str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST
according the backend.
When use pil backend, support method are as following:
- "nearest": Image.NEAREST,
- "bilinear": Image.BILINEAR,
- "bicubic": Image.BICUBIC
When use cv2 backend, support method are as following:
- "nearest": cv2.INTER_NEAREST,
- "bilinear": cv2.INTER_LINEAR,
- "bicubic": cv2.INTER_CUBIC
fill (int|list|tuple, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
Returns:
PIL.Image|np.array|paddle.Tensor: transformed Image.
Examples:
.. code-block:: python
import paddle
from paddle.vision.transforms import functional as F
fake_img = paddle.randn((3, 256, 300)).astype(paddle.float32)
startpoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
endpoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
perspectived_img = F.perspective(fake_img, startpoints, endpoints)
print(perspectived_img.shape)
"""
if not (_is_pil_image(img) or _is_numpy_image(img) or
_is_tensor_image(img)):
raise TypeError(
'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'.
format(type(img)))
if _is_pil_image(img):
coeffs = _get_perspective_coeffs(startpoints, endpoints)
return F_pil.perspective(img, coeffs, interpolation, fill)
elif _is_tensor_image(img):
coeffs = _get_perspective_coeffs(startpoints, endpoints)
return F_t.perspective(img, coeffs, interpolation, fill)
else:
return F_cv2.perspective(img, startpoints, endpoints, interpolation,
fill)
def to_grayscale(img, num_output_channels=1):
"""Converts image to grayscale version of image.
......
......@@ -589,6 +589,56 @@ def rotate(img,
borderValue=fill)
def perspective(img, startpoints, endpoints, interpolation='nearest', fill=0):
"""Perspective the image.
Args:
img (np.array): Image to be perspectived.
startpoints (list[list[int]]): [top-left, top-right, bottom-right, bottom-left] of the original image,
endpoints (list[list[int]]): [top-left, top-right, bottom-right, bottom-left] of the transformed image.
interpolation (int|str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set to cv2.INTER_NEAREST.
when use cv2 backend, support method are as following:
- "nearest": cv2.INTER_NEAREST,
- "bilinear": cv2.INTER_LINEAR,
- "bicubic": cv2.INTER_CUBIC
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively.
Returns:
np.array: Perspectived image.
"""
cv2 = try_import('cv2')
_cv2_interp_from_str = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'area': cv2.INTER_AREA,
'bicubic': cv2.INTER_CUBIC,
'lanczos': cv2.INTER_LANCZOS4
}
h, w = img.shape[0:2]
startpoints = np.array(startpoints, dtype="float32")
endpoints = np.array(endpoints, dtype="float32")
matrix = cv2.getPerspectiveTransform(startpoints, endpoints)
if len(img.shape) == 3 and img.shape[2] == 1:
return cv2.warpPerspective(
img,
matrix,
dsize=(w, h),
flags=_cv2_interp_from_str[interpolation],
borderValue=fill)[:, :, np.newaxis]
else:
return cv2.warpPerspective(
img,
matrix,
dsize=(w, h),
flags=_cv2_interp_from_str[interpolation],
borderValue=fill)
def to_grayscale(img, num_output_channels=1):
"""Converts image to grayscale version of image.
......
......@@ -479,6 +479,33 @@ def rotate(img,
fillcolor=fill)
def perspective(img, coeffs, interpolation="nearest", fill=0):
"""Perspective the image.
Args:
img (PIL.Image): Image to be perspectived.
coeffs (list[float]): coefficients (a, b, c, d, e, f, g, h) of the perspective transforms.
interpolation (str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set to PIL.Image.NEAREST . when use pil backend,
support method are as following:
- "nearest": Image.NEAREST,
- "bilinear": Image.BILINEAR,
- "bicubic": Image.BICUBIC
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively.
Returns:
PIL.Image: Perspectived image.
"""
if isinstance(fill, int):
fill = tuple([fill] * 3)
return img.transform(img.size, Image.PERSPECTIVE, coeffs,
_pil_interp_from_str[interpolation], fill)
def to_grayscale(img, num_output_channels=1):
"""Converts image to grayscale version of image.
......
......@@ -395,6 +395,69 @@ def rotate(img,
return out.squeeze(0)
def _perspective_grid(img, coeffs, ow, oh, dtype):
theta1 = coeffs[:6].reshape([1, 2, 3])
tmp = paddle.tile(coeffs[6:].reshape([1, 2]), repeat_times=[2, 1])
dummy = paddle.ones((2, 1), dtype=dtype)
theta2 = paddle.concat((tmp, dummy), axis=1).unsqueeze(0)
d = 0.5
base_grid = paddle.ones((1, oh, ow, 3), dtype=dtype)
x_grid = paddle.linspace(d, ow * 1.0 + d - 1.0, ow)
base_grid[..., 0] = x_grid
y_grid = paddle.linspace(d, oh * 1.0 + d - 1.0, oh).unsqueeze_(-1)
base_grid[..., 1] = y_grid
scaled_theta1 = theta1.transpose(
(0, 2, 1)) / paddle.to_tensor([0.5 * ow, 0.5 * oh])
output_grid1 = base_grid.reshape((1, oh * ow, 3)).bmm(scaled_theta1)
output_grid2 = base_grid.reshape(
(1, oh * ow, 3)).bmm(theta2.transpose((0, 2, 1)))
output_grid = output_grid1 / output_grid2 - 1.0
return output_grid.reshape((1, oh, ow, 2))
def perspective(img,
coeffs,
interpolation="nearest",
fill=None,
data_format='CHW'):
"""Perspective the image.
Args:
img (paddle.Tensor): Image to be rotated.
coeffs (list[float]): coefficients (a, b, c, d, e, f, g, h) of the perspective transforms.
interpolation (str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set NEAREST. When use pil backend,
support method are as following:
- "nearest"
- "bilinear"
- "bicubic"
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively.
Returns:
paddle.Tensor: Perspectived image.
"""
img = img.unsqueeze(0)
img = img if data_format.lower() == 'chw' else img.transpose((0, 3, 1, 2))
ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if paddle.is_floating_point(img) else paddle.float32
coeffs = paddle.to_tensor(coeffs, place=img.place)
grid = _perspective_grid(img, coeffs, ow=ow, oh=oh, dtype=dtype)
out = _grid_transform(img, grid, mode=interpolation, fill=fill)
out = out if data_format.lower() == 'chw' else out.transpose((0, 2, 3, 1))
return out.squeeze(0)
def vflip(img, data_format='CHW'):
"""Vertically flips the given paddle tensor.
......
......@@ -1481,6 +1481,125 @@ class RandomRotation(BaseTransform):
self.center, self.fill)
class RandomPerspective(BaseTransform):
"""Random perspective transformation with a given probability.
Args:
prob (float, optional): Probability of using transformation, ranges from
0 to 1, default is 0.5.
distortion_scale (float, optional): Degree of distortion, ranges from
0 to 1, default is 0.5.
interpolation (str, optional): Interpolation method. If omitted, or if
the image has only one channel, it is set to PIL.Image.NEAREST or
cv2.INTER_NEAREST.
When use pil backend, support method are as following:
- "nearest": Image.NEAREST,
- "bilinear": Image.BILINEAR,
- "bicubic": Image.BICUBIC
When use cv2 backend, support method are as following:
- "nearest": cv2.INTER_NEAREST,
- "bilinear": cv2.INTER_LINEAR,
- "bicubic": cv2.INTER_CUBIC
fill (int|list|tuple, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None.
Shape:
- img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C).
- output(PIL.Image|np.ndarray|Paddle.Tensor): A perspectived image.
Returns:
A callable object of RandomPerspective.
Examples:
.. code-block:: python
import paddle
from paddle.vision.transforms import RandomPerspective
transform = RandomPerspective(prob=1.0, distortion_scale=0.9)
fake_img = paddle.randn((3, 200, 150)).astype(paddle.float32)
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self,
prob=0.5,
distortion_scale=0.5,
interpolation='nearest',
fill=0,
keys=None):
super(RandomPerspective, self).__init__(keys)
assert 0 <= prob <= 1, "probability must be between 0 and 1"
assert 0 <= distortion_scale <= 1, "distortion_scale must be between 0 and 1"
assert interpolation in ['nearest', 'bilinear', 'bicubic']
assert isinstance(fill, (numbers.Number, str, list, tuple))
self.prob = prob
self.distortion_scale = distortion_scale
self.interpolation = interpolation
self.fill = fill
def get_params(self, width, height, distortion_scale):
"""
Returns:
startpoints (list[list[int]]): [top-left, top-right, bottom-right, bottom-left] of the original image,
endpoints (list[list[int]]): [top-left, top-right, bottom-right, bottom-left] of the transformed image.
"""
half_height = height // 2
half_width = width // 2
topleft = [
int(random.uniform(0, int(distortion_scale * half_width) + 1)),
int(random.uniform(0, int(distortion_scale * half_height) + 1)),
]
topright = [
int(
random.uniform(width - int(distortion_scale * half_width) - 1,
width)),
int(random.uniform(0, int(distortion_scale * half_height) + 1)),
]
botright = [
int(
random.uniform(width - int(distortion_scale * half_width) - 1,
width)),
int(
random.uniform(height - int(distortion_scale * half_height) - 1,
height)),
]
botleft = [
int(random.uniform(0, int(distortion_scale * half_width) + 1)),
int(
random.uniform(height - int(distortion_scale * half_height) - 1,
height)),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1],
[0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
return startpoints, endpoints
def _apply_image(self, img):
"""
Args:
img (PIL.Image|np.array|paddle.Tensor): Image to be Perspectively transformed.
Returns:
PIL.Image|np.array|paddle.Tensor: Perspectively transformed image.
"""
width, height = _get_image_size(img)
if random.random() < self.prob:
startpoints, endpoints = self.get_params(width, height,
self.distortion_scale)
return F.perspective(img, startpoints, endpoints,
self.interpolation, self.fill)
return img
class Grayscale(BaseTransform):
"""Converts image to grayscale.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册