未验证 提交 cbb8df78 编写于 作者: F Feng Ni 提交者: GitHub

[New API] add API paddle.vision.transforms.RandomAffine and...

[New API] add API paddle.vision.transforms.RandomAffine and paddle.vision.transforms.affine (#42278)

* add affine codes

* adjustment codes

* fix test case

* fix F_cv2.affine

* clean codes, add UT

* fix UT

* fix UT

* fix UT shear

* add functional test_errors

* fix typos and coments, test=develop
上级 5d55ebde
......@@ -123,6 +123,44 @@ class TestTransformsCV2(unittest.TestCase):
])
self.do_transform(trans)
def test_affine(self):
trans = transforms.Compose([
transforms.RandomAffine(90),
transforms.RandomAffine(
[-10, 10], translate=[0.1, 0.3]),
transforms.RandomAffine(
45, translate=[0.2, 0.2], scale=[0.2, 0.5]),
transforms.RandomAffine(
10, translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 10]),
transforms.RandomAffine(
10,
translate=[0.5, 0.3],
scale=[0.7, 1.3],
shear=[-10, 10, 20, 40]),
transforms.RandomAffine(
10,
translate=[0.5, 0.3],
scale=[0.7, 1.3],
shear=[-10, 10, 20, 40],
interpolation='bilinear'),
transforms.RandomAffine(
10,
translate=[0.5, 0.3],
scale=[0.7, 1.3],
shear=[-10, 10, 20, 40],
interpolation='bilinear',
fill=114),
transforms.RandomAffine(
10,
translate=[0.5, 0.3],
scale=[0.7, 1.3],
shear=[-10, 10, 20, 40],
interpolation='bilinear',
fill=114,
center=(60, 80)),
])
self.do_transform(trans)
def test_rotate(self):
trans = transforms.Compose([
transforms.RandomRotation(90),
......@@ -278,6 +316,35 @@ class TestTransformsCV2(unittest.TestCase):
tensor_img = paddle.rand((3, 100, 100))
F.pad(tensor_img, [1.0, 2.0, 3.0])
with self.assertRaises(ValueError):
transforms.RandomAffine(-10)
with self.assertRaises(ValueError):
transforms.RandomAffine([-30, 60], translate=[2, 2])
with self.assertRaises(ValueError):
transforms.RandomAffine(10, translate=[0.2, 0.2], scale=[1, 2, 3]),
with self.assertRaises(ValueError):
transforms.RandomAffine(
10, translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[1, 2, 3]),
with self.assertRaises(ValueError):
transforms.RandomAffine(
10,
translate=[0.5, 0.3],
scale=[0.7, 1.3],
shear=[-10, 10, 0, 20, 40])
with self.assertRaises(ValueError):
transforms.RandomAffine(
10,
translate=[0.5, 0.3],
scale=[0.7, 1.3],
shear=[-10, 10, 20, 40],
fill=114,
center=(1, 2, 3))
with self.assertRaises(ValueError):
transforms.RandomRotation(-2)
......@@ -479,6 +546,29 @@ class TestTransformsTensor(TestTransformsCV2):
tensor_img = paddle.rand((3, 100, 100))
F.pad(tensor_img, [1.0, 2.0, 3.0])
with self.assertRaises(ValueError):
transforms.RandomAffine(-10)
with self.assertRaises(ValueError):
transforms.RandomAffine([-30, 60], translate=[2, 2])
with self.assertRaises(ValueError):
transforms.RandomAffine(10, translate=[0.2, 0.2], scale=[-2, -1]),
with self.assertRaises(ValueError):
transforms.RandomAffine(10, translate=[0.2, 0.2], scale=[1, 2, 3]),
with self.assertRaises(ValueError):
transforms.RandomAffine(
10, translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[1, 2, 3]),
with self.assertRaises(ValueError):
transforms.RandomAffine(
10,
translate=[0.5, 0.3],
scale=[0.7, 1.3],
shear=[-10, 10, 0, 20, 40])
with self.assertRaises(ValueError):
transforms.RandomRotation(-2)
......@@ -547,6 +637,36 @@ class TestFunctional(unittest.TestCase):
with self.assertRaises(TypeError):
F.adjust_saturation(1, 0.1)
with self.assertRaises(TypeError):
F.affine('45')
with self.assertRaises(TypeError):
F.affine(45, translate=0.3)
with self.assertRaises(TypeError):
F.affine(45, translate=[0.2, 0.2, 0.3])
with self.assertRaises(TypeError):
F.affine(45, translate=[0.2, 0.2], scale=-0.5)
with self.assertRaises(TypeError):
F.affine(45, translate=[0.2, 0.2], scale=0.5, shear=10)
with self.assertRaises(TypeError):
F.affine(45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 0, 10])
with self.assertRaises(TypeError):
F.affine(
45,
translate=[0.2, 0.2],
scale=0.5,
shear=[-10, 10],
interpolation=2)
with self.assertRaises(TypeError):
F.affine(
45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 10], center=0)
with self.assertRaises(TypeError):
F.rotate(1, 0.1)
......@@ -785,6 +905,31 @@ class TestFunctional(unittest.TestCase):
os.remove(path)
def test_affine(self):
np_img = (np.random.rand(32, 26, 3) * 255).astype('uint8')
pil_img = Image.fromarray(np_img).convert('RGB')
tensor_img = F.to_tensor(pil_img, data_format='CHW') * 255
np.testing.assert_almost_equal(
np_img, tensor_img.transpose((1, 2, 0)), decimal=4)
np_affined_img = F.affine(
np_img, 45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 10])
pil_affined_img = F.affine(
pil_img, 45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 10])
tensor_affined_img = F.affine(
tensor_img, 45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 10])
np.testing.assert_equal(np_affined_img.shape,
np.array(pil_affined_img).shape)
np.testing.assert_equal(np_affined_img.shape,
tensor_affined_img.transpose((1, 2, 0)).shape)
np.testing.assert_almost_equal(
np.array(pil_affined_img),
tensor_affined_img.numpy().transpose((1, 2, 0)),
decimal=4)
def test_rotate(self):
np_img = (np.random.rand(28, 28, 3) * 255).astype('uint8')
pil_img = Image.fromarray(np_img).convert('RGB')
......
......@@ -28,6 +28,7 @@ from .transforms import HueTransform # noqa: F401
from .transforms import ColorJitter # noqa: F401
from .transforms import RandomCrop # noqa: F401
from .transforms import Pad # noqa: F401
from .transforms import RandomAffine # noqa: F401
from .transforms import RandomRotation # noqa: F401
from .transforms import Grayscale # noqa: F401
from .transforms import ToTensor # noqa: F401
......@@ -37,6 +38,7 @@ from .functional import hflip # noqa: F401
from .functional import vflip # noqa: F401
from .functional import resize # noqa: F401
from .functional import pad # noqa: F401
from .functional import affine # noqa: F401
from .functional import rotate # noqa: F401
from .functional import to_grayscale # noqa: F401
from .functional import crop # noqa: F401
......@@ -64,6 +66,7 @@ __all__ = [ #noqa
'ColorJitter',
'RandomCrop',
'Pad',
'RandomAffine',
'RandomRotation',
'Grayscale',
'ToTensor',
......@@ -73,6 +76,7 @@ __all__ = [ #noqa
'vflip',
'resize',
'pad',
'affine',
'rotate',
'to_grayscale',
'crop',
......
......@@ -537,6 +537,166 @@ def adjust_hue(img, hue_factor):
return F_t.adjust_hue(img, hue_factor)
def _get_affine_matrix(center, angle, translate, scale, shear):
# Affine matrix is : M = T * C * RotateScaleShear * C^-1
# Ihe inverse one is : M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
rot = math.radians(angle)
sx = math.radians(shear[0])
sy = math.radians(shear[1])
# Rotate and Shear without scaling
a = math.cos(rot - sy) / math.cos(sy)
b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
c = math.sin(rot - sy) / math.cos(sy)
d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
# Center Translation
cx, cy = center
tx, ty = translate
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
matrix = [d, -b, 0.0, -c, a, 0.0]
matrix = [x / scale for x in matrix]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
matrix[2] += cx
matrix[5] += cy
return matrix
def affine(img,
angle,
translate,
scale,
shear,
interpolation="nearest",
fill=0,
center=None):
"""Apply affine transformation on the image.
Args:
img (PIL.Image|np.array|paddle.Tensor): Image to be affined.
angle (int|float): The angle of the random rotation in clockwise order.
translate (list[float]): Maximum absolute fraction for horizontal and vertical translations.
scale (float): Scale factor for the image, scale should be positive.
shear (list[float]): Shear angle values which are parallel to the x-axis and y-axis in clockwise order.
interpolation (str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST
according the backend.
When use pil backend, support method are as following:
- "nearest": Image.NEAREST,
- "bilinear": Image.BILINEAR,
- "bicubic": Image.BICUBIC
When use cv2 backend, support method are as following:
- "nearest": cv2.INTER_NEAREST,
- "bilinear": cv2.INTER_LINEAR,
- "bicubic": cv2.INTER_CUBIC
fill (int|list|tuple, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
center (2-tuple, optional): Optional center of rotation, (x, y).
Origin is the upper left corner.
Default is the center of the image.
Returns:
PIL.Image|np.array|paddle.Tensor: Affine Transformed image.
Examples:
.. code-block:: python
import paddle
from paddle.vision.transforms import functional as F
fake_img = paddle.randn((3, 256, 300)).astype(paddle.float32)
affined_img = F.affine(fake_img, 45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 10])
print(affined_img.shape)
"""
if not (_is_pil_image(img) or _is_numpy_image(img) or
_is_tensor_image(img)):
raise TypeError(
'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'.
format(type(img)))
if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if not isinstance(translate, (list, tuple)):
raise TypeError("Argument translate should be a sequence")
if len(translate) != 2:
raise ValueError("Argument translate should be a sequence of length 2")
if scale <= 0.0:
raise ValueError("Argument scale should be positive")
if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError(
"Shear should be either a single value or a sequence of two values")
if not isinstance(interpolation, str):
raise TypeError("Argument interpolation should be a string")
if isinstance(angle, int):
angle = float(angle)
if isinstance(translate, tuple):
translate = list(translate)
if isinstance(shear, numbers.Number):
shear = [shear, 0.0]
if isinstance(shear, tuple):
shear = list(shear)
if len(shear) == 1:
shear = [shear[0], shear[0]]
if len(shear) != 2:
raise ValueError(
f"Shear should be a sequence containing two values. Got {shear}")
if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
if _is_pil_image(img):
width, height = img.size
# center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of F_t.affine
if center is None:
center = [width * 0.5, height * 0.5]
matrix = _get_affine_matrix(center, angle, translate, scale, shear)
return F_pil.affine(img, matrix, interpolation, fill)
if _is_numpy_image(img):
# get affine_matrix in F_cv2.affine() using cv2's functions
width, height = img.shape[0:2]
# center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of F_t.affine
if center is None:
center = (width * 0.5, height * 0.5)
return F_cv2.affine(img, angle, translate, scale, shear, interpolation,
fill, center)
if _is_tensor_image(img):
center_f = [0.0, 0.0]
if center is not None:
height, width = img.shape[-1], img.shape[-2]
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [
1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])
]
translate_f = [1.0 * t for t in translate]
matrix = _get_affine_matrix(center_f, angle, translate_f, scale, shear)
return F_t.affine(img, matrix, interpolation, fill)
def rotate(img,
angle,
interpolation="nearest",
......
......@@ -411,6 +411,86 @@ def adjust_hue(img, hue_factor):
return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
def affine(img,
angle,
translate,
scale,
shear,
interpolation='nearest',
fill=0,
center=None):
"""Affine the image by matrix.
Args:
img (PIL.Image): Image to be affined.
translate (sequence or int): horizontal and vertical translations
scale (float): overall scale ratio
shear (sequence or float): shear angle value in degrees between -180 to 180, clockwise direction.
If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while
the second value corresponds to a shear parallel to the y axis.
interpolation (int|str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set to cv2.INTER_NEAREST.
when use cv2 backend, support method are as following:
- "nearest": cv2.INTER_NEAREST,
- "bilinear": cv2.INTER_LINEAR,
- "bicubic": cv2.INTER_CUBIC
fill (3-tuple or int): RGB pixel fill value for area outside the affined image.
If int, it is used for all channels respectively.
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
Default is the center of the image.
Returns:
np.array: Affined image.
"""
cv2 = try_import('cv2')
_cv2_interp_from_str = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'area': cv2.INTER_AREA,
'bicubic': cv2.INTER_CUBIC,
'lanczos': cv2.INTER_LANCZOS4
}
h, w = img.shape[0:2]
if isinstance(fill, int):
fill = tuple([fill] * 3)
if center is None:
center = (w / 2.0, h / 2.0)
M = np.ones([2, 3])
# Rotate and Scale
R = cv2.getRotationMatrix2D(angle=angle, center=center, scale=scale)
# Shear
sx = math.tan(shear[0] * math.pi / 180)
sy = math.tan(shear[1] * math.pi / 180)
M[0] = R[0] + sy * R[1]
M[1] = R[1] + sx * R[0]
# Translation
tx, ty = translate
M[0, 2] = tx
M[1, 2] = ty
if len(img.shape) == 3 and img.shape[2] == 1:
return cv2.warpAffine(
img,
M,
dsize=(w, h),
flags=_cv2_interp_from_str[interpolation],
borderValue=fill)[:, :, np.newaxis]
else:
return cv2.warpAffine(
img,
M,
dsize=(w, h),
flags=_cv2_interp_from_str[interpolation],
borderValue=fill)
def rotate(img,
angle,
interpolation='nearest',
......
......@@ -410,6 +410,32 @@ def adjust_hue(img, hue_factor):
return img
def affine(img, matrix, interpolation="nearest", fill=0):
"""Affine the image by matrix.
Args:
img (PIL.Image): Image to be affined.
matrix (float or int): Affine matrix.
interpolation (str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set to PIL.Image.NEAREST . when use pil backend,
support method are as following:
- "nearest": Image.NEAREST,
- "bilinear": Image.BILINEAR,
- "bicubic": Image.BICUBIC
fill (3-tuple or int): RGB pixel fill value for area outside the affined image.
If int, it is used for all channels respectively.
Returns:
PIL.Image: Affined image.
"""
if isinstance(fill, int):
fill = tuple([fill] * 3)
return img.transform(img.size, Image.AFFINE, matrix,
_pil_interp_from_str[interpolation], fill)
def rotate(img,
angle,
interpolation="nearest",
......
......@@ -226,8 +226,8 @@ def _affine_grid(theta, w, h, ow, oh):
def _grid_transform(img, grid, mode, fill):
if img.shape[0] > 1:
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2],
grid.shape[3])
grid = grid.expand(
shape=[img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]])
if fill is not None:
dummy = paddle.ones(
......@@ -255,6 +255,47 @@ def _grid_transform(img, grid, mode, fill):
return img
def affine(img, matrix, interpolation="nearest", fill=None, data_format='CHW'):
"""Affine to the image by matrix.
Args:
img (paddle.Tensor): Image to be rotated.
matrix (float or int): Affine matrix.
interpolation (str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set NEAREST . when use pil backend,
support method are as following:
- "nearest"
- "bilinear"
- "bicubic"
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively.
data_format (str, optional): Data format of img, should be 'HWC' or
'CHW'. Default: 'CHW'.
Returns:
paddle.Tensor: Affined image.
"""
img = img.unsqueeze(0)
img = img if data_format.lower() == 'chw' else img.transpose((0, 3, 1, 2))
matrix = paddle.to_tensor(matrix, place=img.place)
matrix = matrix.reshape((1, 2, 3))
shape = img.shape
grid = _affine_grid(
matrix, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
if isinstance(fill, int):
fill = tuple([fill] * 3)
out = _grid_transform(img, grid, mode=interpolation, fill=fill)
out = out if data_format.lower() == 'chw' else out.transpose((0, 2, 3, 1))
return out.squeeze(0)
def rotate(img,
angle,
interpolation='nearest',
......
......@@ -1205,6 +1205,189 @@ class Pad(BaseTransform):
return F.pad(img, self.padding, self.fill, self.padding_mode)
def _check_sequence_input(x, name, req_sizes):
msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join(
[str(s) for s in req_sizes])
if not isinstance(x, Sequence):
raise TypeError(f"{name} should be a sequence of length {msg}.")
if len(x) not in req_sizes:
raise ValueError(f"{name} should be sequence of length {msg}.")
def _setup_angle(x, name, req_sizes=(2, )):
if isinstance(x, numbers.Number):
if x < 0:
raise ValueError(
f"If {name} is a single number, it must be positive.")
x = [-x, x]
else:
_check_sequence_input(x, name, req_sizes)
return [float(d) for d in x]
class RandomAffine(BaseTransform):
"""Random affine transformation of the image.
Args:
degrees (int|float|tuple): The angle interval of the random rotation.
If set as a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees) in clockwise order. If set 0, will not rotate.
translate (tuple, optional): Maximum absolute fraction for horizontal and vertical translations.
For example translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a
and vertical shift is randomly sampled in the range -img_height * b < dy < img_height * b.
Default is None, will not translate.
scale (tuple, optional): Scaling factor interval, e.g (a, b), then scale is randomly sampled from the range a <= scale <= b.
Default is None, will keep original scale and not scale.
shear (sequence or number, optional): Range of degrees to shear, ranges from -180 to 180 in clockwise order.
If set as a number, a shear parallel to the x axis in the range (-shear, +shear) will be applied.
Else if set as a sequence of 2 values a shear parallel to the x axis in the range (shear[0], shear[1]) will be applied.
Else if set as a sequence of 4 values, a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
Default is None, will not apply shear.
interpolation (str, optional): Interpolation method. If omitted, or if the
image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST
according the backend.
When use pil backend, support method are as following:
- "nearest": Image.NEAREST,
- "bilinear": Image.BILINEAR,
- "bicubic": Image.BICUBIC
When use cv2 backend, support method are as following:
- "nearest": cv2.INTER_NEAREST,
- "bilinear": cv2.INTER_LINEAR,
- "bicubic": cv2.INTER_CUBIC
fill (int|list|tuple, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
center (2-tuple, optional): Optional center of rotation, (x, y).
Origin is the upper left corner.
Default is the center of the image.
keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None.
Shape:
- img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C).
- output(PIL.Image|np.ndarray|Paddle.Tensor): An affined image.
Returns:
A callable object of RandomAffine.
Examples:
.. code-block:: python
import paddle
from paddle.vision.transforms import RandomAffine
transform = RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 10])
fake_img = paddle.randn((3, 256, 300)).astype(paddle.float32)
fake_img = transform(fake_img)
print(fake_img.shape)
"""
def __init__(self,
degrees,
translate=None,
scale=None,
shear=None,
interpolation='nearest',
fill=0,
center=None,
keys=None):
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
super(RandomAffine, self).__init__(keys)
assert interpolation in ['nearest', 'bilinear', 'bicubic']
self.interpolation = interpolation
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2, ))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError(
"translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2, ))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fill = fill
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2, ))
self.center = center
def _get_param(self,
img_size,
degrees,
translate=None,
scale_ranges=None,
shears=None):
"""Get parameters for affine transformation
Returns:
params to be passed to the affine transformation
"""
angle = random.uniform(degrees[0], degrees[1])
if translate is not None:
max_dx = float(translate[0] * img_size[0])
max_dy = float(translate[1] * img_size[1])
tx = int(random.uniform(-max_dx, max_dx))
ty = int(random.uniform(-max_dy, max_dy))
translations = (tx, ty)
else:
translations = (0, 0)
if scale_ranges is not None:
scale = random.uniform(scale_ranges[0], scale_ranges[1])
else:
scale = 1.0
shear_x, shear_y = 0.0, 0.0
if shears is not None:
shear_x = random.uniform(shears[0], shears[1])
if len(shears) == 4:
shear_y = random.uniform(shears[2], shears[3])
shear = (shear_x, shear_y)
return angle, translations, scale, shear
def _apply_image(self, img):
"""
Args:
img (PIL.Image|np.array): Image to be affine transformed.
Returns:
PIL.Image or np.array: Affine transformed image.
"""
w, h = _get_image_size(img)
img_size = [w, h]
ret = self._get_param(img_size, self.degrees, self.translate,
self.scale, self.shear)
return F.affine(
img,
*ret,
interpolation=self.interpolation,
fill=self.fill,
center=self.center)
class RandomRotation(BaseTransform):
"""Rotates the image by angle.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册