未验证 提交 ec2fad4d 编写于 作者: L LielinJiang 提交者: GitHub

Fix rotation bug when use cv2 backend (#29933)

* fix cv2 rotation
上级 95861223
...@@ -444,6 +444,16 @@ class TestFunctional(unittest.TestCase): ...@@ -444,6 +444,16 @@ class TestFunctional(unittest.TestCase):
os.remove(path) os.remove(path)
def test_rotate(self):
np_img = (np.random.rand(28, 28, 3) * 255).astype('uint8')
pil_img = Image.fromarray(np_img).convert('RGB')
rotated_np_img = F.rotate(np_img, 80, expand=True)
rotated_pil_img = F.rotate(pil_img, 80, expand=True)
np.testing.assert_equal(rotated_np_img.shape,
np.array(rotated_pil_img).shape)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -512,14 +512,19 @@ def adjust_hue(img, hue_factor): ...@@ -512,14 +512,19 @@ def adjust_hue(img, hue_factor):
return F_cv2.adjust_hue(img, hue_factor) return F_cv2.adjust_hue(img, hue_factor)
def rotate(img, angle, resample=False, expand=False, center=None, fill=0): def rotate(img,
angle,
interpolation="nearest",
expand=False,
center=None,
fill=0):
"""Rotates the image by angle. """Rotates the image by angle.
Args: Args:
img (PIL.Image|np.array): Image to be rotated. img (PIL.Image|np.array): Image to be rotated.
angle (float or int): In degrees degrees counter clockwise order. angle (float or int): In degrees degrees counter clockwise order.
resample (int|str, optional): An optional resampling filter. If omitted, or if the interpolation (str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST
according the backend. when use pil backend, support method are as following: according the backend. when use pil backend, support method are as following:
- "nearest": Image.NEAREST, - "nearest": Image.NEAREST,
...@@ -564,9 +569,9 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0): ...@@ -564,9 +569,9 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0):
format(type(img))) format(type(img)))
if _is_pil_image(img): if _is_pil_image(img):
return F_pil.rotate(img, angle, resample, expand, center, fill) return F_pil.rotate(img, angle, interpolation, expand, center, fill)
else: else:
return F_cv2.rotate(img, angle, resample, expand, center, fill) return F_cv2.rotate(img, angle, interpolation, expand, center, fill)
def to_grayscale(img, num_output_channels=1): def to_grayscale(img, num_output_channels=1):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import division from __future__ import division
import sys import sys
import math
import numbers import numbers
import warnings import warnings
import collections import collections
...@@ -407,13 +408,18 @@ def adjust_hue(img, hue_factor): ...@@ -407,13 +408,18 @@ def adjust_hue(img, hue_factor):
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype) return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
def rotate(img, angle, resample=False, expand=False, center=None, fill=0): def rotate(img,
angle,
interpolation='nearest',
expand=False,
center=None,
fill=0):
"""Rotates the image by angle. """Rotates the image by angle.
Args: Args:
img (np.array): Image to be rotated. img (np.array): Image to be rotated.
angle (float or int): In degrees degrees counter clockwise order. angle (float or int): In degrees degrees counter clockwise order.
resample (int|str, optional): An optional resampling filter. If omitted, or if the interpolation (int|str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set to cv2.INTER_NEAREST. image has only one channel, it is set to cv2.INTER_NEAREST.
when use cv2 backend, support method are as following: when use cv2 backend, support method are as following:
- "nearest": cv2.INTER_NEAREST, - "nearest": cv2.INTER_NEAREST,
...@@ -434,15 +440,70 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0): ...@@ -434,15 +440,70 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0):
""" """
cv2 = try_import('cv2') cv2 = try_import('cv2')
_cv2_interp_from_str = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'area': cv2.INTER_AREA,
'bicubic': cv2.INTER_CUBIC,
'lanczos': cv2.INTER_LANCZOS4
}
rows, cols = img.shape[0:2] h, w = img.shape[0:2]
if center is None: if center is None:
center = (cols / 2, rows / 2) center = (w / 2.0, h / 2.0)
M = cv2.getRotationMatrix2D(center, angle, 1) M = cv2.getRotationMatrix2D(center, angle, 1)
if expand:
def transform(x, y, matrix):
(a, b, c, d, e, f) = matrix
return a * x + b * y + c, d * x + e * y + f
# calculate output size
xx = []
yy = []
angle = -math.radians(angle)
expand_matrix = [
round(math.cos(angle), 15),
round(math.sin(angle), 15),
0.0,
round(-math.sin(angle), 15),
round(math.cos(angle), 15),
0.0,
]
post_trans = (0, 0)
expand_matrix[2], expand_matrix[5] = transform(
-center[0] - post_trans[0], -center[1] - post_trans[1],
expand_matrix)
expand_matrix[2] += center[0]
expand_matrix[5] += center[1]
for x, y in ((0, 0), (w, 0), (w, h), (0, h)):
x, y = transform(x, y, expand_matrix)
xx.append(x)
yy.append(y)
nw = math.ceil(max(xx)) - math.floor(min(xx))
nh = math.ceil(max(yy)) - math.floor(min(yy))
M[0, 2] += (nw - w) * 0.5
M[1, 2] += (nh - h) * 0.5
w, h = int(nw), int(nh)
if len(img.shape) == 3 and img.shape[2] == 1: if len(img.shape) == 3 and img.shape[2] == 1:
return cv2.warpAffine(img, M, (cols, rows))[:, :, np.newaxis] return cv2.warpAffine(
img,
M, (w, h),
flags=_cv2_interp_from_str[interpolation],
borderValue=fill)[:, :, np.newaxis]
else: else:
return cv2.warpAffine(img, M, (cols, rows)) return cv2.warpAffine(
img,
M, (w, h),
flags=_cv2_interp_from_str[interpolation],
borderValue=fill)
def to_grayscale(img, num_output_channels=1): def to_grayscale(img, num_output_channels=1):
......
...@@ -396,13 +396,18 @@ def adjust_hue(img, hue_factor): ...@@ -396,13 +396,18 @@ def adjust_hue(img, hue_factor):
return img return img
def rotate(img, angle, resample=False, expand=False, center=None, fill=0): def rotate(img,
angle,
interpolation="nearest",
expand=False,
center=None,
fill=0):
"""Rotates the image by angle. """Rotates the image by angle.
Args: Args:
img (PIL.Image): Image to be rotated. img (PIL.Image): Image to be rotated.
angle (float or int): In degrees degrees counter clockwise order. angle (float or int): In degrees degrees counter clockwise order.
resample (int|str, optional): An optional resampling filter. If omitted, or if the interpolation (str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set to PIL.Image.NEAREST . when use pil backend, image has only one channel, it is set to PIL.Image.NEAREST . when use pil backend,
support method are as following: support method are as following:
- "nearest": Image.NEAREST, - "nearest": Image.NEAREST,
...@@ -426,7 +431,12 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0): ...@@ -426,7 +431,12 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0):
if isinstance(fill, int): if isinstance(fill, int):
fill = tuple([fill] * 3) fill = tuple([fill] * 3)
return img.rotate(angle, resample, expand, center, fillcolor=fill) return img.rotate(
angle,
_pil_interp_from_str[interpolation],
expand,
center,
fillcolor=fill)
def to_grayscale(img, num_output_channels=1): def to_grayscale(img, num_output_channels=1):
......
...@@ -1093,8 +1093,7 @@ class RandomRotation(BaseTransform): ...@@ -1093,8 +1093,7 @@ class RandomRotation(BaseTransform):
degrees (sequence or float or int): Range of degrees to select from. degrees (sequence or float or int): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees) clockwise order. will be (-degrees, +degrees) clockwise order.
interpolation (int|str, optional): Interpolation method. Default: 'bilinear'. interpolation (str, optional): Interpolation method. If omitted, or if the
resample (int|str, optional): An optional resampling filter. If omitted, or if the
image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST
according the backend. when use pil backend, support method are as following: according the backend. when use pil backend, support method are as following:
- "nearest": Image.NEAREST, - "nearest": Image.NEAREST,
...@@ -1131,7 +1130,7 @@ class RandomRotation(BaseTransform): ...@@ -1131,7 +1130,7 @@ class RandomRotation(BaseTransform):
def __init__(self, def __init__(self,
degrees, degrees,
resample=False, interpolation='nearest',
expand=False, expand=False,
center=None, center=None,
fill=0, fill=0,
...@@ -1148,7 +1147,7 @@ class RandomRotation(BaseTransform): ...@@ -1148,7 +1147,7 @@ class RandomRotation(BaseTransform):
self.degrees = degrees self.degrees = degrees
super(RandomRotation, self).__init__(keys) super(RandomRotation, self).__init__(keys)
self.resample = resample self.interpolation = interpolation
self.expand = expand self.expand = expand
self.center = center self.center = center
self.fill = fill self.fill = fill
...@@ -1169,8 +1168,8 @@ class RandomRotation(BaseTransform): ...@@ -1169,8 +1168,8 @@ class RandomRotation(BaseTransform):
angle = self._get_param(self.degrees) angle = self._get_param(self.degrees)
return F.rotate(img, angle, self.resample, self.expand, self.center, return F.rotate(img, angle, self.interpolation, self.expand,
self.fill) self.center, self.fill)
class Grayscale(BaseTransform): class Grayscale(BaseTransform):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册