From e906eb5bce215f08a2cf22fc71ea85903c25433e Mon Sep 17 00:00:00 2001 From: JYChen Date: Thu, 12 May 2022 15:18:52 +0800 Subject: [PATCH] add batch tensor support for some vision transforms functions (#42701) --- python/paddle/tests/test_transforms.py | 125 ++++++++++++++++++ .../vision/transforms/functional_tensor.py | 20 ++- python/paddle/vision/transforms/transforms.py | 9 +- 3 files changed, 147 insertions(+), 7 deletions(-) diff --git a/python/paddle/tests/test_transforms.py b/python/paddle/tests/test_transforms.py index 82ae3cb6b68..e07ac47a0f8 100644 --- a/python/paddle/tests/test_transforms.py +++ b/python/paddle/tests/test_transforms.py @@ -458,6 +458,20 @@ class TestTransformsTensor(TestTransformsCV2): trans = transforms.Compose([transforms.ColorJitter(1.1, 2.2, 0.8, 0.1)]) self.do_transform(trans) + color_jitter_trans = transforms.ColorJitter(1.2, 0.2, 0.5, 0.2) + batch_input = paddle.rand((2, 3, 4, 4), dtype=paddle.float32) + result = color_jitter_trans(batch_input) + + def test_perspective(self): + trans = transforms.RandomPerspective(prob=1.0, distortion_scale=0.7) + batch_input = paddle.rand((2, 3, 4, 4), dtype=paddle.float32) + result = trans(batch_input) + + def test_affine(self): + trans = transforms.RandomAffine(15, translate=[0.1, 0.1]) + batch_input = paddle.rand((2, 3, 4, 4), dtype=paddle.float32) + result = trans(batch_input) + def test_pad(self): trans = transforms.Compose([transforms.Pad(2)]) self.do_transform(trans) @@ -508,6 +522,10 @@ class TestTransformsTensor(TestTransformsCV2): ]) self.do_transform(trans) + erase_trans = transforms.RandomErasing(value=(0.5, 0.2, 0.01)) + batch_input = paddle.rand((2, 3, 4, 4), dtype=paddle.float32) + result = erase_trans(batch_input) + def test_exception(self): trans = transforms.Compose([transforms.Resize(-1)]) @@ -1003,6 +1021,113 @@ class TestFunctional(unittest.TestCase): # Tolerance : less than 6% of different pixels assert ratio_diff_pixels < 0.06 + def test_batch_input(self): + paddle.seed(777) + batch_tensor = paddle.rand((2, 3, 8, 8), dtype=paddle.float32) + + def test_erase(batch_tensor): + input1, input2 = paddle.unbind(batch_tensor, axis=0) + target_result = paddle.stack([ + F.erase(input1, 1, 1, 2, 2, 0.5), + F.erase(input2, 1, 1, 2, 2, 0.5) + ]) + + batch_result = F.erase(batch_tensor, 1, 1, 2, 2, 0.5) + + return paddle.allclose(batch_result, target_result) + + self.assertTrue(test_erase(batch_tensor)) + + def test_affine(batch_tensor): + input1, input2 = paddle.unbind(batch_tensor, axis=0) + target_result = paddle.stack([ + F.affine( + input1, + 45, + translate=[0.2, 0.2], + scale=0.5, + shear=[-10, 10]), F.affine( + input2, + 45, + translate=[0.2, 0.2], + scale=0.5, + shear=[-10, 10]) + ]) + batch_result = F.affine( + batch_tensor, + 45, + translate=[0.2, 0.2], + scale=0.5, + shear=[-10, 10]) + + return paddle.allclose(batch_result, target_result) + + self.assertTrue(test_affine(batch_tensor)) + + def test_perspective(batch_tensor): + input1, input2 = paddle.unbind(batch_tensor, axis=0) + startpoints = [[0, 0], [3, 0], [4, 5], [6, 7]] + endpoints = [[0, 1], [3, 1], [4, 4], [5, 7]] + target_result = paddle.stack([ + F.perspective(input1, startpoints, endpoints), + F.perspective(input2, startpoints, endpoints) + ]) + + batch_result = F.perspective(batch_tensor, startpoints, endpoints) + + return paddle.allclose(batch_result, target_result) + + self.assertTrue(test_perspective(batch_tensor)) + + def test_adjust_brightness(batch_tensor): + input1, input2 = paddle.unbind(batch_tensor, axis=0) + target_result = paddle.stack([ + F.adjust_brightness(input1, 2.1), + F.adjust_brightness(input2, 2.1) + ]) + + batch_result = F.adjust_brightness(batch_tensor, 2.1) + + return paddle.allclose(batch_result, target_result) + + self.assertTrue(test_adjust_brightness(batch_tensor)) + + def test_adjust_contrast(batch_tensor): + input1, input2 = paddle.unbind(batch_tensor, axis=0) + target_result = paddle.stack([ + F.adjust_contrast(input1, 0.3), F.adjust_contrast(input2, 0.3) + ]) + + batch_result = F.adjust_contrast(batch_tensor, 0.3) + + return paddle.allclose(batch_result, target_result) + + self.assertTrue(test_adjust_contrast(batch_tensor)) + + def test_adjust_saturation(batch_tensor): + input1, input2 = paddle.unbind(batch_tensor, axis=0) + target_result = paddle.stack([ + F.adjust_saturation(input1, 1.1), + F.adjust_saturation(input2, 1.1) + ]) + + batch_result = F.adjust_saturation(batch_tensor, 1.1) + + return paddle.allclose(batch_result, target_result) + + self.assertTrue(test_adjust_saturation(batch_tensor)) + + def test_adjust_hue(batch_tensor): + input1, input2 = paddle.unbind(batch_tensor, axis=0) + target_result = paddle.stack( + [F.adjust_hue(input1, -0.2), F.adjust_hue(input2, -0.2)]) + + batch_result = F.adjust_hue(batch_tensor, -0.2) + + return paddle.allclose(batch_result, target_result) + + self.assertTrue(test_adjust_hue(batch_tensor)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/vision/transforms/functional_tensor.py b/python/paddle/vision/transforms/functional_tensor.py index df2529d1224..27f83029bab 100644 --- a/python/paddle/vision/transforms/functional_tensor.py +++ b/python/paddle/vision/transforms/functional_tensor.py @@ -28,8 +28,9 @@ __all__ = [] def _assert_image_tensor(img, data_format): if not isinstance( - img, paddle.Tensor) or img.ndim != 3 or not data_format.lower() in ( - 'chw', 'hwc'): + img, paddle.Tensor + ) or img.ndim < 3 or img.ndim > 4 or not data_format.lower() in ('chw', + 'hwc'): raise RuntimeError( 'not support [type={}, ndim={}, data_format={}] paddle image'. format(type(img), img.ndim, data_format)) @@ -276,7 +277,10 @@ def affine(img, matrix, interpolation="nearest", fill=None, data_format='CHW'): paddle.Tensor: Affined image. """ - img = img.unsqueeze(0) + ndim = len(img.shape) + if ndim == 3: + img = img.unsqueeze(0) + img = img if data_format.lower() == 'chw' else img.transpose((0, 3, 1, 2)) matrix = paddle.to_tensor(matrix, place=img.place) @@ -292,8 +296,9 @@ def affine(img, matrix, interpolation="nearest", fill=None, data_format='CHW'): out = _grid_transform(img, grid, mode=interpolation, fill=fill) out = out if data_format.lower() == 'chw' else out.transpose((0, 2, 3, 1)) + out = out.squeeze(0) if ndim == 3 else out - return out.squeeze(0) + return out def rotate(img, @@ -443,7 +448,9 @@ def perspective(img, """ - img = img.unsqueeze(0) + ndim = len(img.shape) + if ndim == 3: + img = img.unsqueeze(0) img = img if data_format.lower() == 'chw' else img.transpose((0, 3, 1, 2)) ow, oh = img.shape[-1], img.shape[-2] @@ -454,8 +461,9 @@ def perspective(img, out = _grid_transform(img, grid, mode=interpolation, fill=fill) out = out if data_format.lower() == 'chw' else out.transpose((0, 2, 3, 1)) + out = out.squeeze(0) if ndim == 3 else out - return out.squeeze(0) + return out def vflip(img, data_format='CHW'): diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index 79d3b1bc92e..fea2efb1fb2 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -45,7 +45,14 @@ def _get_image_size(img): elif F._is_numpy_image(img): return img.shape[:2][::-1] elif F._is_tensor_image(img): - return img.shape[1:][::-1] # chw + if len(img.shape) == 3: + return img.shape[1:][::-1] # chw -> wh + elif len(img.shape) == 4: + return img.shape[2:][::-1] # nchw -> wh + else: + raise ValueError( + "The dim for input Tensor should be 3-D or 4-D, but received {}". + format(len(img.shape))) else: raise TypeError("Unexpected type {}".format(type(img))) -- GitLab