diff --git a/python/paddle/tests/test_transforms.py b/python/paddle/tests/test_transforms.py index 119b1037278f6b3f2d451d994eaad86cf522002c..242680bc7c738c6a165cfa4a73e9b3ea78863edf 100644 --- a/python/paddle/tests/test_transforms.py +++ b/python/paddle/tests/test_transforms.py @@ -175,6 +175,12 @@ class TestTransformsCV2(unittest.TestCase): trans_random_crop_pad = transforms.RandomCrop((224, 256), 2, True) img = trans_random_crop_pad(img) + def test_erase(self): + trans = transforms.Compose([ + transforms.RandomErasing(), transforms.RandomErasing(value="random") + ]) + self.do_transform(trans) + def test_grayscale(self): trans = transforms.Compose([transforms.Grayscale()]) self.do_transform(trans) @@ -299,6 +305,24 @@ class TestTransformsCV2(unittest.TestCase): with self.assertRaises(NotImplementedError): transform = transforms.BrightnessTransform('0.1', keys='a') + with self.assertRaises(Exception): + transform = transforms.RandomErasing(scale=0.5) + + with self.assertRaises(Exception): + transform = transforms.RandomErasing(ratio=0.8) + + with self.assertRaises(Exception): + transform = transforms.RandomErasing(scale=(10, 0.4)) + + with self.assertRaises(Exception): + transform = transforms.RandomErasing(ratio=(3.3, 0.3)) + + with self.assertRaises(Exception): + transform = transforms.RandomErasing(prob=1.5) + + with self.assertRaises(Exception): + transform = transforms.RandomErasing(value="0") + def test_info(self): str(transforms.Compose([transforms.Resize((224, 224))])) str(transforms.Compose([transforms.Resize((224, 224))])) @@ -402,6 +426,13 @@ class TestTransformsTensor(TestTransformsCV2): trans_random_crop_pad = transforms.RandomCrop((224, 256), 2, True) img = trans_random_crop_pad(img) + def test_erase(self): + trans = transforms.Compose([ + transforms.RandomErasing(value=(0.5, )), + transforms.RandomErasing(value="random") + ]) + self.do_transform(trans) + def test_exception(self): trans = transforms.Compose([transforms.Resize(-1)]) @@ -694,6 +725,47 @@ class TestFunctional(unittest.TestCase): pil_img = Image.fromarray(np_img).convert('YCbCr') pil_tensor = F.to_tensor(pil_img) + def test_erase(self): + np_img = (np.random.rand(28, 28, 3) * 255).astype('uint8') + pil_img = Image.fromarray(np_img).convert('RGB') + + expected = np_img.copy() + expected[10:15, 10:15, :] = 0 + + F.erase(np_img, 10, 10, 5, 5, 0, inplace=True) + np.testing.assert_equal(np_img, expected) + + pil_result = F.erase(pil_img, 10, 10, 5, 5, 0) + np.testing.assert_equal(np.array(pil_result), expected) + + np_data = np.random.rand(3, 28, 28).astype('float32') + places = ['cpu'] + if paddle.device.is_compiled_with_cuda(): + places.append('gpu') + for place in places: + paddle.set_device(place) + tensor_img = paddle.to_tensor(np_data) + expected_tensor = tensor_img.clone() + expected_tensor[:, 10:15, 10:15] = paddle.to_tensor([0.88]) + + tensor_result = F.erase(tensor_img, 10, 10, 5, 5, + paddle.to_tensor([0.88])) + np.testing.assert_equal(tensor_result.numpy(), + expected_tensor.numpy()) + + def test_erase_backward(self): + img = paddle.randn((3, 14, 14), dtype=np.float32) + img.stop_gradient = False + erased = F.erase( + img, 3, 3, 5, 5, paddle.ones( + (1, 1, 1), dtype='float32')) + loss = erased.sum() + loss.backward() + + expected_grad = np.ones((3, 14, 14), dtype=np.float32) + expected_grad[:, 3:8, 3:8] = 0. + np.testing.assert_equal(img.grad.numpy(), expected_grad) + def test_image_load(self): fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype( 'uint8')) diff --git a/python/paddle/vision/transforms/__init__.py b/python/paddle/vision/transforms/__init__.py index 413f09f78699ee995f490e86a94006cd1a48c6a0..b255e663e6876a0c46ab18a6a0a393919755ac53 100644 --- a/python/paddle/vision/transforms/__init__.py +++ b/python/paddle/vision/transforms/__init__.py @@ -31,6 +31,7 @@ from .transforms import Pad # noqa: F401 from .transforms import RandomRotation # noqa: F401 from .transforms import Grayscale # noqa: F401 from .transforms import ToTensor # noqa: F401 +from .transforms import RandomErasing # noqa: F401 from .functional import to_tensor # noqa: F401 from .functional import hflip # noqa: F401 from .functional import vflip # noqa: F401 @@ -44,6 +45,7 @@ from .functional import adjust_brightness # noqa: F401 from .functional import adjust_contrast # noqa: F401 from .functional import adjust_hue # noqa: F401 from .functional import normalize # noqa: F401 +from .functional import erase # noqa: F401 __all__ = [ #noqa 'BaseTransform', @@ -65,6 +67,7 @@ __all__ = [ #noqa 'RandomRotation', 'Grayscale', 'ToTensor', + 'RandomErasing', 'to_tensor', 'hflip', 'vflip', @@ -77,5 +80,6 @@ __all__ = [ #noqa 'adjust_brightness', 'adjust_contrast', 'adjust_hue', - 'normalize' + 'normalize', + 'erase', ] diff --git a/python/paddle/vision/transforms/functional.py b/python/paddle/vision/transforms/functional.py index 1afac6e48be16958d9aa92ea01056cbb97130d83..5a8c2cc09f884900cb5d95f5e4579e9229a1b663 100644 --- a/python/paddle/vision/transforms/functional.py +++ b/python/paddle/vision/transforms/functional.py @@ -689,3 +689,39 @@ def normalize(img, mean, std, data_format='CHW', to_rgb=False): img = np.array(img).astype(np.float32) return F_cv2.normalize(img, mean, std, data_format, to_rgb) + + +def erase(img, i, j, h, w, v, inplace=False): + """Erase the pixels of selected area in input image with given value. + + Args: + img (paddle.Tensor | np.array | PIL.Image): input Tensor image. + For Tensor input, the shape should be (C, H, W). For np.array input, + the shape should be (H, W, C). + i (int): y coordinate of the top-left point of erased region. + j (int): x coordinate of the top-left point of erased region. + h (int): Height of the erased region. + w (int): Width of the erased region. + v (paddle.Tensor | np.array): value used to replace the pixels in erased region. It + should be np.array when img is np.array or PIL.Image. + inplace (bool, optional): Whether this transform is inplace. Default: False. + + Returns: + paddle.Tensor | np.array | PIL.Image: Erased image. The type is same with input image. + + Examples: + .. code-block:: python + + import paddle + + fake_img = paddle.randn((3, 10, 10)).astype(paddle.float32) + values = paddle.zeros((1,1,1), dtype=paddle.float32) + result = paddle.vision.transforms.erase(fake_img, 4, 4, 3, 3, values) + + """ + if _is_tensor_image(img): + return F_t.erase(img, i, j, h, w, v, inplace=inplace) + elif _is_pil_image(img): + return F_pil.erase(img, i, j, h, w, v, inplace=inplace) + else: + return F_cv2.erase(img, i, j, h, w, v, inplace=inplace) diff --git a/python/paddle/vision/transforms/functional_cv2.py b/python/paddle/vision/transforms/functional_cv2.py index 38b50898be606787977c0ac0b32d7e4d6aafa050..8343a8c340ffb33eafe74d9dd6be17210ffbd425 100644 --- a/python/paddle/vision/transforms/functional_cv2.py +++ b/python/paddle/vision/transforms/functional_cv2.py @@ -564,3 +564,26 @@ def normalize(img, mean, std, data_format='CHW', to_rgb=False): img = (img - mean) / std return img + + +def erase(img, i, j, h, w, v, inplace=False): + """Erase the pixels of selected area in input image array with given value. + + Args: + img (np.array): input image array, which shape is (H, W, C). + i (int): y coordinate of the top-left point of erased region. + j (int): x coordinate of the top-left point of erased region. + h (int): Height of the erased region. + w (int): Width of the erased region. + v (np.array): value used to replace the pixels in erased region. + inplace (bool, optional): Whether this transform is inplace. Default: False. + + Returns: + np.array: Erased image. + + """ + if not inplace: + img = img.copy() + + img[i:i + h, j:j + w, ...] = v + return img diff --git a/python/paddle/vision/transforms/functional_pil.py b/python/paddle/vision/transforms/functional_pil.py index 32f65fa1f846f4749d1d5fed77e296e6b042da86..71f7759f11b665e57980a3d8569793f43316efa0 100644 --- a/python/paddle/vision/transforms/functional_pil.py +++ b/python/paddle/vision/transforms/functional_pil.py @@ -480,3 +480,26 @@ def to_grayscale(img, num_output_channels=1): raise ValueError('num_output_channels should be either 1 or 3') return img + + +def erase(img, i, j, h, w, v, inplace=False): + """Erase the pixels of selected area in input image with given value. PIL format is + not support inplace. + + Args: + img (PIL.Image): input image, which shape is (C, H, W). + i (int): y coordinate of the top-left point of erased region. + j (int): x coordinate of the top-left point of erased region. + h (int): Height of the erased region. + w (int): Width of the erased region. + v (np.array): value used to replace the pixels in erased region. + inplace (bool, optional): Whether this transform is inplace. Default: False. + + Returns: + PIL.Image: Erased image. + + """ + np_img = np.array(img, dtype=np.uint8) + np_img[i:i + h, j:j + w, ...] = v + img = Image.fromarray(np_img, 'RGB') + return img diff --git a/python/paddle/vision/transforms/functional_tensor.py b/python/paddle/vision/transforms/functional_tensor.py index 2d6dc125d42dabd7c72ddbeb0b3cd8a9a6aaa2d5..2e276883cd3765dbf793115711791caa16627c35 100644 --- a/python/paddle/vision/transforms/functional_tensor.py +++ b/python/paddle/vision/transforms/functional_tensor.py @@ -416,6 +416,30 @@ def crop(img, top, left, height, width, data_format='CHW'): return img[top:top + height, left:left + width, :] +def erase(img, i, j, h, w, v, inplace=False): + """Erase the pixels of selected area in input Tensor image with given value. + + Args: + img (paddle.Tensor): input Tensor image. + i (int): y coordinate of the top-left point of erased region. + j (int): x coordinate of the top-left point of erased region. + h (int): Height of the erased region. + w (int): Width of the erased region. + v (paddle.Tensor): value used to replace the pixels in erased region. + inplace (bool, optional): Whether this transform is inplace. Default: False. + + Returns: + paddle.Tensor: Erased image. + + """ + _assert_image_tensor(img, 'CHW') + if not inplace: + img = img.clone() + + img[..., i:i + h, j:j + w] = v + return img + + def center_crop(img, output_size, data_format='CHW'): """Crops the given paddle.Tensor Image and resize it to desired size. diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index a22f8a2ab40499c01ead49a8020fa73e8aef7a10..828a0d9b0936d7a2c04226d7d97e2d78f267334e 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -25,6 +25,7 @@ import collections import warnings import traceback +import paddle from paddle.utils import try_import from . import functional as F @@ -1342,3 +1343,143 @@ class Grayscale(BaseTransform): PIL Image: Randomly grayscaled image. """ return F.to_grayscale(img, self.num_output_channels) + + +class RandomErasing(BaseTransform): + """Erase the pixels in a rectangle region selected randomly. + + Args: + prob (float, optional): Probability of the input data being erased. Default: 0.5. + scale (sequence, optional): The proportional range of the erased area to the input image. + Default: (0.02, 0.33). + ratio (sequence, optional): Aspect ratio range of the erased area. Default: (0.3, 3.3). + value (int|float|sequence|str, optional): The value each pixel in erased area will be replaced with. + If value is a single number, all pixels will be erased with this value. + If value is a sequence with length 3, the R, G, B channels will be ereased + respectively. If value is set to "random", each pixel will be erased with + random values. Default: 0. + inplace (bool, optional): Whether this transform is inplace. Default: False. + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + + Shape: + - img(paddle.Tensor | np.array | PIL.Image): The input image. For Tensor input, the shape should be (C, H, W). + For np.array input, the shape should be (H, W, C). + - output(paddle.Tensor | np.array | PIL.Image): A random erased image. + + Returns: + A callable object of RandomErasing. + + Examples: + + .. code-block:: python + + import paddle + + fake_img = paddle.randn((3, 10, 10)).astype(paddle.float32) + transform = paddle.vision.transforms.RandomErasing() + result = transform(fake_img) + """ + + def __init__(self, + prob=0.5, + scale=(0.02, 0.33), + ratio=(0.3, 3.3), + value=0, + inplace=False, + keys=None): + super(RandomErasing, self).__init__(keys) + assert isinstance(scale, + (tuple, list)), "scale should be a tuple or list" + assert (scale[0] >= 0 and scale[1] <= 1 and scale[0] <= scale[1] + ), "scale should be of kind (min, max) and in range [0, 1]" + assert isinstance(ratio, + (tuple, list)), "ratio should be a tuple or list" + assert (ratio[0] >= 0 and + ratio[0] <= ratio[1]), "ratio should be of kind (min, max)" + assert (prob >= 0 and + prob <= 1), "The probability should be in range [0, 1]" + assert isinstance( + value, (numbers.Number, str, tuple, + list)), "value should be a number, tuple, list or str" + if isinstance(value, str) and value != "random": + raise ValueError("value must be 'random' when type is str") + + self.prob = prob + self.scale = scale + self.ratio = ratio + self.value = value + self.inplace = inplace + + def _get_param(self, img, scale, ratio, value): + """Get parameters for ``erase`` for a random erasing. + + Args: + img (paddle.Tensor | np.array | PIL.Image): Image to be erased. + scale (sequence, optional): The proportional range of the erased area to the input image. + ratio (sequence, optional): Aspect ratio range of the erased area. + value (sequence | None): The value each pixel in erased area will be replaced with. + If value is a sequence with length 3, the R, G, B channels will be ereased + respectively. If value is None, each pixel will be erased with random values. + + Returns: + tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erase. + """ + if F._is_pil_image(img): + shape = np.asarray(img).astype(np.uint8).shape + h, w, c = shape[-3], shape[-2], shape[-1] + elif F._is_numpy_image(img): + h, w, c = img.shape[-3], img.shape[-2], img.shape[-1] + elif F._is_tensor_image(img): + c, h, w = img.shape[-3], img.shape[-2], img.shape[-1] + + img_area = h * w + log_ratio = np.log(ratio) + for _ in range(10): + erase_area = np.random.uniform(*scale) * img_area + aspect_ratio = np.exp(np.random.uniform(*log_ratio)) + erase_h = int(round(np.sqrt(erase_area * aspect_ratio))) + erase_w = int(round(np.sqrt(erase_area / aspect_ratio))) + if erase_h >= h or erase_w >= w: + continue + if F._is_tensor_image(img): + if value is None: + v = paddle.normal( + shape=[c, erase_h, erase_w]).astype(img.dtype) + else: + v = paddle.to_tensor(value, dtype=img.dtype)[:, None, None] + else: + if value is None: + v = np.random.normal(size=[erase_h, erase_w, c]) * 255 + else: + v = np.array(value)[None, None, :] + top = np.random.randint(0, h - erase_h + 1) + left = np.random.randint(0, w - erase_w + 1) + + return top, left, erase_h, erase_w, v + + return 0, 0, h, w, img + + def _apply_image(self, img): + """ + Args: + img (paddle.Tensor | np.array | PIL.Image): Image to be Erased. + + Returns: + output (paddle.Tensor np.array | PIL.Image): A random erased image. + """ + + if random.random() < self.prob: + if isinstance(self.value, numbers.Number): + value = [self.value] + elif isinstance(self.value, str): + value = None + else: + value = self.value + if value is not None and not (len(value) == 1 or len(value) == 3): + raise ValueError( + "Value should be a single number or a sequence with length equals to image's channel." + ) + top, left, erase_h, erase_w, v = self._get_param(img, self.scale, + self.ratio, value) + return F.erase(img, top, left, erase_h, erase_w, v, self.inplace) + return img