未验证 提交 20286ae7 编写于 作者: J JYChen 提交者: GitHub

[New API] add API paddle.vision.transforms.RandomErasing and...

[New API] add API paddle.vision.transforms.RandomErasing and paddle.vision.transforms.erase (#42280)

* add api RandomErasing
* add numpy/PIL backend support and UT
* fix doc and optimize UT
* add seed
上级 a5de44f5
......@@ -175,6 +175,12 @@ class TestTransformsCV2(unittest.TestCase):
trans_random_crop_pad = transforms.RandomCrop((224, 256), 2, True)
img = trans_random_crop_pad(img)
def test_erase(self):
trans = transforms.Compose([
transforms.RandomErasing(), transforms.RandomErasing(value="random")
])
self.do_transform(trans)
def test_grayscale(self):
trans = transforms.Compose([transforms.Grayscale()])
self.do_transform(trans)
......@@ -299,6 +305,24 @@ class TestTransformsCV2(unittest.TestCase):
with self.assertRaises(NotImplementedError):
transform = transforms.BrightnessTransform('0.1', keys='a')
with self.assertRaises(Exception):
transform = transforms.RandomErasing(scale=0.5)
with self.assertRaises(Exception):
transform = transforms.RandomErasing(ratio=0.8)
with self.assertRaises(Exception):
transform = transforms.RandomErasing(scale=(10, 0.4))
with self.assertRaises(Exception):
transform = transforms.RandomErasing(ratio=(3.3, 0.3))
with self.assertRaises(Exception):
transform = transforms.RandomErasing(prob=1.5)
with self.assertRaises(Exception):
transform = transforms.RandomErasing(value="0")
def test_info(self):
str(transforms.Compose([transforms.Resize((224, 224))]))
str(transforms.Compose([transforms.Resize((224, 224))]))
......@@ -402,6 +426,13 @@ class TestTransformsTensor(TestTransformsCV2):
trans_random_crop_pad = transforms.RandomCrop((224, 256), 2, True)
img = trans_random_crop_pad(img)
def test_erase(self):
trans = transforms.Compose([
transforms.RandomErasing(value=(0.5, )),
transforms.RandomErasing(value="random")
])
self.do_transform(trans)
def test_exception(self):
trans = transforms.Compose([transforms.Resize(-1)])
......@@ -694,6 +725,47 @@ class TestFunctional(unittest.TestCase):
pil_img = Image.fromarray(np_img).convert('YCbCr')
pil_tensor = F.to_tensor(pil_img)
def test_erase(self):
np_img = (np.random.rand(28, 28, 3) * 255).astype('uint8')
pil_img = Image.fromarray(np_img).convert('RGB')
expected = np_img.copy()
expected[10:15, 10:15, :] = 0
F.erase(np_img, 10, 10, 5, 5, 0, inplace=True)
np.testing.assert_equal(np_img, expected)
pil_result = F.erase(pil_img, 10, 10, 5, 5, 0)
np.testing.assert_equal(np.array(pil_result), expected)
np_data = np.random.rand(3, 28, 28).astype('float32')
places = ['cpu']
if paddle.device.is_compiled_with_cuda():
places.append('gpu')
for place in places:
paddle.set_device(place)
tensor_img = paddle.to_tensor(np_data)
expected_tensor = tensor_img.clone()
expected_tensor[:, 10:15, 10:15] = paddle.to_tensor([0.88])
tensor_result = F.erase(tensor_img, 10, 10, 5, 5,
paddle.to_tensor([0.88]))
np.testing.assert_equal(tensor_result.numpy(),
expected_tensor.numpy())
def test_erase_backward(self):
img = paddle.randn((3, 14, 14), dtype=np.float32)
img.stop_gradient = False
erased = F.erase(
img, 3, 3, 5, 5, paddle.ones(
(1, 1, 1), dtype='float32'))
loss = erased.sum()
loss.backward()
expected_grad = np.ones((3, 14, 14), dtype=np.float32)
expected_grad[:, 3:8, 3:8] = 0.
np.testing.assert_equal(img.grad.numpy(), expected_grad)
def test_image_load(self):
fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype(
'uint8'))
......
......@@ -31,6 +31,7 @@ from .transforms import Pad # noqa: F401
from .transforms import RandomRotation # noqa: F401
from .transforms import Grayscale # noqa: F401
from .transforms import ToTensor # noqa: F401
from .transforms import RandomErasing # noqa: F401
from .functional import to_tensor # noqa: F401
from .functional import hflip # noqa: F401
from .functional import vflip # noqa: F401
......@@ -44,6 +45,7 @@ from .functional import adjust_brightness # noqa: F401
from .functional import adjust_contrast # noqa: F401
from .functional import adjust_hue # noqa: F401
from .functional import normalize # noqa: F401
from .functional import erase # noqa: F401
__all__ = [ #noqa
'BaseTransform',
......@@ -65,6 +67,7 @@ __all__ = [ #noqa
'RandomRotation',
'Grayscale',
'ToTensor',
'RandomErasing',
'to_tensor',
'hflip',
'vflip',
......@@ -77,5 +80,6 @@ __all__ = [ #noqa
'adjust_brightness',
'adjust_contrast',
'adjust_hue',
'normalize'
'normalize',
'erase',
]
......@@ -689,3 +689,39 @@ def normalize(img, mean, std, data_format='CHW', to_rgb=False):
img = np.array(img).astype(np.float32)
return F_cv2.normalize(img, mean, std, data_format, to_rgb)
def erase(img, i, j, h, w, v, inplace=False):
"""Erase the pixels of selected area in input image with given value.
Args:
img (paddle.Tensor | np.array | PIL.Image): input Tensor image.
For Tensor input, the shape should be (C, H, W). For np.array input,
the shape should be (H, W, C).
i (int): y coordinate of the top-left point of erased region.
j (int): x coordinate of the top-left point of erased region.
h (int): Height of the erased region.
w (int): Width of the erased region.
v (paddle.Tensor | np.array): value used to replace the pixels in erased region. It
should be np.array when img is np.array or PIL.Image.
inplace (bool, optional): Whether this transform is inplace. Default: False.
Returns:
paddle.Tensor | np.array | PIL.Image: Erased image. The type is same with input image.
Examples:
.. code-block:: python
import paddle
fake_img = paddle.randn((3, 10, 10)).astype(paddle.float32)
values = paddle.zeros((1,1,1), dtype=paddle.float32)
result = paddle.vision.transforms.erase(fake_img, 4, 4, 3, 3, values)
"""
if _is_tensor_image(img):
return F_t.erase(img, i, j, h, w, v, inplace=inplace)
elif _is_pil_image(img):
return F_pil.erase(img, i, j, h, w, v, inplace=inplace)
else:
return F_cv2.erase(img, i, j, h, w, v, inplace=inplace)
......@@ -564,3 +564,26 @@ def normalize(img, mean, std, data_format='CHW', to_rgb=False):
img = (img - mean) / std
return img
def erase(img, i, j, h, w, v, inplace=False):
"""Erase the pixels of selected area in input image array with given value.
Args:
img (np.array): input image array, which shape is (H, W, C).
i (int): y coordinate of the top-left point of erased region.
j (int): x coordinate of the top-left point of erased region.
h (int): Height of the erased region.
w (int): Width of the erased region.
v (np.array): value used to replace the pixels in erased region.
inplace (bool, optional): Whether this transform is inplace. Default: False.
Returns:
np.array: Erased image.
"""
if not inplace:
img = img.copy()
img[i:i + h, j:j + w, ...] = v
return img
......@@ -480,3 +480,26 @@ def to_grayscale(img, num_output_channels=1):
raise ValueError('num_output_channels should be either 1 or 3')
return img
def erase(img, i, j, h, w, v, inplace=False):
"""Erase the pixels of selected area in input image with given value. PIL format is
not support inplace.
Args:
img (PIL.Image): input image, which shape is (C, H, W).
i (int): y coordinate of the top-left point of erased region.
j (int): x coordinate of the top-left point of erased region.
h (int): Height of the erased region.
w (int): Width of the erased region.
v (np.array): value used to replace the pixels in erased region.
inplace (bool, optional): Whether this transform is inplace. Default: False.
Returns:
PIL.Image: Erased image.
"""
np_img = np.array(img, dtype=np.uint8)
np_img[i:i + h, j:j + w, ...] = v
img = Image.fromarray(np_img, 'RGB')
return img
......@@ -416,6 +416,30 @@ def crop(img, top, left, height, width, data_format='CHW'):
return img[top:top + height, left:left + width, :]
def erase(img, i, j, h, w, v, inplace=False):
"""Erase the pixels of selected area in input Tensor image with given value.
Args:
img (paddle.Tensor): input Tensor image.
i (int): y coordinate of the top-left point of erased region.
j (int): x coordinate of the top-left point of erased region.
h (int): Height of the erased region.
w (int): Width of the erased region.
v (paddle.Tensor): value used to replace the pixels in erased region.
inplace (bool, optional): Whether this transform is inplace. Default: False.
Returns:
paddle.Tensor: Erased image.
"""
_assert_image_tensor(img, 'CHW')
if not inplace:
img = img.clone()
img[..., i:i + h, j:j + w] = v
return img
def center_crop(img, output_size, data_format='CHW'):
"""Crops the given paddle.Tensor Image and resize it to desired size.
......
......@@ -25,6 +25,7 @@ import collections
import warnings
import traceback
import paddle
from paddle.utils import try_import
from . import functional as F
......@@ -1342,3 +1343,143 @@ class Grayscale(BaseTransform):
PIL Image: Randomly grayscaled image.
"""
return F.to_grayscale(img, self.num_output_channels)
class RandomErasing(BaseTransform):
"""Erase the pixels in a rectangle region selected randomly.
Args:
prob (float, optional): Probability of the input data being erased. Default: 0.5.
scale (sequence, optional): The proportional range of the erased area to the input image.
Default: (0.02, 0.33).
ratio (sequence, optional): Aspect ratio range of the erased area. Default: (0.3, 3.3).
value (int|float|sequence|str, optional): The value each pixel in erased area will be replaced with.
If value is a single number, all pixels will be erased with this value.
If value is a sequence with length 3, the R, G, B channels will be ereased
respectively. If value is set to "random", each pixel will be erased with
random values. Default: 0.
inplace (bool, optional): Whether this transform is inplace. Default: False.
keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None.
Shape:
- img(paddle.Tensor | np.array | PIL.Image): The input image. For Tensor input, the shape should be (C, H, W).
For np.array input, the shape should be (H, W, C).
- output(paddle.Tensor | np.array | PIL.Image): A random erased image.
Returns:
A callable object of RandomErasing.
Examples:
.. code-block:: python
import paddle
fake_img = paddle.randn((3, 10, 10)).astype(paddle.float32)
transform = paddle.vision.transforms.RandomErasing()
result = transform(fake_img)
"""
def __init__(self,
prob=0.5,
scale=(0.02, 0.33),
ratio=(0.3, 3.3),
value=0,
inplace=False,
keys=None):
super(RandomErasing, self).__init__(keys)
assert isinstance(scale,
(tuple, list)), "scale should be a tuple or list"
assert (scale[0] >= 0 and scale[1] <= 1 and scale[0] <= scale[1]
), "scale should be of kind (min, max) and in range [0, 1]"
assert isinstance(ratio,
(tuple, list)), "ratio should be a tuple or list"
assert (ratio[0] >= 0 and
ratio[0] <= ratio[1]), "ratio should be of kind (min, max)"
assert (prob >= 0 and
prob <= 1), "The probability should be in range [0, 1]"
assert isinstance(
value, (numbers.Number, str, tuple,
list)), "value should be a number, tuple, list or str"
if isinstance(value, str) and value != "random":
raise ValueError("value must be 'random' when type is str")
self.prob = prob
self.scale = scale
self.ratio = ratio
self.value = value
self.inplace = inplace
def _get_param(self, img, scale, ratio, value):
"""Get parameters for ``erase`` for a random erasing.
Args:
img (paddle.Tensor | np.array | PIL.Image): Image to be erased.
scale (sequence, optional): The proportional range of the erased area to the input image.
ratio (sequence, optional): Aspect ratio range of the erased area.
value (sequence | None): The value each pixel in erased area will be replaced with.
If value is a sequence with length 3, the R, G, B channels will be ereased
respectively. If value is None, each pixel will be erased with random values.
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erase.
"""
if F._is_pil_image(img):
shape = np.asarray(img).astype(np.uint8).shape
h, w, c = shape[-3], shape[-2], shape[-1]
elif F._is_numpy_image(img):
h, w, c = img.shape[-3], img.shape[-2], img.shape[-1]
elif F._is_tensor_image(img):
c, h, w = img.shape[-3], img.shape[-2], img.shape[-1]
img_area = h * w
log_ratio = np.log(ratio)
for _ in range(10):
erase_area = np.random.uniform(*scale) * img_area
aspect_ratio = np.exp(np.random.uniform(*log_ratio))
erase_h = int(round(np.sqrt(erase_area * aspect_ratio)))
erase_w = int(round(np.sqrt(erase_area / aspect_ratio)))
if erase_h >= h or erase_w >= w:
continue
if F._is_tensor_image(img):
if value is None:
v = paddle.normal(
shape=[c, erase_h, erase_w]).astype(img.dtype)
else:
v = paddle.to_tensor(value, dtype=img.dtype)[:, None, None]
else:
if value is None:
v = np.random.normal(size=[erase_h, erase_w, c]) * 255
else:
v = np.array(value)[None, None, :]
top = np.random.randint(0, h - erase_h + 1)
left = np.random.randint(0, w - erase_w + 1)
return top, left, erase_h, erase_w, v
return 0, 0, h, w, img
def _apply_image(self, img):
"""
Args:
img (paddle.Tensor | np.array | PIL.Image): Image to be Erased.
Returns:
output (paddle.Tensor np.array | PIL.Image): A random erased image.
"""
if random.random() < self.prob:
if isinstance(self.value, numbers.Number):
value = [self.value]
elif isinstance(self.value, str):
value = None
else:
value = self.value
if value is not None and not (len(value) == 1 or len(value) == 3):
raise ValueError(
"Value should be a single number or a sequence with length equals to image's channel."
)
top, left, erase_h, erase_w, v = self._get_param(img, self.scale,
self.ratio, value)
return F.erase(img, top, left, erase_h, erase_w, v, self.inplace)
return img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册