diff --git a/python/paddle/tests/test_transforms.py b/python/paddle/tests/test_transforms.py index 242680bc7c738c6a165cfa4a73e9b3ea78863edf..38cad05bfcb89f0e4d13c77aba96421e264259b9 100644 --- a/python/paddle/tests/test_transforms.py +++ b/python/paddle/tests/test_transforms.py @@ -123,6 +123,44 @@ class TestTransformsCV2(unittest.TestCase): ]) self.do_transform(trans) + def test_affine(self): + trans = transforms.Compose([ + transforms.RandomAffine(90), + transforms.RandomAffine( + [-10, 10], translate=[0.1, 0.3]), + transforms.RandomAffine( + 45, translate=[0.2, 0.2], scale=[0.2, 0.5]), + transforms.RandomAffine( + 10, translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 10]), + transforms.RandomAffine( + 10, + translate=[0.5, 0.3], + scale=[0.7, 1.3], + shear=[-10, 10, 20, 40]), + transforms.RandomAffine( + 10, + translate=[0.5, 0.3], + scale=[0.7, 1.3], + shear=[-10, 10, 20, 40], + interpolation='bilinear'), + transforms.RandomAffine( + 10, + translate=[0.5, 0.3], + scale=[0.7, 1.3], + shear=[-10, 10, 20, 40], + interpolation='bilinear', + fill=114), + transforms.RandomAffine( + 10, + translate=[0.5, 0.3], + scale=[0.7, 1.3], + shear=[-10, 10, 20, 40], + interpolation='bilinear', + fill=114, + center=(60, 80)), + ]) + self.do_transform(trans) + def test_rotate(self): trans = transforms.Compose([ transforms.RandomRotation(90), @@ -278,6 +316,35 @@ class TestTransformsCV2(unittest.TestCase): tensor_img = paddle.rand((3, 100, 100)) F.pad(tensor_img, [1.0, 2.0, 3.0]) + with self.assertRaises(ValueError): + transforms.RandomAffine(-10) + + with self.assertRaises(ValueError): + transforms.RandomAffine([-30, 60], translate=[2, 2]) + + with self.assertRaises(ValueError): + transforms.RandomAffine(10, translate=[0.2, 0.2], scale=[1, 2, 3]), + + with self.assertRaises(ValueError): + transforms.RandomAffine( + 10, translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[1, 2, 3]), + + with self.assertRaises(ValueError): + transforms.RandomAffine( + 10, + translate=[0.5, 0.3], + scale=[0.7, 1.3], + shear=[-10, 10, 0, 20, 40]) + + with self.assertRaises(ValueError): + transforms.RandomAffine( + 10, + translate=[0.5, 0.3], + scale=[0.7, 1.3], + shear=[-10, 10, 20, 40], + fill=114, + center=(1, 2, 3)) + with self.assertRaises(ValueError): transforms.RandomRotation(-2) @@ -479,6 +546,29 @@ class TestTransformsTensor(TestTransformsCV2): tensor_img = paddle.rand((3, 100, 100)) F.pad(tensor_img, [1.0, 2.0, 3.0]) + with self.assertRaises(ValueError): + transforms.RandomAffine(-10) + + with self.assertRaises(ValueError): + transforms.RandomAffine([-30, 60], translate=[2, 2]) + + with self.assertRaises(ValueError): + transforms.RandomAffine(10, translate=[0.2, 0.2], scale=[-2, -1]), + + with self.assertRaises(ValueError): + transforms.RandomAffine(10, translate=[0.2, 0.2], scale=[1, 2, 3]), + + with self.assertRaises(ValueError): + transforms.RandomAffine( + 10, translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[1, 2, 3]), + + with self.assertRaises(ValueError): + transforms.RandomAffine( + 10, + translate=[0.5, 0.3], + scale=[0.7, 1.3], + shear=[-10, 10, 0, 20, 40]) + with self.assertRaises(ValueError): transforms.RandomRotation(-2) @@ -547,6 +637,36 @@ class TestFunctional(unittest.TestCase): with self.assertRaises(TypeError): F.adjust_saturation(1, 0.1) + with self.assertRaises(TypeError): + F.affine('45') + + with self.assertRaises(TypeError): + F.affine(45, translate=0.3) + + with self.assertRaises(TypeError): + F.affine(45, translate=[0.2, 0.2, 0.3]) + + with self.assertRaises(TypeError): + F.affine(45, translate=[0.2, 0.2], scale=-0.5) + + with self.assertRaises(TypeError): + F.affine(45, translate=[0.2, 0.2], scale=0.5, shear=10) + + with self.assertRaises(TypeError): + F.affine(45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 0, 10]) + + with self.assertRaises(TypeError): + F.affine( + 45, + translate=[0.2, 0.2], + scale=0.5, + shear=[-10, 10], + interpolation=2) + + with self.assertRaises(TypeError): + F.affine( + 45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 10], center=0) + with self.assertRaises(TypeError): F.rotate(1, 0.1) @@ -785,6 +905,31 @@ class TestFunctional(unittest.TestCase): os.remove(path) + def test_affine(self): + np_img = (np.random.rand(32, 26, 3) * 255).astype('uint8') + pil_img = Image.fromarray(np_img).convert('RGB') + tensor_img = F.to_tensor(pil_img, data_format='CHW') * 255 + + np.testing.assert_almost_equal( + np_img, tensor_img.transpose((1, 2, 0)), decimal=4) + + np_affined_img = F.affine( + np_img, 45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 10]) + pil_affined_img = F.affine( + pil_img, 45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 10]) + tensor_affined_img = F.affine( + tensor_img, 45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 10]) + + np.testing.assert_equal(np_affined_img.shape, + np.array(pil_affined_img).shape) + np.testing.assert_equal(np_affined_img.shape, + tensor_affined_img.transpose((1, 2, 0)).shape) + + np.testing.assert_almost_equal( + np.array(pil_affined_img), + tensor_affined_img.numpy().transpose((1, 2, 0)), + decimal=4) + def test_rotate(self): np_img = (np.random.rand(28, 28, 3) * 255).astype('uint8') pil_img = Image.fromarray(np_img).convert('RGB') diff --git a/python/paddle/vision/transforms/__init__.py b/python/paddle/vision/transforms/__init__.py index b255e663e6876a0c46ab18a6a0a393919755ac53..41e9b188e34ed1a22674fbcb32c9b25282b7b437 100644 --- a/python/paddle/vision/transforms/__init__.py +++ b/python/paddle/vision/transforms/__init__.py @@ -28,6 +28,7 @@ from .transforms import HueTransform # noqa: F401 from .transforms import ColorJitter # noqa: F401 from .transforms import RandomCrop # noqa: F401 from .transforms import Pad # noqa: F401 +from .transforms import RandomAffine # noqa: F401 from .transforms import RandomRotation # noqa: F401 from .transforms import Grayscale # noqa: F401 from .transforms import ToTensor # noqa: F401 @@ -37,6 +38,7 @@ from .functional import hflip # noqa: F401 from .functional import vflip # noqa: F401 from .functional import resize # noqa: F401 from .functional import pad # noqa: F401 +from .functional import affine # noqa: F401 from .functional import rotate # noqa: F401 from .functional import to_grayscale # noqa: F401 from .functional import crop # noqa: F401 @@ -64,6 +66,7 @@ __all__ = [ #noqa 'ColorJitter', 'RandomCrop', 'Pad', + 'RandomAffine', 'RandomRotation', 'Grayscale', 'ToTensor', @@ -73,6 +76,7 @@ __all__ = [ #noqa 'vflip', 'resize', 'pad', + 'affine', 'rotate', 'to_grayscale', 'crop', diff --git a/python/paddle/vision/transforms/functional.py b/python/paddle/vision/transforms/functional.py index 29a857ba570f615def8a278b30b795595d0a9d72..83f756e6ed2a629f757bab9bdeaa6d93a619349e 100644 --- a/python/paddle/vision/transforms/functional.py +++ b/python/paddle/vision/transforms/functional.py @@ -537,6 +537,166 @@ def adjust_hue(img, hue_factor): return F_t.adjust_hue(img, hue_factor) +def _get_affine_matrix(center, angle, translate, scale, shear): + # Affine matrix is : M = T * C * RotateScaleShear * C^-1 + # Ihe inverse one is : M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1 + rot = math.radians(angle) + sx = math.radians(shear[0]) + sy = math.radians(shear[1]) + + # Rotate and Shear without scaling + a = math.cos(rot - sy) / math.cos(sy) + b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot) + c = math.sin(rot - sy) / math.cos(sy) + d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot) + + # Center Translation + cx, cy = center + tx, ty = translate + + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + matrix = [d, -b, 0.0, -c, a, 0.0] + matrix = [x / scale for x in matrix] + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) + matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += cx + matrix[5] += cy + + return matrix + + +def affine(img, + angle, + translate, + scale, + shear, + interpolation="nearest", + fill=0, + center=None): + """Apply affine transformation on the image. + + Args: + img (PIL.Image|np.array|paddle.Tensor): Image to be affined. + angle (int|float): The angle of the random rotation in clockwise order. + translate (list[float]): Maximum absolute fraction for horizontal and vertical translations. + scale (float): Scale factor for the image, scale should be positive. + shear (list[float]): Shear angle values which are parallel to the x-axis and y-axis in clockwise order. + interpolation (str, optional): Interpolation method. If omitted, or if the + image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST + according the backend. + When use pil backend, support method are as following: + - "nearest": Image.NEAREST, + - "bilinear": Image.BILINEAR, + - "bicubic": Image.BICUBIC + When use cv2 backend, support method are as following: + - "nearest": cv2.INTER_NEAREST, + - "bilinear": cv2.INTER_LINEAR, + - "bicubic": cv2.INTER_CUBIC + fill (int|list|tuple, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + center (2-tuple, optional): Optional center of rotation, (x, y). + Origin is the upper left corner. + Default is the center of the image. + + Returns: + PIL.Image|np.array|paddle.Tensor: Affine Transformed image. + + Examples: + .. code-block:: python + + import paddle + from paddle.vision.transforms import functional as F + + fake_img = paddle.randn((3, 256, 300)).astype(paddle.float32) + + affined_img = F.affine(fake_img, 45, translate=[0.2, 0.2], scale=0.5, shear=[-10, 10]) + print(affined_img.shape) + """ + + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): + raise TypeError( + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if not isinstance(translate, (list, tuple)): + raise TypeError("Argument translate should be a sequence") + + if len(translate) != 2: + raise ValueError("Argument translate should be a sequence of length 2") + + if scale <= 0.0: + raise ValueError("Argument scale should be positive") + + if not isinstance(shear, (numbers.Number, (list, tuple))): + raise TypeError( + "Shear should be either a single value or a sequence of two values") + + if not isinstance(interpolation, str): + raise TypeError("Argument interpolation should be a string") + + if isinstance(angle, int): + angle = float(angle) + + if isinstance(translate, tuple): + translate = list(translate) + + if isinstance(shear, numbers.Number): + shear = [shear, 0.0] + + if isinstance(shear, tuple): + shear = list(shear) + + if len(shear) == 1: + shear = [shear[0], shear[0]] + + if len(shear) != 2: + raise ValueError( + f"Shear should be a sequence containing two values. Got {shear}") + + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + + if _is_pil_image(img): + width, height = img.size + # center = (width * 0.5 + 0.5, height * 0.5 + 0.5) + # it is visually better to estimate the center without 0.5 offset + # otherwise image rotated by 90 degrees is shifted vs output image of F_t.affine + if center is None: + center = [width * 0.5, height * 0.5] + matrix = _get_affine_matrix(center, angle, translate, scale, shear) + return F_pil.affine(img, matrix, interpolation, fill) + + if _is_numpy_image(img): + # get affine_matrix in F_cv2.affine() using cv2's functions + width, height = img.shape[0:2] + # center = (width * 0.5 + 0.5, height * 0.5 + 0.5) + # it is visually better to estimate the center without 0.5 offset + # otherwise image rotated by 90 degrees is shifted vs output image of F_t.affine + if center is None: + center = (width * 0.5, height * 0.5) + return F_cv2.affine(img, angle, translate, scale, shear, interpolation, + fill, center) + + if _is_tensor_image(img): + center_f = [0.0, 0.0] + if center is not None: + height, width = img.shape[-1], img.shape[-2] + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [ + 1.0 * (c - s * 0.5) for c, s in zip(center, [width, height]) + ] + translate_f = [1.0 * t for t in translate] + matrix = _get_affine_matrix(center_f, angle, translate_f, scale, shear) + return F_t.affine(img, matrix, interpolation, fill) + + def rotate(img, angle, interpolation="nearest", diff --git a/python/paddle/vision/transforms/functional_cv2.py b/python/paddle/vision/transforms/functional_cv2.py index 8343a8c340ffb33eafe74d9dd6be17210ffbd425..d20bf3e60d907a069bd019a589038b25b7620715 100644 --- a/python/paddle/vision/transforms/functional_cv2.py +++ b/python/paddle/vision/transforms/functional_cv2.py @@ -411,6 +411,86 @@ def adjust_hue(img, hue_factor): return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype) +def affine(img, + angle, + translate, + scale, + shear, + interpolation='nearest', + fill=0, + center=None): + """Affine the image by matrix. + + Args: + img (PIL.Image): Image to be affined. + translate (sequence or int): horizontal and vertical translations + scale (float): overall scale ratio + shear (sequence or float): shear angle value in degrees between -180 to 180, clockwise direction. + If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while + the second value corresponds to a shear parallel to the y axis. + interpolation (int|str, optional): Interpolation method. If omitted, or if the + image has only one channel, it is set to cv2.INTER_NEAREST. + when use cv2 backend, support method are as following: + - "nearest": cv2.INTER_NEAREST, + - "bilinear": cv2.INTER_LINEAR, + - "bicubic": cv2.INTER_CUBIC + fill (3-tuple or int): RGB pixel fill value for area outside the affined image. + If int, it is used for all channels respectively. + center (sequence, optional): Optional center of rotation. Origin is the upper left corner. + Default is the center of the image. + + Returns: + np.array: Affined image. + + """ + cv2 = try_import('cv2') + _cv2_interp_from_str = { + 'nearest': cv2.INTER_NEAREST, + 'bilinear': cv2.INTER_LINEAR, + 'area': cv2.INTER_AREA, + 'bicubic': cv2.INTER_CUBIC, + 'lanczos': cv2.INTER_LANCZOS4 + } + + h, w = img.shape[0:2] + + if isinstance(fill, int): + fill = tuple([fill] * 3) + + if center is None: + center = (w / 2.0, h / 2.0) + + M = np.ones([2, 3]) + # Rotate and Scale + R = cv2.getRotationMatrix2D(angle=angle, center=center, scale=scale) + + # Shear + sx = math.tan(shear[0] * math.pi / 180) + sy = math.tan(shear[1] * math.pi / 180) + M[0] = R[0] + sy * R[1] + M[1] = R[1] + sx * R[0] + + # Translation + tx, ty = translate + M[0, 2] = tx + M[1, 2] = ty + + if len(img.shape) == 3 and img.shape[2] == 1: + return cv2.warpAffine( + img, + M, + dsize=(w, h), + flags=_cv2_interp_from_str[interpolation], + borderValue=fill)[:, :, np.newaxis] + else: + return cv2.warpAffine( + img, + M, + dsize=(w, h), + flags=_cv2_interp_from_str[interpolation], + borderValue=fill) + + def rotate(img, angle, interpolation='nearest', diff --git a/python/paddle/vision/transforms/functional_pil.py b/python/paddle/vision/transforms/functional_pil.py index 71f7759f11b665e57980a3d8569793f43316efa0..4c342e31b7f89a523ef9b19ac94e6e9fe199e0ce 100644 --- a/python/paddle/vision/transforms/functional_pil.py +++ b/python/paddle/vision/transforms/functional_pil.py @@ -410,6 +410,32 @@ def adjust_hue(img, hue_factor): return img +def affine(img, matrix, interpolation="nearest", fill=0): + """Affine the image by matrix. + + Args: + img (PIL.Image): Image to be affined. + matrix (float or int): Affine matrix. + interpolation (str, optional): Interpolation method. If omitted, or if the + image has only one channel, it is set to PIL.Image.NEAREST . when use pil backend, + support method are as following: + - "nearest": Image.NEAREST, + - "bilinear": Image.BILINEAR, + - "bicubic": Image.BICUBIC + fill (3-tuple or int): RGB pixel fill value for area outside the affined image. + If int, it is used for all channels respectively. + + Returns: + PIL.Image: Affined image. + + """ + if isinstance(fill, int): + fill = tuple([fill] * 3) + + return img.transform(img.size, Image.AFFINE, matrix, + _pil_interp_from_str[interpolation], fill) + + def rotate(img, angle, interpolation="nearest", diff --git a/python/paddle/vision/transforms/functional_tensor.py b/python/paddle/vision/transforms/functional_tensor.py index 2e276883cd3765dbf793115711791caa16627c35..cafb2655659b0d11ae36312f9940edb5a20b0ced 100644 --- a/python/paddle/vision/transforms/functional_tensor.py +++ b/python/paddle/vision/transforms/functional_tensor.py @@ -226,8 +226,8 @@ def _affine_grid(theta, w, h, ow, oh): def _grid_transform(img, grid, mode, fill): if img.shape[0] > 1: - grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], - grid.shape[3]) + grid = grid.expand( + shape=[img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]]) if fill is not None: dummy = paddle.ones( @@ -255,6 +255,47 @@ def _grid_transform(img, grid, mode, fill): return img +def affine(img, matrix, interpolation="nearest", fill=None, data_format='CHW'): + """Affine to the image by matrix. + + Args: + img (paddle.Tensor): Image to be rotated. + matrix (float or int): Affine matrix. + interpolation (str, optional): Interpolation method. If omitted, or if the + image has only one channel, it is set NEAREST . when use pil backend, + support method are as following: + - "nearest" + - "bilinear" + - "bicubic" + fill (3-tuple or int): RGB pixel fill value for area outside the rotated image. + If int, it is used for all channels respectively. + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + + Returns: + paddle.Tensor: Affined image. + + """ + img = img.unsqueeze(0) + img = img if data_format.lower() == 'chw' else img.transpose((0, 3, 1, 2)) + + matrix = paddle.to_tensor(matrix, place=img.place) + matrix = matrix.reshape((1, 2, 3)) + shape = img.shape + + grid = _affine_grid( + matrix, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) + + if isinstance(fill, int): + fill = tuple([fill] * 3) + + out = _grid_transform(img, grid, mode=interpolation, fill=fill) + + out = out if data_format.lower() == 'chw' else out.transpose((0, 2, 3, 1)) + + return out.squeeze(0) + + def rotate(img, angle, interpolation='nearest', diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index ce356449c594e3a5950160b5ae65844803bdd5a0..42dfd6dfa4f81e899c3544c4dae50993886def50 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -1205,6 +1205,189 @@ class Pad(BaseTransform): return F.pad(img, self.padding, self.fill, self.padding_mode) +def _check_sequence_input(x, name, req_sizes): + msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join( + [str(s) for s in req_sizes]) + if not isinstance(x, Sequence): + raise TypeError(f"{name} should be a sequence of length {msg}.") + if len(x) not in req_sizes: + raise ValueError(f"{name} should be sequence of length {msg}.") + + +def _setup_angle(x, name, req_sizes=(2, )): + if isinstance(x, numbers.Number): + if x < 0: + raise ValueError( + f"If {name} is a single number, it must be positive.") + x = [-x, x] + else: + _check_sequence_input(x, name, req_sizes) + + return [float(d) for d in x] + + +class RandomAffine(BaseTransform): + """Random affine transformation of the image. + + Args: + degrees (int|float|tuple): The angle interval of the random rotation. + If set as a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees) in clockwise order. If set 0, will not rotate. + translate (tuple, optional): Maximum absolute fraction for horizontal and vertical translations. + For example translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a + and vertical shift is randomly sampled in the range -img_height * b < dy < img_height * b. + Default is None, will not translate. + scale (tuple, optional): Scaling factor interval, e.g (a, b), then scale is randomly sampled from the range a <= scale <= b. + Default is None, will keep original scale and not scale. + shear (sequence or number, optional): Range of degrees to shear, ranges from -180 to 180 in clockwise order. + If set as a number, a shear parallel to the x axis in the range (-shear, +shear) will be applied. + Else if set as a sequence of 2 values a shear parallel to the x axis in the range (shear[0], shear[1]) will be applied. + Else if set as a sequence of 4 values, a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. + Default is None, will not apply shear. + interpolation (str, optional): Interpolation method. If omitted, or if the + image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST + according the backend. + When use pil backend, support method are as following: + - "nearest": Image.NEAREST, + - "bilinear": Image.BILINEAR, + - "bicubic": Image.BICUBIC + When use cv2 backend, support method are as following: + - "nearest": cv2.INTER_NEAREST, + - "bilinear": cv2.INTER_LINEAR, + - "bicubic": cv2.INTER_CUBIC + fill (int|list|tuple, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + center (2-tuple, optional): Optional center of rotation, (x, y). + Origin is the upper left corner. + Default is the center of the image. + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): An affined image. + + Returns: + A callable object of RandomAffine. + + Examples: + + .. code-block:: python + + import paddle + from paddle.vision.transforms import RandomAffine + + transform = RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 10]) + + fake_img = paddle.randn((3, 256, 300)).astype(paddle.float32) + + fake_img = transform(fake_img) + print(fake_img.shape) + """ + + def __init__(self, + degrees, + translate=None, + scale=None, + shear=None, + interpolation='nearest', + fill=0, + center=None, + keys=None): + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) + + super(RandomAffine, self).__init__(keys) + assert interpolation in ['nearest', 'bilinear', 'bicubic'] + self.interpolation = interpolation + + if translate is not None: + _check_sequence_input(translate, "translate", req_sizes=(2, )) + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError( + "translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + _check_sequence_input(scale, "scale", req_sizes=(2, )) + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) + else: + self.shear = shear + + if fill is None: + fill = 0 + elif not isinstance(fill, (Sequence, numbers.Number)): + raise TypeError("Fill should be either a sequence or a number.") + self.fill = fill + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2, )) + self.center = center + + def _get_param(self, + img_size, + degrees, + translate=None, + scale_ranges=None, + shears=None): + """Get parameters for affine transformation + + Returns: + params to be passed to the affine transformation + """ + angle = random.uniform(degrees[0], degrees[1]) + + if translate is not None: + max_dx = float(translate[0] * img_size[0]) + max_dy = float(translate[1] * img_size[1]) + tx = int(random.uniform(-max_dx, max_dx)) + ty = int(random.uniform(-max_dy, max_dy)) + translations = (tx, ty) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = random.uniform(scale_ranges[0], scale_ranges[1]) + else: + scale = 1.0 + + shear_x, shear_y = 0.0, 0.0 + if shears is not None: + shear_x = random.uniform(shears[0], shears[1]) + if len(shears) == 4: + shear_y = random.uniform(shears[2], shears[3]) + shear = (shear_x, shear_y) + + return angle, translations, scale, shear + + def _apply_image(self, img): + """ + Args: + img (PIL.Image|np.array): Image to be affine transformed. + + Returns: + PIL.Image or np.array: Affine transformed image. + """ + + w, h = _get_image_size(img) + img_size = [w, h] + + ret = self._get_param(img_size, self.degrees, self.translate, + self.scale, self.shear) + + return F.affine( + img, + *ret, + interpolation=self.interpolation, + fill=self.fill, + center=self.center) + + class RandomRotation(BaseTransform): """Rotates the image by angle.