未验证 提交 e906eb5b 编写于 作者: J JYChen 提交者: GitHub

add batch tensor support for some vision transforms functions (#42701)

上级 43d70bcc
......@@ -458,6 +458,20 @@ class TestTransformsTensor(TestTransformsCV2):
trans = transforms.Compose([transforms.ColorJitter(1.1, 2.2, 0.8, 0.1)])
self.do_transform(trans)
color_jitter_trans = transforms.ColorJitter(1.2, 0.2, 0.5, 0.2)
batch_input = paddle.rand((2, 3, 4, 4), dtype=paddle.float32)
result = color_jitter_trans(batch_input)
def test_perspective(self):
trans = transforms.RandomPerspective(prob=1.0, distortion_scale=0.7)
batch_input = paddle.rand((2, 3, 4, 4), dtype=paddle.float32)
result = trans(batch_input)
def test_affine(self):
trans = transforms.RandomAffine(15, translate=[0.1, 0.1])
batch_input = paddle.rand((2, 3, 4, 4), dtype=paddle.float32)
result = trans(batch_input)
def test_pad(self):
trans = transforms.Compose([transforms.Pad(2)])
self.do_transform(trans)
......@@ -508,6 +522,10 @@ class TestTransformsTensor(TestTransformsCV2):
])
self.do_transform(trans)
erase_trans = transforms.RandomErasing(value=(0.5, 0.2, 0.01))
batch_input = paddle.rand((2, 3, 4, 4), dtype=paddle.float32)
result = erase_trans(batch_input)
def test_exception(self):
trans = transforms.Compose([transforms.Resize(-1)])
......@@ -1003,6 +1021,113 @@ class TestFunctional(unittest.TestCase):
# Tolerance : less than 6% of different pixels
assert ratio_diff_pixels < 0.06
def test_batch_input(self):
paddle.seed(777)
batch_tensor = paddle.rand((2, 3, 8, 8), dtype=paddle.float32)
def test_erase(batch_tensor):
input1, input2 = paddle.unbind(batch_tensor, axis=0)
target_result = paddle.stack([
F.erase(input1, 1, 1, 2, 2, 0.5),
F.erase(input2, 1, 1, 2, 2, 0.5)
])
batch_result = F.erase(batch_tensor, 1, 1, 2, 2, 0.5)
return paddle.allclose(batch_result, target_result)
self.assertTrue(test_erase(batch_tensor))
def test_affine(batch_tensor):
input1, input2 = paddle.unbind(batch_tensor, axis=0)
target_result = paddle.stack([
F.affine(
input1,
45,
translate=[0.2, 0.2],
scale=0.5,
shear=[-10, 10]), F.affine(
input2,
45,
translate=[0.2, 0.2],
scale=0.5,
shear=[-10, 10])
])
batch_result = F.affine(
batch_tensor,
45,
translate=[0.2, 0.2],
scale=0.5,
shear=[-10, 10])
return paddle.allclose(batch_result, target_result)
self.assertTrue(test_affine(batch_tensor))
def test_perspective(batch_tensor):
input1, input2 = paddle.unbind(batch_tensor, axis=0)
startpoints = [[0, 0], [3, 0], [4, 5], [6, 7]]
endpoints = [[0, 1], [3, 1], [4, 4], [5, 7]]
target_result = paddle.stack([
F.perspective(input1, startpoints, endpoints),
F.perspective(input2, startpoints, endpoints)
])
batch_result = F.perspective(batch_tensor, startpoints, endpoints)
return paddle.allclose(batch_result, target_result)
self.assertTrue(test_perspective(batch_tensor))
def test_adjust_brightness(batch_tensor):
input1, input2 = paddle.unbind(batch_tensor, axis=0)
target_result = paddle.stack([
F.adjust_brightness(input1, 2.1),
F.adjust_brightness(input2, 2.1)
])
batch_result = F.adjust_brightness(batch_tensor, 2.1)
return paddle.allclose(batch_result, target_result)
self.assertTrue(test_adjust_brightness(batch_tensor))
def test_adjust_contrast(batch_tensor):
input1, input2 = paddle.unbind(batch_tensor, axis=0)
target_result = paddle.stack([
F.adjust_contrast(input1, 0.3), F.adjust_contrast(input2, 0.3)
])
batch_result = F.adjust_contrast(batch_tensor, 0.3)
return paddle.allclose(batch_result, target_result)
self.assertTrue(test_adjust_contrast(batch_tensor))
def test_adjust_saturation(batch_tensor):
input1, input2 = paddle.unbind(batch_tensor, axis=0)
target_result = paddle.stack([
F.adjust_saturation(input1, 1.1),
F.adjust_saturation(input2, 1.1)
])
batch_result = F.adjust_saturation(batch_tensor, 1.1)
return paddle.allclose(batch_result, target_result)
self.assertTrue(test_adjust_saturation(batch_tensor))
def test_adjust_hue(batch_tensor):
input1, input2 = paddle.unbind(batch_tensor, axis=0)
target_result = paddle.stack(
[F.adjust_hue(input1, -0.2), F.adjust_hue(input2, -0.2)])
batch_result = F.adjust_hue(batch_tensor, -0.2)
return paddle.allclose(batch_result, target_result)
self.assertTrue(test_adjust_hue(batch_tensor))
if __name__ == '__main__':
unittest.main()
......@@ -28,8 +28,9 @@ __all__ = []
def _assert_image_tensor(img, data_format):
if not isinstance(
img, paddle.Tensor) or img.ndim != 3 or not data_format.lower() in (
'chw', 'hwc'):
img, paddle.Tensor
) or img.ndim < 3 or img.ndim > 4 or not data_format.lower() in ('chw',
'hwc'):
raise RuntimeError(
'not support [type={}, ndim={}, data_format={}] paddle image'.
format(type(img), img.ndim, data_format))
......@@ -276,7 +277,10 @@ def affine(img, matrix, interpolation="nearest", fill=None, data_format='CHW'):
paddle.Tensor: Affined image.
"""
ndim = len(img.shape)
if ndim == 3:
img = img.unsqueeze(0)
img = img if data_format.lower() == 'chw' else img.transpose((0, 3, 1, 2))
matrix = paddle.to_tensor(matrix, place=img.place)
......@@ -292,8 +296,9 @@ def affine(img, matrix, interpolation="nearest", fill=None, data_format='CHW'):
out = _grid_transform(img, grid, mode=interpolation, fill=fill)
out = out if data_format.lower() == 'chw' else out.transpose((0, 2, 3, 1))
out = out.squeeze(0) if ndim == 3 else out
return out.squeeze(0)
return out
def rotate(img,
......@@ -443,6 +448,8 @@ def perspective(img,
"""
ndim = len(img.shape)
if ndim == 3:
img = img.unsqueeze(0)
img = img if data_format.lower() == 'chw' else img.transpose((0, 3, 1, 2))
......@@ -454,8 +461,9 @@ def perspective(img,
out = _grid_transform(img, grid, mode=interpolation, fill=fill)
out = out if data_format.lower() == 'chw' else out.transpose((0, 2, 3, 1))
out = out.squeeze(0) if ndim == 3 else out
return out.squeeze(0)
return out
def vflip(img, data_format='CHW'):
......
......@@ -45,7 +45,14 @@ def _get_image_size(img):
elif F._is_numpy_image(img):
return img.shape[:2][::-1]
elif F._is_tensor_image(img):
return img.shape[1:][::-1] # chw
if len(img.shape) == 3:
return img.shape[1:][::-1] # chw -> wh
elif len(img.shape) == 4:
return img.shape[2:][::-1] # nchw -> wh
else:
raise ValueError(
"The dim for input Tensor should be 3-D or 4-D, but received {}".
format(len(img.shape)))
else:
raise TypeError("Unexpected type {}".format(type(img)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册