From a9d330a390b7151fdc363413537afc30b8d3bcd1 Mon Sep 17 00:00:00 2001 From: Wenyu Date: Thu, 6 May 2021 15:46:29 +0800 Subject: [PATCH] [cherry-pick pr31970] Support transforms for paddle tensor image (#32705) * add to_grayscale, normalize * add rotate * add vfip and hflip * add crop center_crop * add utils * add utils * update utils, add raise for some cases * add padding, support constant, reflect, replicate, circular same as paddle.pad * update rotate * using utils func in [v|h]flip * add get-image-[n,c,w,h] axis utils * add get-image-[n,c,w,h] axis utils * align * update * remove default value in utils func * add assert for pad * update assert paddle image * support rotate fill func * raise valueerror for pad * remove typing, py2 dont support * init uinttest for transforms tensor * add resize op * register [normalize hflip crop center_crop resize transpose] imagenet * register [normalize hflip crop center_crop resize transpose] imagenet * fix bugs, (w, h) getter and import * add _get_image_size for tensor image * add pad vflip for tensor image * add unittest for tensor transforms * update transforms unittest for converage CI probelms, test=develop * update * update * update * fix `get_shape` for tensor backend * update * update * add more resize tests * update * update for ci test * update * remove redundancy code * update uinttest, and set tensor image to hwc by default * add tensor backend * fix copyright doc, rm comment code, add pil unittest * update data_format to `chw` for tensor * coverage notest,test=coverage * update * update --- python/paddle/tests/test_transforms.py | 230 ++++++++- python/paddle/vision/image.py | 10 +- python/paddle/vision/transforms/functional.py | 75 ++- .../vision/transforms/functional_tensor.py | 488 +++++++++++++++++- python/paddle/vision/transforms/transforms.py | 5 + 5 files changed, 764 insertions(+), 44 deletions(-) diff --git a/python/paddle/tests/test_transforms.py b/python/paddle/tests/test_transforms.py index 5086a12d945..c84950fdbc5 100644 --- a/python/paddle/tests/test_transforms.py +++ b/python/paddle/tests/test_transforms.py @@ -56,7 +56,10 @@ class TestTransformsCV2(unittest.TestCase): 'uint8')) def get_shape(self, img): - if self.backend == 'pil': + if isinstance(img, paddle.Tensor): + return img.shape + + elif self.backend == 'pil': return np.array(img).shape return img.shape @@ -253,6 +256,22 @@ class TestTransformsCV2(unittest.TestCase): fake_img = self.create_image((100, 120, 3)) F.pad(fake_img, [1.0, 2.0, 3.0]) + with self.assertRaises(TypeError): + tensor_img = paddle.rand((3, 100, 100)) + F.pad(tensor_img, '1') + + with self.assertRaises(TypeError): + tensor_img = paddle.rand((3, 100, 100)) + F.pad(tensor_img, 1, {}) + + with self.assertRaises(TypeError): + tensor_img = paddle.rand((3, 100, 100)) + F.pad(tensor_img, 1, padding_mode=-1) + + with self.assertRaises(ValueError): + tensor_img = paddle.rand((3, 100, 100)) + F.pad(tensor_img, [1.0, 2.0, 3.0]) + with self.assertRaises(ValueError): transforms.RandomRotation(-2) @@ -290,6 +309,159 @@ class TestTransformsPIL(TestTransformsCV2): return 'pil' +class TestTransformsTensor(TestTransformsCV2): + def get_backend(self): + return 'tensor' + + def create_image(self, shape): + return paddle.to_tensor(np.random.rand(*shape)).transpose( + (2, 0, 1)) # hwc->chw + + def do_transform(self, trans): + trans.transforms.insert(0, transforms.ToTensor(data_format='CHW')) + trans.transforms.append(transforms.Transpose(order=(1, 2, 0))) + dataset_folder = DatasetFolder(self.data_dir, transform=trans) + for _ in dataset_folder: + pass + + def test_trans_all(self): + normalize = transforms.Normalize( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.120, 57.375], ) + trans = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + normalize, + ]) + self.do_transform(trans) + + def test_grayscale(self): + trans = transforms.Compose([transforms.Grayscale()]) + self.do_transform(trans) + + trans_gray = transforms.Grayscale() + fake_img = self.create_image((500, 400, 3)) + fake_img_gray = trans_gray(fake_img) + + np.testing.assert_equal(self.get_shape(fake_img_gray)[1], 500) + np.testing.assert_equal(self.get_shape(fake_img_gray)[2], 400) + + trans_gray3 = transforms.Grayscale(3) + fake_img = self.create_image((500, 400, 3)) + fake_img_gray = trans_gray3(fake_img) + + def test_normalize(self): + normalize = transforms.Normalize(mean=0.5, std=0.5) + trans = transforms.Compose([normalize]) + self.do_transform(trans) + + def test_pad(self): + trans = transforms.Compose([transforms.Pad(2)]) + self.do_transform(trans) + + fake_img = self.create_image((200, 150, 3)) + trans_pad = transforms.Compose([transforms.Pad(10)]) + fake_img_padded = trans_pad(fake_img) + np.testing.assert_equal(self.get_shape(fake_img_padded), (3, 220, 170)) + trans_pad1 = transforms.Pad([1, 2]) + trans_pad2 = transforms.Pad([1, 2, 3, 4]) + trans_pad4 = transforms.Pad(1, padding_mode='edge') + img = trans_pad1(fake_img) + img = trans_pad2(img) + img = trans_pad4(img) + + def test_random_crop(self): + trans = transforms.Compose([ + transforms.RandomCrop(200), + transforms.RandomCrop((140, 160)), + ]) + self.do_transform(trans) + + trans_random_crop1 = transforms.RandomCrop(224) + trans_random_crop2 = transforms.RandomCrop((140, 160)) + + fake_img = self.create_image((500, 400, 3)) + fake_img_crop1 = trans_random_crop1(fake_img) + fake_img_crop2 = trans_random_crop2(fake_img_crop1) + + np.testing.assert_equal(self.get_shape(fake_img_crop1), (3, 224, 224)) + + np.testing.assert_equal(self.get_shape(fake_img_crop2), (3, 140, 160)) + + trans_random_crop_same = transforms.RandomCrop((140, 160)) + img = trans_random_crop_same(fake_img_crop2) + + trans_random_crop_bigger = transforms.RandomCrop( + (180, 200), pad_if_needed=True) + img = trans_random_crop_bigger(img) + + trans_random_crop_pad = transforms.RandomCrop((224, 256), 2, True) + img = trans_random_crop_pad(img) + + def test_exception(self): + trans = transforms.Compose([transforms.Resize(-1)]) + + trans_batch = transforms.Compose([transforms.Resize(-1)]) + + with self.assertRaises(Exception): + self.do_transform(trans) + + with self.assertRaises(Exception): + self.do_transform(trans_batch) + + with self.assertRaises(ValueError): + transforms.Pad([1.0, 2.0, 3.0]) + + with self.assertRaises(TypeError): + fake_img = self.create_image((100, 120, 3)) + F.pad(fake_img, '1') + + with self.assertRaises(TypeError): + fake_img = self.create_image((100, 120, 3)) + F.pad(fake_img, 1, {}) + + with self.assertRaises(TypeError): + fake_img = self.create_image((100, 120, 3)) + F.pad(fake_img, 1, padding_mode=-1) + + with self.assertRaises(ValueError): + fake_img = self.create_image((100, 120, 3)) + F.pad(fake_img, [1.0, 2.0, 3.0]) + + with self.assertRaises(TypeError): + tensor_img = paddle.rand((3, 100, 100)) + F.pad(tensor_img, '1') + + with self.assertRaises(TypeError): + tensor_img = paddle.rand((3, 100, 100)) + F.pad(tensor_img, 1, {}) + + with self.assertRaises(TypeError): + tensor_img = paddle.rand((3, 100, 100)) + F.pad(tensor_img, 1, padding_mode=-1) + + with self.assertRaises(ValueError): + tensor_img = paddle.rand((3, 100, 100)) + F.pad(tensor_img, [1.0, 2.0, 3.0]) + + with self.assertRaises(ValueError): + transforms.RandomRotation(-2) + + with self.assertRaises(ValueError): + transforms.RandomRotation([1, 2, 3]) + + with self.assertRaises(ValueError): + trans_gray = transforms.Grayscale(5) + fake_img = self.create_image((100, 120, 3)) + trans_gray(fake_img) + + with self.assertRaises(TypeError): + transform = transforms.RandomResizedCrop(64) + transform(1) + + test_color_jitter = None + + class TestFunctional(unittest.TestCase): def test_errors(self): with self.assertRaises(TypeError): @@ -300,6 +472,14 @@ class TestFunctional(unittest.TestCase): 'uint8')) F.to_tensor(fake_img, data_format=1) + with self.assertRaises(ValueError): + fake_img = paddle.rand((3, 100, 100)) + F.pad(fake_img, 1, padding_mode='symmetric') + + with self.assertRaises(TypeError): + fake_img = paddle.rand((3, 100, 100)) + F.resize(fake_img, {1: 1}) + with self.assertRaises(TypeError): fake_img = Image.fromarray((np.random.rand(28, 28, 3) * 255).astype( 'uint8')) @@ -354,31 +534,50 @@ class TestFunctional(unittest.TestCase): std = [0.5, 0.5, 0.5] normalized_img = F.normalize(tensor_img, mean, std) - normalized_img = F.normalize( + normalized_img_tensor = F.normalize( tensor_img_hwc, mean, std, data_format='HWC') - normalized_img = F.normalize(pil_img, mean, std, data_format='HWC') - normalized_img = F.normalize( + normalized_img_pil = F.normalize(pil_img, mean, std, data_format='HWC') + normalized_img_np = F.normalize( np_img, mean, std, data_format='HWC', to_rgb=True) + np.testing.assert_almost_equal( + np.array(normalized_img_pil), normalized_img_np) + np.testing.assert_almost_equal(normalized_img_tensor.numpy(), + normalized_img_np) + def test_center_crop(self): np_img = (np.random.rand(28, 24, 3)).astype('uint8') pil_img = Image.fromarray(np_img) + tensor_img = F.to_tensor(pil_img, data_format='CHW') np_cropped_img = F.center_crop(np_img, 4) pil_cropped_img = F.center_crop(pil_img, 4) + tensor_cropped_img = F.center_crop(tensor_img, 4) np.testing.assert_almost_equal(np_cropped_img, np.array(pil_cropped_img)) + np.testing.assert_almost_equal(np_cropped_img, + tensor_cropped_img.numpy().transpose( + (1, 2, 0))) def test_pad(self): np_img = (np.random.rand(28, 24, 3)).astype('uint8') pil_img = Image.fromarray(np_img) + tensor_img = F.to_tensor(pil_img, 'CHW') np_padded_img = F.pad(np_img, [1, 2], padding_mode='reflect') pil_padded_img = F.pad(pil_img, [1, 2], padding_mode='reflect') + tensor_padded_img = F.pad(tensor_img, [1, 2], padding_mode='reflect') np.testing.assert_almost_equal(np_padded_img, np.array(pil_padded_img)) + np.testing.assert_almost_equal(np_padded_img, + tensor_padded_img.numpy().transpose( + (1, 2, 0))) + + tensor_padded_img = F.pad(tensor_img, 1, padding_mode='reflect') + tensor_padded_img = F.pad(tensor_img, [1, 2, 1, 2], + padding_mode='reflect') pil_p_img = pil_img.convert('P') pil_padded_img = F.pad(pil_p_img, [1, 2]) @@ -387,12 +586,21 @@ class TestFunctional(unittest.TestCase): def test_resize(self): np_img = (np.zeros([28, 24, 3])).astype('uint8') pil_img = Image.fromarray(np_img) + tensor_img = F.to_tensor(pil_img, 'CHW') np_reseized_img = F.resize(np_img, 40) pil_reseized_img = F.resize(pil_img, 40) + tensor_reseized_img = F.resize(tensor_img, 40) + tensor_reseized_img2 = F.resize(tensor_img, (46, 40)) np.testing.assert_almost_equal(np_reseized_img, np.array(pil_reseized_img)) + np.testing.assert_almost_equal(np_reseized_img, + tensor_reseized_img.numpy().transpose( + (1, 2, 0))) + np.testing.assert_almost_equal(np_reseized_img, + tensor_reseized_img2.numpy().transpose( + (1, 2, 0))) gray_img = (np.zeros([28, 32])).astype('uint8') gray_resize_img = F.resize(gray_img, 40) @@ -447,12 +655,24 @@ class TestFunctional(unittest.TestCase): def test_rotate(self): np_img = (np.random.rand(28, 28, 3) * 255).astype('uint8') pil_img = Image.fromarray(np_img).convert('RGB') - rotated_np_img = F.rotate(np_img, 80, expand=True) rotated_pil_img = F.rotate(pil_img, 80, expand=True) + tensor_img = F.to_tensor(pil_img, 'CHW') + + rotated_tensor_img1 = F.rotate(tensor_img, 80, expand=True) + + rotated_tensor_img2 = F.rotate( + tensor_img, + 80, + interpolation='bilinear', + center=(10, 10), + expand=False) + np.testing.assert_equal(rotated_np_img.shape, np.array(rotated_pil_img).shape) + np.testing.assert_equal(rotated_np_img.shape, + rotated_tensor_img1.transpose((1, 2, 0)).shape) def test_rotate1(self): np_img = (np.random.rand(28, 28, 3) * 255).astype('uint8') diff --git a/python/paddle/vision/image.py b/python/paddle/vision/image.py index 3d5ea3a73af..19986816b7c 100644 --- a/python/paddle/vision/image.py +++ b/python/paddle/vision/image.py @@ -80,9 +80,9 @@ def set_image_backend(backend): shutil.rmtree(temp_dir) """ global _image_backend - if backend not in ['pil', 'cv2']: + if backend not in ['pil', 'cv2', 'tensor']: raise ValueError( - "Expected backend are one of ['pil', 'cv2'], but got {}" + "Expected backend are one of ['pil', 'cv2', 'tensor'], but got {}" .format(backend)) _image_backend = backend @@ -150,13 +150,13 @@ def image_load(path, backend=None): if backend is None: backend = _image_backend - if backend not in ['pil', 'cv2']: + if backend not in ['pil', 'cv2', 'tensor']: raise ValueError( - "Expected backend are one of ['pil', 'cv2'], but got {}" + "Expected backend are one of ['pil', 'cv2', 'tensor'], but got {}" .format(backend)) if backend == 'pil': return Image.open(path) - else: + elif backend == 'cv2': cv2 = try_import('cv2') return cv2.imread(path) diff --git a/python/paddle/vision/transforms/functional.py b/python/paddle/vision/transforms/functional.py index c0e72877ffc..18a35915c99 100644 --- a/python/paddle/vision/transforms/functional.py +++ b/python/paddle/vision/transforms/functional.py @@ -25,13 +25,6 @@ from PIL import Image from numpy import sin, cos, tan import paddle -if sys.version_info < (3, 3): - Sequence = collections.Sequence - Iterable = collections.Iterable -else: - Sequence = collections.abc.Sequence - Iterable = collections.abc.Iterable - from . import functional_pil as F_pil from . import functional_cv2 as F_cv2 from . import functional_tensor as F_t @@ -83,14 +76,18 @@ def to_tensor(pic, data_format='CHW'): print(tensor.shape) """ - if not (_is_pil_image(pic) or _is_numpy_image(pic)): - raise TypeError('pic should be PIL Image or ndarray. Got {}'.format( - type(pic))) + if not (_is_pil_image(pic) or _is_numpy_image(pic) or + _is_tensor_image(pic)): + raise TypeError( + 'pic should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(pic))) if _is_pil_image(pic): return F_pil.to_tensor(pic, data_format) - else: + elif _is_numpy_image(pic): return F_cv2.to_tensor(pic, data_format) + else: + return pic if data_format.lower() == 'chw' else pic.transpose((1, 2, 0)) def resize(img, size, interpolation='bilinear'): @@ -135,13 +132,16 @@ def resize(img, size, interpolation='bilinear'): converted_img = F.resize(fake_img, (200, 150)) print(converted_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.resize(img, size, interpolation) + elif _is_tensor_image(img): + return F_t.resize(img, size, interpolation) else: return F_cv2.resize(img, size, interpolation) @@ -196,13 +196,16 @@ def pad(img, padding, fill=0, padding_mode='constant'): padded_img = F.pad(fake_img, padding=(2, 1)) print(padded_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.pad(img, padding, fill, padding_mode) + elif _is_tensor_image(img): + return F_t.pad(img, padding, fill, padding_mode) else: return F_cv2.pad(img, padding, fill, padding_mode) @@ -236,13 +239,16 @@ def crop(img, top, left, height, width): print(cropped_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.crop(img, top, left, height, width) + elif _is_tensor_image(img): + return F_t.crop(img, top, left, height, width) else: return F_cv2.crop(img, top, left, height, width) @@ -272,13 +278,16 @@ def center_crop(img, output_size): cropped_img = F.center_crop(fake_img, (150, 100)) print(cropped_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.center_crop(img, output_size) + elif _is_tensor_image(img): + return F_t.center_crop(img, output_size) else: return F_cv2.center_crop(img, output_size) @@ -307,13 +316,16 @@ def hflip(img): print(flpped_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.hflip(img) + elif _is_tensor_image(img): + return F_t.hflip(img) else: return F_cv2.hflip(img) @@ -342,13 +354,16 @@ def vflip(img): print(flpped_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.vflip(img) + elif _is_tensor_image(img): + return F_t.vflip(img) else: return F_cv2.vflip(img) @@ -563,9 +578,10 @@ def rotate(img, print(rotated_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if isinstance(center, list): @@ -575,6 +591,8 @@ def rotate(img, if _is_pil_image(img): return F_pil.rotate(img, angle, interpolation, expand, center, fill) + elif _is_tensor_image(img): + return F_t.rotate(img, angle, interpolation, expand, center, fill) else: return F_cv2.rotate(img, angle, interpolation, expand, center, fill) @@ -606,13 +624,16 @@ def to_grayscale(img, num_output_channels=1): print(gray_img.size) """ - if not (_is_pil_image(img) or _is_numpy_image(img)): + if not (_is_pil_image(img) or _is_numpy_image(img) or + _is_tensor_image(img)): raise TypeError( - 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + 'img should be PIL Image or Tensor Image or ndarray with dim=[2 or 3]. Got {}'. format(type(img))) if _is_pil_image(img): return F_pil.to_grayscale(img, num_output_channels) + elif _is_tensor_image(img): + return F_t.to_grayscale(img, num_output_channels) else: return F_cv2.to_grayscale(img, num_output_channels) diff --git a/python/paddle/vision/transforms/functional_tensor.py b/python/paddle/vision/transforms/functional_tensor.py index e8b70820dd9..7f490d57916 100644 --- a/python/paddle/vision/transforms/functional_tensor.py +++ b/python/paddle/vision/transforms/functional_tensor.py @@ -14,11 +14,78 @@ from __future__ import division +import math +import numbers + import paddle +import paddle.nn.functional as F + +import sys +import collections + + +def _assert_image_tensor(img, data_format): + if not isinstance( + img, paddle.Tensor) or img.ndim != 3 or not data_format.lower() in ( + 'chw', 'hwc'): + raise RuntimeError( + 'not support [type={}, ndim={}, data_format={}] paddle image'. + format(type(img), img.ndim, data_format)) + + +def _get_image_h_axis(data_format): + if data_format.lower() == 'chw': + return -2 + elif data_format.lower() == 'hwc': + return -3 + + +def _get_image_w_axis(data_format): + if data_format.lower() == 'chw': + return -1 + elif data_format.lower() == 'hwc': + return -2 + + +def _get_image_c_axis(data_format): + if data_format.lower() == 'chw': + return -3 + elif data_format.lower() == 'hwc': + return -1 + + +def _get_image_n_axis(data_format): + if len(data_format) == 3: + return None + elif len(data_format) == 4: + return 0 + + +def _is_channel_last(data_format): + return _get_image_c_axis(data_format) == -1 + + +def _is_channel_first(data_format): + return _get_image_c_axis(data_format) == -3 + + +def _get_image_num_batches(img, data_format): + if _get_image_n_axis(data_format): + return img.shape[_get_image_n_axis(data_format)] + return None + + +def _get_image_num_channels(img, data_format): + return img.shape[_get_image_c_axis(data_format)] + + +def _get_image_size(img, data_format): + return img.shape[_get_image_w_axis(data_format)], img.shape[ + _get_image_h_axis(data_format)] def normalize(img, mean, std, data_format='CHW'): - """Normalizes a tensor image with mean and standard deviation. + """Normalizes a tensor image given mean and standard deviation. Args: img (paddle.Tensor): input data to be normalized. @@ -31,10 +98,417 @@ def normalize(img, mean, std, data_format='CHW'): Tensor: Normalized mage. """ - if data_format == 'CHW': - mean = paddle.to_tensor(mean).reshape([-1, 1, 1]) - std = paddle.to_tensor(std).reshape([-1, 1, 1]) - else: - mean = paddle.to_tensor(mean) - std = paddle.to_tensor(std) + _assert_image_tensor(img, data_format) + + mean = paddle.to_tensor(mean, place=img.place) + std = paddle.to_tensor(std, place=img.place) + + if _is_channel_first(data_format): + mean = mean.reshape([-1, 1, 1]) + std = std.reshape([-1, 1, 1]) + return (img - mean) / std + + +def to_grayscale(img, num_output_channels=1, data_format='CHW'): + """Converts image to grayscale version of image. + + Args: + img (paddel.Tensor): Image to be converted to grayscale. + num_output_channels (int, optionl[1, 3]): + if num_output_channels = 1 : returned image is single channel + if num_output_channels = 3 : returned image is 3 channel + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + + Returns: + paddle.Tensor: Grayscale version of the image. + """ + _assert_image_tensor(img, data_format) + + if num_output_channels not in (1, 3): + raise ValueError('num_output_channels should be either 1 or 3') + + rgb_weights = paddle.to_tensor( + [0.2989, 0.5870, 0.1140], place=img.place).astype(img.dtype) + + if _is_channel_first(data_format): + rgb_weights = rgb_weights.reshape((-1, 1, 1)) + + _c_index = _get_image_c_axis(data_format) + + img = (img * rgb_weights).sum(axis=_c_index, keepdim=True) + _shape = img.shape + _shape[_c_index] = num_output_channels + + return img.expand(_shape) + + +def _affine_grid(theta, w, h, ow, oh): + d = 0.5 + base_grid = paddle.ones((1, oh, ow, 3), dtype=theta.dtype) + + x_grid = paddle.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, ow) + base_grid[..., 0] = x_grid + y_grid = paddle.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, oh).unsqueeze_(-1) + base_grid[..., 1] = y_grid + + scaled_theta = theta.transpose( + (0, 2, 1)) / paddle.to_tensor([0.5 * w, 0.5 * h]) + output_grid = base_grid.reshape((1, oh * ow, 3)).bmm(scaled_theta) + + return output_grid.reshape((1, oh, ow, 2)) + + +def _grid_transform(img, grid, mode, fill): + if img.shape[0] > 1: + grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], + grid.shape[3]) + + if fill is not None: + dummy = paddle.ones( + (img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype) + img = paddle.concat((img, dummy), axis=1) + + img = F.grid_sample( + img, grid, mode=mode, padding_mode="zeros", align_corners=False) + + # Fill with required color + if fill is not None: + mask = img[:, -1:, :, :] # n 1 h w + img = img[:, :-1, :, :] # n c h w + mask = mask.expand_as(img) + len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1 + fill_img = paddle.to_tensor(fill).reshape( + (1, len_fill, 1, 1)).expand_as(img) + + if mode == 'nearest': + mask = paddle.cast(mask < 0.5, img.dtype) + img = img * (1. - mask) + mask * fill_img + else: # 'bilinear' + img = img * mask + (1.0 - mask) * fill_img + + return img + + +def rotate(img, + angle, + interpolation='nearest', + expand=False, + center=None, + fill=None, + data_format='CHW'): + """Rotates the image by angle. + + Args: + img (paddle.Tensor): Image to be rotated. + angle (float or int): In degrees degrees counter clockwise order. + interpolation (str, optional): Interpolation method. If omitted, or if the + image has only one channel, it is set NEAREST . when use pil backend, + support method are as following: + - "nearest" + - "bilinear" + - "bicubic" + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + fill (3-tuple or int): RGB pixel fill value for area outside the rotated image. + If int, it is used for all channels respectively. + + Returns: + paddle.Tensor: Rotated image. + + """ + + angle = -angle % 360 + img = img.unsqueeze(0) + + # n, c, h, w = img.shape + w, h = _get_image_size(img, data_format=data_format) + + img = img if data_format.lower() == 'chw' else img.transpose((0, 3, 1, 2)) + + post_trans = [0, 0] + + if center is None: + rotn_center = [0, 0] + else: + rotn_center = [(p - s * 0.5) for p, s in zip(center, [w, h])] + + angle = math.radians(angle) + matrix = [ + math.cos(angle), + math.sin(angle), + 0.0, + -math.sin(angle), + math.cos(angle), + 0.0, + ] + + matrix[2] += matrix[0] * (-rotn_center[0] - post_trans[0]) + matrix[1] * ( + -rotn_center[1] - post_trans[1]) + matrix[5] += matrix[3] * (-rotn_center[0] - post_trans[0]) + matrix[4] * ( + -rotn_center[1] - post_trans[1]) + + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + + matrix = paddle.to_tensor(matrix, place=img.place) + matrix = matrix.reshape((1, 2, 3)) + + if expand: + # calculate output size + corners = paddle.to_tensor( + [[-0.5 * w, -0.5 * h, 1.0], [-0.5 * w, 0.5 * h, 1.0], + [0.5 * w, 0.5 * h, 1.0], [0.5 * w, -0.5 * h, 1.0]], + place=matrix.place).astype(matrix.dtype) + + _pos = corners.reshape( + (1, -1, 3)).bmm(matrix.transpose((0, 2, 1))).reshape((1, -1, 2)) + _min = _pos.min(axis=-2).floor() + _max = _pos.max(axis=-2).ceil() + + npos = _max - _min + nw = npos[0][0] + nh = npos[0][1] + + ow, oh = int(nw.numpy()[0]), int(nh.numpy()[0]) + + else: + ow, oh = w, h + + grid = _affine_grid(matrix, w, h, ow, oh) + + out = _grid_transform(img, grid, mode=interpolation, fill=fill) + + out = out if data_format.lower() == 'chw' else out.transpose((0, 2, 3, 1)) + + return out.squeeze(0) + + +def vflip(img, data_format='CHW'): + """Vertically flips the given paddle tensor. + + Args: + img (paddle.Tensor): Image to be flipped. + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + + Returns: + paddle.Tensor: Vertically flipped image. + + """ + _assert_image_tensor(img, data_format) + + h_axis = _get_image_h_axis(data_format) + + return img.flip(axis=[h_axis]) + + +def hflip(img, data_format='CHW'): + """Horizontally flips the given paddle.Tensor Image. + + Args: + img (paddle.Tensor): Image to be flipped. + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + + Returns: + paddle.Tensor: Horizontall flipped image. + + """ + _assert_image_tensor(img, data_format) + + w_axis = _get_image_w_axis(data_format) + + return img.flip(axis=[w_axis]) + + +def crop(img, top, left, height, width, data_format='CHW'): + """Crops the given paddle.Tensor Image. + + Args: + img (paddle.Tensor): Image to be cropped. (0,0) denotes the top left + corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + Returns: + paddle.Tensor: Cropped image. + + """ + _assert_image_tensor(img, data_format) + + if _is_channel_first(data_format): + return img[:, top:top + height, left:left + width] + else: + return img[top:top + height, left:left + width, :] + + +def center_crop(img, output_size, data_format='CHW'): + """Crops the given paddle.Tensor Image and resize it to desired size. + + Args: + img (paddle.Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. + output_size (sequence or int): (height, width) of the crop box. If int, + it is used for both directions + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + Returns: + paddle.Tensor: Cropped image. + + """ + _assert_image_tensor(img, data_format) + + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + + image_width, image_height = _get_image_size(img, data_format) + crop_height, crop_width = output_size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop( + img, + crop_top, + crop_left, + crop_height, + crop_width, + data_format=data_format) + + +def pad(img, padding, fill=0, padding_mode='constant', data_format='CHW'): + """ + Pads the given paddle.Tensor on all sides with specified padding mode and fill value. + + Args: + img (paddle.Tensor): Image to be padded. + padding (int|list|tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill (float, optional): Pixel fill value for constant fill. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. Default: 0. + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default: 'constant'. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + + Returns: + paddle.Tensor: Padded image. + + """ + _assert_image_tensor(img, data_format) + + if not isinstance(padding, (numbers.Number, list, tuple)): + raise TypeError('Got inappropriate padding arg') + if not isinstance(fill, (numbers.Number, str, list, tuple)): + raise TypeError('Got inappropriate fill arg') + if not isinstance(padding_mode, str): + raise TypeError('Got inappropriate padding_mode arg') + + if isinstance(padding, (list, tuple)) and len(padding) not in [2, 4]: + raise ValueError( + "Padding must be an int or a 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ + 'Padding mode should be either constant, edge, reflect or symmetric' + + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + else: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + padding = [pad_left, pad_right, pad_top, pad_bottom] + + if padding_mode == 'edge': + padding_mode = 'replicate' + elif padding_mode == 'symmetric': + raise ValueError('Do not support symmetric mdoe') + + img = img.unsqueeze(0) + # 'constant', 'reflect', 'replicate', 'circular' + img = F.pad(img, + pad=padding, + mode=padding_mode, + value=float(fill), + data_format='N' + data_format) + + return img.squeeze(0) + + +def resize(img, size, interpolation='bilinear', data_format='CHW'): + """ + Resizes the image to given size + + Args: + input (paddle.Tensor): Image to be resized. + size (int|list|tuple): Target size of input data, with (height, width) shape. + interpolation (int|str, optional): Interpolation method. when use paddle backend, + support method are as following: + - "nearest" + - "bilinear" + - "bicubic" + - "trilinear" + - "area" + - "linear" + data_format (str, optional): paddle.Tensor format + - 'CHW' + - 'HWC' + Returns: + paddle.Tensor: Resized image. + + """ + _assert_image_tensor(img, data_format) + + if not (isinstance(size, int) or + (isinstance(size, (tuple, list)) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + if isinstance(size, int): + w, h = _get_image_size(img, data_format) + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + else: + oh, ow = size + + img = img.unsqueeze(0) + img = F.interpolate( + img, + size=(oh, ow), + mode=interpolation.lower(), + data_format='N' + data_format.upper()) + + return img.squeeze(0) diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index 6eeb726fcee..00e12689c4d 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -49,6 +49,8 @@ def _get_image_size(img): return img.size elif F._is_numpy_image(img): return img.shape[:2][::-1] + elif F._is_tensor_image(img): + return img.shape[1:][::-1] # chw else: raise TypeError("Unexpected type {}".format(type(img))) @@ -690,6 +692,9 @@ class Transpose(BaseTransform): self.order = order def _apply_image(self, img): + if F._is_tensor_image(img): + return img.transpose(self.order) + if F._is_pil_image(img): img = np.asarray(img) -- GitLab