import sys import random import numbers import collections import numpy as np import paddle.vision.transforms as T import paddle.vision.transforms.functional as F from .builder import TRANSFORMS if sys.version_info < (3, 3): Sequence = collections.Sequence Iterable = collections.Iterable else: Sequence = collections.abc.Sequence Iterable = collections.abc.Iterable TRANSFORMS.register(T.Resize) TRANSFORMS.register(T.RandomCrop) TRANSFORMS.register(T.RandomHorizontalFlip) TRANSFORMS.register(T.Normalize) TRANSFORMS.register(T.Transpose) @TRANSFORMS.register() class PairedRandomCrop(T.RandomCrop): def __init__(self, size, keys=None): super().__init__(size, keys=keys) if isinstance(size, int): self.size = (size, size) else: self.size = size def _get_params(self, inputs): image = inputs[self.keys.index('image')] params = {} params['crop_prams'] = self._get_param(image, self.size) return params def _apply_image(self, img): i, j, h, w = self.params['crop_prams'] return F.crop(img, i, j, h, w) @TRANSFORMS.register() class PairedRandomHorizontalFlip(T.RandomHorizontalFlip): def __init__(self, prob=0.5, keys=None): super().__init__(prob, keys=keys) def _get_params(self, inputs): params = {} params['flip'] = random.random() < self.prob return params def _apply_image(self, image): if self.params['flip']: return F.hflip(image) return image