From 74c8a811276a09f6a774dea98904656468ce56bf Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Tue, 20 Oct 2020 17:21:26 +0800 Subject: [PATCH] Add pil backend for vision transforms (#28035) * add pil backend --- python/paddle/tests/test_callbacks.py | 2 +- python/paddle/tests/test_transforms.py | 297 ++++- python/paddle/vision/__init__.py | 6 +- python/paddle/vision/datasets/folder.py | 21 +- python/paddle/vision/image.py | 162 +++ python/paddle/vision/transforms/functional.py | 747 +++++++---- .../vision/transforms/functional_cv2.py | 503 ++++++++ .../vision/transforms/functional_pil.py | 458 +++++++ .../vision/transforms/functional_tensor.py | 40 + python/paddle/vision/transforms/transforms.py | 1149 +++++++++-------- 10 files changed, 2556 insertions(+), 829 deletions(-) create mode 100644 python/paddle/vision/image.py create mode 100644 python/paddle/vision/transforms/functional_cv2.py create mode 100644 python/paddle/vision/transforms/functional_pil.py create mode 100644 python/paddle/vision/transforms/functional_tensor.py diff --git a/python/paddle/tests/test_callbacks.py b/python/paddle/tests/test_callbacks.py index b9442c46b8..5c349c5f1d 100644 --- a/python/paddle/tests/test_callbacks.py +++ b/python/paddle/tests/test_callbacks.py @@ -105,7 +105,7 @@ class TestCallbacks(unittest.TestCase): self.run_callback() def test_visualdl_callback(self): - # visualdl not support python3 + # visualdl not support python2 if sys.version_info < (3, ): return diff --git a/python/paddle/tests/test_transforms.py b/python/paddle/tests/test_transforms.py index 6c2944d1e7..ac21f8a619 100644 --- a/python/paddle/tests/test_transforms.py +++ b/python/paddle/tests/test_transforms.py @@ -18,14 +18,19 @@ import tempfile import cv2 import shutil import numpy as np +from PIL import Image +import paddle +from paddle.vision import get_image_backend, set_image_backend, image_load from paddle.vision.datasets import DatasetFolder from paddle.vision.transforms import transforms import paddle.vision.transforms.functional as F -class TestTransforms(unittest.TestCase): +class TestTransformsCV2(unittest.TestCase): def setUp(self): + self.backend = self.get_backend() + set_image_backend(self.backend) self.data_dir = tempfile.mkdtemp() for i in range(2): sub_dir = os.path.join(self.data_dir, 'class_' + str(i)) @@ -40,6 +45,22 @@ class TestTransforms(unittest.TestCase): (400, 300, 3)) * 255).astype('uint8') cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img) + def get_backend(self): + return 'cv2' + + def create_image(self, shape): + if self.backend == 'cv2': + return (np.random.rand(*shape) * 255).astype('uint8') + elif self.backend == 'pil': + return Image.fromarray((np.random.rand(*shape) * 255).astype( + 'uint8')) + + def get_shape(self, img): + if self.backend == 'pil': + return np.array(img).shape + + return img.shape + def tearDown(self): shutil.rmtree(self.data_dir) @@ -51,27 +72,29 @@ class TestTransforms(unittest.TestCase): def test_trans_all(self): normalize = transforms.Normalize( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375]) + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.120, 57.375], ) trans = transforms.Compose([ - transforms.RandomResizedCrop(224), transforms.GaussianNoise(), + transforms.RandomResizedCrop(224), transforms.ColorJitter( - brightness=0.4, contrast=0.4, saturation=0.4, - hue=0.4), transforms.RandomHorizontalFlip(), - transforms.Permute(mode='CHW'), normalize + brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4), + transforms.RandomHorizontalFlip(), + transforms.Transpose(), + normalize, ]) self.do_transform(trans) def test_normalize(self): normalize = transforms.Normalize(mean=0.5, std=0.5) - trans = transforms.Compose([transforms.Permute(mode='CHW'), normalize]) + trans = transforms.Compose([transforms.Transpose(), normalize]) self.do_transform(trans) def test_trans_resize(self): trans = transforms.Compose([ - transforms.Resize(300, [0, 1]), + transforms.Resize(300), transforms.RandomResizedCrop((280, 280)), - transforms.Resize(280, [0, 1]), + transforms.Resize(280), transforms.Resize((256, 200)), transforms.Resize((180, 160)), transforms.CenterCrop(128), @@ -79,13 +102,6 @@ class TestTransforms(unittest.TestCase): ]) self.do_transform(trans) - def test_trans_centerCrop(self): - trans = transforms.Compose([ - transforms.CenterCropResize(224), - transforms.CenterCropResize(128, 160), - ]) - self.do_transform(trans) - def test_flip(self): trans = transforms.Compose([ transforms.RandomHorizontalFlip(1.0), @@ -96,7 +112,7 @@ class TestTransforms(unittest.TestCase): self.do_transform(trans) def test_color_jitter(self): - trans = transforms.BatchCompose([ + trans = transforms.Compose([ transforms.BrightnessTransform(0.0), transforms.HueTransform(0.0), transforms.SaturationTransform(0.0), @@ -106,11 +122,11 @@ class TestTransforms(unittest.TestCase): def test_rotate(self): trans = transforms.Compose([ - transforms.RandomRotate(90), - transforms.RandomRotate([-10, 10]), - transforms.RandomRotate( + transforms.RandomRotation(90), + transforms.RandomRotation([-10, 10]), + transforms.RandomRotation( 45, expand=True), - transforms.RandomRotate( + transforms.RandomRotation( 10, expand=True, center=(60, 80)), ]) self.do_transform(trans) @@ -119,20 +135,15 @@ class TestTransforms(unittest.TestCase): trans = transforms.Compose([transforms.Pad(2)]) self.do_transform(trans) - fake_img = np.random.rand(200, 150, 3).astype('float32') + fake_img = self.create_image((200, 150, 3)) trans_pad = transforms.Pad(10) fake_img_padded = trans_pad(fake_img) - np.testing.assert_equal(fake_img_padded.shape, (220, 170, 3)) + np.testing.assert_equal(self.get_shape(fake_img_padded), (220, 170, 3)) trans_pad1 = transforms.Pad([1, 2]) trans_pad2 = transforms.Pad([1, 2, 3, 4]) img = trans_pad1(fake_img) img = trans_pad2(img) - def test_erase(self): - trans = transforms.Compose( - [transforms.RandomErasing(), transforms.RandomErasing(value=0.0)]) - self.do_transform(trans) - def test_random_crop(self): trans = transforms.Compose([ transforms.RandomCrop(200), @@ -143,18 +154,19 @@ class TestTransforms(unittest.TestCase): trans_random_crop1 = transforms.RandomCrop(224) trans_random_crop2 = transforms.RandomCrop((140, 160)) - fake_img = np.random.rand(500, 400, 3).astype('float32') + fake_img = self.create_image((500, 400, 3)) fake_img_crop1 = trans_random_crop1(fake_img) fake_img_crop2 = trans_random_crop2(fake_img_crop1) - np.testing.assert_equal(fake_img_crop1.shape, (224, 224, 3)) + np.testing.assert_equal(self.get_shape(fake_img_crop1), (224, 224, 3)) - np.testing.assert_equal(fake_img_crop2.shape, (140, 160, 3)) + np.testing.assert_equal(self.get_shape(fake_img_crop2), (140, 160, 3)) trans_random_crop_same = transforms.RandomCrop((140, 160)) img = trans_random_crop_same(fake_img_crop2) - trans_random_crop_bigger = transforms.RandomCrop((180, 200)) + trans_random_crop_bigger = transforms.RandomCrop( + (180, 200), pad_if_needed=True) img = trans_random_crop_bigger(img) trans_random_crop_pad = transforms.RandomCrop((224, 256), 2, True) @@ -165,21 +177,38 @@ class TestTransforms(unittest.TestCase): self.do_transform(trans) trans_gray = transforms.Grayscale() - fake_img = np.random.rand(500, 400, 3).astype('float32') + fake_img = self.create_image((500, 400, 3)) fake_img_gray = trans_gray(fake_img) - np.testing.assert_equal(len(fake_img_gray.shape), 3) - np.testing.assert_equal(fake_img_gray.shape[0], 500) - np.testing.assert_equal(fake_img_gray.shape[1], 400) + np.testing.assert_equal(self.get_shape(fake_img_gray)[0], 500) + np.testing.assert_equal(self.get_shape(fake_img_gray)[1], 400) trans_gray3 = transforms.Grayscale(3) - fake_img = np.random.rand(500, 400, 3).astype('float32') + fake_img = self.create_image((500, 400, 3)) fake_img_gray = trans_gray3(fake_img) + def test_tranpose(self): + trans = transforms.Compose([transforms.Transpose()]) + self.do_transform(trans) + + fake_img = self.create_image((50, 100, 3)) + converted_img = trans(fake_img) + + np.testing.assert_equal(self.get_shape(converted_img), (3, 50, 100)) + + def test_to_tensor(self): + trans = transforms.Compose([transforms.ToTensor()]) + fake_img = self.create_image((50, 100, 3)) + + tensor = trans(fake_img) + + assert isinstance(tensor, paddle.Tensor) + np.testing.assert_equal(tensor.shape, (3, 50, 100)) + def test_exception(self): trans = transforms.Compose([transforms.Resize(-1)]) - trans_batch = transforms.BatchCompose([transforms.Resize(-1)]) + trans_batch = transforms.Compose([transforms.Resize(-1)]) with self.assertRaises(Exception): self.do_transform(trans) @@ -203,35 +232,211 @@ class TestTransforms(unittest.TestCase): transforms.Pad([1.0, 2.0, 3.0]) with self.assertRaises(TypeError): - fake_img = np.random.rand(100, 120, 3).astype('float32') + fake_img = self.create_image((100, 120, 3)) F.pad(fake_img, '1') with self.assertRaises(TypeError): - fake_img = np.random.rand(100, 120, 3).astype('float32') + fake_img = self.create_image((100, 120, 3)) F.pad(fake_img, 1, {}) with self.assertRaises(TypeError): - fake_img = np.random.rand(100, 120, 3).astype('float32') + fake_img = self.create_image((100, 120, 3)) F.pad(fake_img, 1, padding_mode=-1) with self.assertRaises(ValueError): - fake_img = np.random.rand(100, 120, 3).astype('float32') + fake_img = self.create_image((100, 120, 3)) F.pad(fake_img, [1.0, 2.0, 3.0]) with self.assertRaises(ValueError): - transforms.RandomRotate(-2) + transforms.RandomRotation(-2) with self.assertRaises(ValueError): - transforms.RandomRotate([1, 2, 3]) + transforms.RandomRotation([1, 2, 3]) with self.assertRaises(ValueError): trans_gray = transforms.Grayscale(5) - fake_img = np.random.rand(100, 120, 3).astype('float32') + fake_img = self.create_image((100, 120, 3)) trans_gray(fake_img) + with self.assertRaises(TypeError): + transform = transforms.RandomResizedCrop(64) + transform(1) + + with self.assertRaises(ValueError): + transform = transforms.BrightnessTransform([-0.1, -0.2]) + + with self.assertRaises(TypeError): + transform = transforms.BrightnessTransform('0.1') + + with self.assertRaises(ValueError): + transform = transforms.BrightnessTransform('0.1', keys=1) + + with self.assertRaises(NotImplementedError): + transform = transforms.BrightnessTransform('0.1', keys='a') + def test_info(self): str(transforms.Compose([transforms.Resize((224, 224))])) - str(transforms.BatchCompose([transforms.Resize((224, 224))])) + str(transforms.Compose([transforms.Resize((224, 224))])) + + +class TestTransformsPIL(TestTransformsCV2): + def get_backend(self): + return 'pil' + + +class TestFunctional(unittest.TestCase): + def test_errors(self): + with self.assertRaises(TypeError): + F.to_tensor(1) + + with self.assertRaises(ValueError): + fake_img = Image.fromarray((np.random.rand(28, 28, 3) * 255).astype( + 'uint8')) + F.to_tensor(fake_img, data_format=1) + + with self.assertRaises(TypeError): + fake_img = Image.fromarray((np.random.rand(28, 28, 3) * 255).astype( + 'uint8')) + F.resize(fake_img, '1') + + with self.assertRaises(TypeError): + F.resize(1, 1) + + with self.assertRaises(TypeError): + F.pad(1, 1) + + with self.assertRaises(TypeError): + F.crop(1, 1, 1, 1, 1) + + with self.assertRaises(TypeError): + F.hflip(1) + + with self.assertRaises(TypeError): + F.vflip(1) + + with self.assertRaises(TypeError): + F.adjust_brightness(1, 0.1) + + with self.assertRaises(TypeError): + F.adjust_contrast(1, 0.1) + + with self.assertRaises(TypeError): + F.adjust_hue(1, 0.1) + + with self.assertRaises(TypeError): + F.adjust_saturation(1, 0.1) + + with self.assertRaises(TypeError): + F.rotate(1, 0.1) + + with self.assertRaises(TypeError): + F.to_grayscale(1) + + with self.assertRaises(ValueError): + set_image_backend(1) + + with self.assertRaises(ValueError): + image_load('tmp.jpg', backend=1) + + def test_normalize(self): + np_img = (np.random.rand(28, 24, 3)).astype('uint8') + pil_img = Image.fromarray(np_img) + tensor_img = F.to_tensor(pil_img) + tensor_img_hwc = F.to_tensor(pil_img, data_format='HWC') + + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + + normalized_img = F.normalize(tensor_img, mean, std) + normalized_img = F.normalize( + tensor_img_hwc, mean, std, data_format='HWC') + + normalized_img = F.normalize(pil_img, mean, std, data_format='HWC') + normalized_img = F.normalize( + np_img, mean, std, data_format='HWC', to_rgb=True) + + def test_center_crop(self): + np_img = (np.random.rand(28, 24, 3)).astype('uint8') + pil_img = Image.fromarray(np_img) + + np_cropped_img = F.center_crop(np_img, 4) + pil_cropped_img = F.center_crop(pil_img, 4) + + np.testing.assert_almost_equal(np_cropped_img, + np.array(pil_cropped_img)) + + def test_pad(self): + np_img = (np.random.rand(28, 24, 3)).astype('uint8') + pil_img = Image.fromarray(np_img) + + np_padded_img = F.pad(np_img, [1, 2], padding_mode='reflect') + pil_padded_img = F.pad(pil_img, [1, 2], padding_mode='reflect') + + np.testing.assert_almost_equal(np_padded_img, np.array(pil_padded_img)) + + pil_p_img = pil_img.convert('P') + pil_padded_img = F.pad(pil_p_img, [1, 2]) + pil_padded_img = F.pad(pil_p_img, [1, 2], padding_mode='reflect') + + def test_resize(self): + np_img = (np.zeros([28, 24, 3])).astype('uint8') + pil_img = Image.fromarray(np_img) + + np_reseized_img = F.resize(np_img, 40) + pil_reseized_img = F.resize(pil_img, 40) + + np.testing.assert_almost_equal(np_reseized_img, + np.array(pil_reseized_img)) + + gray_img = (np.zeros([28, 32])).astype('uint8') + gray_resize_img = F.resize(gray_img, 40) + + def test_to_tensor(self): + np_img = (np.random.rand(28, 28) * 255).astype('uint8') + pil_img = Image.fromarray(np_img) + + np_tensor = F.to_tensor(np_img, data_format='HWC') + pil_tensor = F.to_tensor(pil_img, data_format='HWC') + + np.testing.assert_allclose(np_tensor.numpy(), pil_tensor.numpy()) + + # test float dtype + float_img = np.random.rand(28, 28) + float_tensor = F.to_tensor(float_img) + + pil_img = Image.fromarray(np_img).convert('I') + pil_tensor = F.to_tensor(pil_img) + + pil_img = Image.fromarray(np_img).convert('I;16') + pil_tensor = F.to_tensor(pil_img) + + pil_img = Image.fromarray(np_img).convert('F') + pil_tensor = F.to_tensor(pil_img) + + pil_img = Image.fromarray(np_img).convert('1') + pil_tensor = F.to_tensor(pil_img) + + pil_img = Image.fromarray(np_img).convert('YCbCr') + pil_tensor = F.to_tensor(pil_img) + + def test_image_load(self): + fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype( + 'uint8')) + + path = 'temp.jpg' + fake_img.save(path) + + set_image_backend('pil') + + pil_img = image_load(path).convert('RGB') + + print(type(pil_img)) + + set_image_backend('cv2') + + np_img = image_load(path) + + os.remove(path) if __name__ == '__main__': diff --git a/python/paddle/vision/__init__.py b/python/paddle/vision/__init__.py index 7d28d567ce..db5a94f932 100644 --- a/python/paddle/vision/__init__.py +++ b/python/paddle/vision/__init__.py @@ -21,6 +21,10 @@ from .transforms import * from . import datasets from .datasets import * +from . import image +from .image import * + __all__ = models.__all__ \ + transforms.__all__ \ - + datasets.__all__ + + datasets.__all__ \ + + image.__all__ diff --git a/python/paddle/vision/datasets/folder.py b/python/paddle/vision/datasets/folder.py index 19d913504b..d005bc4f19 100644 --- a/python/paddle/vision/datasets/folder.py +++ b/python/paddle/vision/datasets/folder.py @@ -14,6 +14,7 @@ import os import sys +from PIL import Image import paddle from paddle.io import Dataset @@ -136,7 +137,7 @@ class DatasetFolder(Dataset): "Found 0 files in subfolders of: " + self.root + "\n" "Supported extensions are: " + ",".join(extensions))) - self.loader = cv2_loader if loader is None else loader + self.loader = default_loader if loader is None else loader self.extensions = extensions self.classes = classes @@ -193,9 +194,23 @@ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') +def pil_loader(path): + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + def cv2_loader(path): cv2 = try_import('cv2') - return cv2.imread(path) + return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) + + +def default_loader(path): + from paddle.vision import get_image_backend + if get_image_backend() == 'cv2': + return cv2_loader(path) + else: + return pil_loader(path) class ImageFolder(Dataset): @@ -280,7 +295,7 @@ class ImageFolder(Dataset): "Found 0 files in subfolders of: " + self.root + "\n" "Supported extensions are: " + ",".join(extensions))) - self.loader = cv2_loader if loader is None else loader + self.loader = default_loader if loader is None else loader self.extensions = extensions self.samples = samples self.transform = transform diff --git a/python/paddle/vision/image.py b/python/paddle/vision/image.py new file mode 100644 index 0000000000..3d5ea3a73a --- /dev/null +++ b/python/paddle/vision/image.py @@ -0,0 +1,162 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PIL import Image +from paddle.utils import try_import + +__all__ = ['set_image_backend', 'get_image_backend', 'image_load'] + +_image_backend = 'pil' + + +def set_image_backend(backend): + """ + Specifies the backend used to load images in class ``paddle.vision.datasets.ImageFolder`` + and ``paddle.vision.datasets.DatasetFolder`` . Now support backends are pillow and opencv. + If backend not set, will use 'pil' as default. + + Args: + backend (str): Name of the image load backend, should be one of {'pil', 'cv2'}. + + Examples: + + .. code-block:: python + + import os + import shutil + import tempfile + import numpy as np + from PIL import Image + + from paddle.vision import DatasetFolder + from paddle.vision import set_image_backend + + set_image_backend('pil') + + def make_fake_dir(): + data_dir = tempfile.mkdtemp() + + for i in range(2): + sub_dir = os.path.join(data_dir, 'class_' + str(i)) + if not os.path.exists(sub_dir): + os.makedirs(sub_dir) + for j in range(2): + fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8')) + fake_img.save(os.path.join(sub_dir, str(j) + '.png')) + return data_dir + + temp_dir = make_fake_dir() + + pil_data_folder = DatasetFolder(temp_dir) + + for items in pil_data_folder: + break + + # should get PIL.Image.Image + print(type(items[0])) + + # use opencv as backend + # set_image_backend('cv2') + + # cv2_data_folder = DatasetFolder(temp_dir) + + # for items in cv2_data_folder: + # break + + # should get numpy.ndarray + # print(type(items[0])) + + shutil.rmtree(temp_dir) + """ + global _image_backend + if backend not in ['pil', 'cv2']: + raise ValueError( + "Expected backend are one of ['pil', 'cv2'], but got {}" + .format(backend)) + _image_backend = backend + + +def get_image_backend(): + """ + Gets the name of the package used to load images + + Returns: + str: backend of image load. + + Examples: + + .. code-block:: python + + from paddle.vision import get_image_backend + + backend = get_image_backend() + print(backend) + + """ + return _image_backend + + +def image_load(path, backend=None): + """Load an image. + + Args: + path (str): Path of the image. + backend (str, optional): The image decoding backend type. Options are + `cv2`, `pil`, `None`. If backend is None, the global _imread_backend + specified by ``paddle.vision.set_image_backend`` will be used. Default: None. + + Returns: + PIL.Image or np.array: Loaded image. + + Examples: + + .. code-block:: python + + import numpy as np + from PIL import Image + from paddle.vision import image_load, set_image_backend + + fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8')) + + path = 'temp.png' + fake_img.save(path) + + set_image_backend('pil') + + pil_img = image_load(path).convert('RGB') + + # should be PIL.Image.Image + print(type(pil_img)) + + # use opencv as backend + # set_image_backend('cv2') + + # np_img = image_load(path) + # # should get numpy.ndarray + # print(type(np_img)) + + """ + + if backend is None: + backend = _image_backend + if backend not in ['pil', 'cv2']: + raise ValueError( + "Expected backend are one of ['pil', 'cv2'], but got {}" + .format(backend)) + + if backend == 'pil': + return Image.open(path) + else: + cv2 = try_import('cv2') + return cv2.imread(path) diff --git a/python/paddle/vision/transforms/functional.py b/python/paddle/vision/transforms/functional.py index acceb111e6..7391ae322e 100644 --- a/python/paddle/vision/transforms/functional.py +++ b/python/paddle/vision/transforms/functional.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import division + import sys -import collections -import random import math -import functools - import numbers -import numpy as np +import warnings +import collections -from paddle.utils import try_import +import numpy as np +from PIL import Image +from numpy import sin, cos, tan +import paddle if sys.version_info < (3, 3): Sequence = collections.Sequence @@ -30,314 +32,623 @@ else: Sequence = collections.abc.Sequence Iterable = collections.abc.Iterable -__all__ = ['flip', 'resize', 'pad', 'rotate', 'to_grayscale'] +from . import functional_pil as F_pil +from . import functional_cv2 as F_cv2 +from . import functional_tensor as F_t +__all__ = [ + 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', + 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', + 'to_grayscale', 'normalize' +] -def keepdims(func): - """Keep the dimension of input images unchanged""" - @functools.wraps(func) - def wrapper(image, *args, **kwargs): - if len(image.shape) != 3: - raise ValueError("Expect image have 3 dims, but got {} dims".format( - len(image.shape))) - ret = func(image, *args, **kwargs) - if len(ret.shape) == 2: - ret = ret[:, :, np.newaxis] - return ret +def _is_pil_image(img): + return isinstance(img, Image.Image) - return wrapper +def _is_tensor_image(img): + return isinstance(img, paddle.Tensor) -@keepdims -def flip(image, code): - """ - Accordding to the code (the type of flip), flip the input image + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def to_tensor(pic, data_format='CHW'): + """Converts a ``PIL.Image`` or ``numpy.ndarray`` to paddle.Tensor. + + See ``ToTensor`` for more details. Args: - image (np.ndarray): Input image, with (H, W, C) shape - code (int): Code that indicates the type of flip. - -1 : Flip horizontally and vertically - 0 : Flip vertically - 1 : Flip horizontally + pic (PIL.Image|np.ndarray): Image to be converted to tensor. + data_format (str, optional): Data format of input img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + + Returns: + Tensor: Converted image. Data format is same as input img. Examples: .. code-block:: python import numpy as np + from PIL import Image from paddle.vision.transforms import functional as F - fake_img = np.random.rand(224, 224, 3) + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') - # flip horizontally and vertically - F.flip(fake_img, -1) + fake_img = Image.fromarray(fake_img) - # flip vertically - F.flip(fake_img, 0) + tensor = F.to_tensor(fake_img) + print(tensor.shape) - # flip horizontally - F.flip(fake_img, 1) """ - cv2 = try_import('cv2') - return cv2.flip(image, flipCode=code) + if not (_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError('pic should be PIL Image or ndarray. Got {}'.format( + type(pic))) + + if _is_pil_image(pic): + return F_pil.to_tensor(pic, data_format) + else: + return F_cv2.to_tensor(pic, data_format) -@keepdims -def resize(img, size, interpolation=1): +def resize(img, size, interpolation='bilinear'): """ - resize the input data to given size + Resizes the image to given size Args: - input (np.ndarray): Input data, could be image or masks, with (H, W, C) shape + input (PIL.Image|np.ndarray): Image to be resized. size (int|list|tuple): Target size of input data, with (height, width) shape. - interpolation (int, optional): Interpolation method. - 0 : cv2.INTER_NEAREST - 1 : cv2.INTER_LINEAR - 2 : cv2.INTER_CUBIC - 3 : cv2.INTER_AREA - 4 : cv2.INTER_LANCZOS4 - 5 : cv2.INTER_LINEAR_EXACT - 7 : cv2.INTER_MAX - 8 : cv2.WARP_FILL_OUTLIERS - 16: cv2.WARP_INVERSE_MAP + interpolation (int|str, optional): Interpolation method. when use pil backend, + support method are as following: + - "nearest": Image.NEAREST, + - "bilinear": Image.BILINEAR, + - "bicubic": Image.BICUBIC, + - "box": Image.BOX, + - "lanczos": Image.LANCZOS, + - "hamming": Image.HAMMING + when use cv2 backend, support method are as following: + - "nearest": cv2.INTER_NEAREST, + - "bilinear": cv2.INTER_LINEAR, + - "area": cv2.INTER_AREA, + - "bicubic": cv2.INTER_CUBIC, + - "lanczos": cv2.INTER_LANCZOS4 + + Returns: + PIL.Image or np.array: Resized image. Examples: .. code-block:: python import numpy as np + from PIL import Image from paddle.vision.transforms import functional as F - fake_img = np.random.rand(256, 256, 3) + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') - F.resize(fake_img, 224) + fake_img = Image.fromarray(fake_img) - F.resize(fake_img, (200, 150)) + converted_img = F.resize(fake_img, 224) + print(converted_img.size) + + converted_img = F.resize(fake_img, (200, 150)) + print(converted_img.size) """ - cv2 = try_import('cv2') - if isinstance(interpolation, Sequence): - interpolation = random.choice(interpolation) - - if isinstance(size, int): - h, w = img.shape[:2] - if (w <= h and w == size) or (h <= w and h == size): - return img - if w < h: - ow = size - oh = int(size * h / w) - return cv2.resize(img, (ow, oh), interpolation=interpolation) - else: - oh = size - ow = int(size * w / h) - return cv2.resize(img, (ow, oh), interpolation=interpolation) + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.resize(img, size, interpolation) else: - return cv2.resize(img, size[::-1], interpolation=interpolation) + return F_cv2.resize(img, size, interpolation) -@keepdims -def pad(img, padding, fill=(0, 0, 0), padding_mode='constant'): - """Pads the given CV Image on all sides with speficified padding mode and fill value. +def pad(img, padding, fill=0, padding_mode='constant'): + """ + Pads the given PIL.Image or numpy.array on all sides with specified padding mode and fill value. Args: - img (np.ndarray): Image to be padded. - padding (int|tuple): Padding on each border. If a single int is provided this + img (PIL.Image|np.array): Image to be padded. + padding (int|list|tuple): Padding on each border. If a single int is provided this is used to pad all borders. If tuple of length 2 is provided this is the padding on left/right and top/bottom respectively. If a tuple of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. - fill (int|tuple): Pixel fill value for constant fill. Default is 0. If a tuple of + fill (float, optional): Pixel fill value for constant fill. If a tuple of length 3, it is used to fill R, G, B channels respectively. - This value is only used when the padding_mode is constant - padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. - ``constant`` means padding with a constant value, this value is specified with fill. - ``edge`` means padding with the last value at the edge of the image. - ``reflect`` means padding with reflection of image (without repeating the last value on the edge) - padding ``[1, 2, 3, 4]`` with 2 elements on both sides in reflect mode - will result in ``[3, 2, 1, 2, 3, 4, 3, 2]``. - ``symmetric`` menas pads with reflection of image (repeating the last value on the edge) - padding ``[1, 2, 3, 4]`` with 2 elements on both sides in symmetric mode - will result in ``[2, 1, 1, 2, 3, 4, 4, 3]``. + This value is only used when the padding_mode is constant. Default: 0. + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default: 'constant'. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] Returns: - numpy ndarray: Padded image. + PIL.Image or np.array: Padded image. + + Examples: + .. code-block:: python + + import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') + + fake_img = Image.fromarray(fake_img) + + padded_img = F.pad(fake_img, padding=1) + print(padded_img.size) + + padded_img = F.pad(fake_img, padding=(2, 1)) + print(padded_img.size) + """ + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.pad(img, padding, fill, padding_mode) + else: + return F_cv2.pad(img, padding, fill, padding_mode) + + +def crop(img, top, left, height, width): + """Crops the given Image. + + Args: + img (PIL.Image|np.array): Image to be cropped. (0,0) denotes the top left + corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + + Returns: + PIL.Image or np.array: Cropped image. Examples: - .. code-block:: python import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F - from paddle.vision.transforms.functional import pad + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray(fake_img) - fake_img = pad(fake_img, 2) - print(fake_img.shape) + cropped_img = F.crop(fake_img, 56, 150, 200, 100) + print(cropped_img.size) """ + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.crop(img, top, left, height, width) + else: + return F_cv2.crop(img, top, left, height, width) + + +def center_crop(img, output_size): + """Crops the given Image and resize it to desired size. + + Args: + img (PIL.Image|np.array): Image to be cropped. (0,0) denotes the top left corner of the image. + output_size (sequence or int): (height, width) of the crop box. If int, + it is used for both directions + + Returns: + PIL.Image or np.array: Cropped image. - if not isinstance(padding, (numbers.Number, list, tuple)): - raise TypeError('Got inappropriate padding arg') - if not isinstance(fill, (numbers.Number, str, list, tuple)): - raise TypeError('Got inappropriate fill arg') - if not isinstance(padding_mode, str): - raise TypeError('Got inappropriate padding_mode arg') - - if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: - raise ValueError( - "Padding must be an int or a 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) - - assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ - 'Expected padding mode be either constant, edge, reflect or symmetric, but got {}'.format(padding_mode) - - cv2 = try_import('cv2') - - PAD_MOD = { - 'constant': cv2.BORDER_CONSTANT, - 'edge': cv2.BORDER_REPLICATE, - 'reflect': cv2.BORDER_DEFAULT, - 'symmetric': cv2.BORDER_REFLECT - } - - if isinstance(padding, int): - pad_left = pad_right = pad_top = pad_bottom = padding - if isinstance(padding, collections.Sequence) and len(padding) == 2: - pad_left = pad_right = padding[0] - pad_top = pad_bottom = padding[1] - if isinstance(padding, collections.Sequence) and len(padding) == 4: - pad_left, pad_top, pad_right, pad_bottom = padding - - if isinstance(fill, numbers.Number): - fill = (fill, ) * (2 * len(img.shape) - 3) - - if padding_mode == 'constant': - assert (len(fill) == 3 and len(img.shape) == 3) or (len(fill) == 1 and len(img.shape) == 2), \ - 'channel of image is {} but length of fill is {}'.format(img.shape[-1], len(fill)) - - img = cv2.copyMakeBorder( - src=img, - top=pad_top, - bottom=pad_bottom, - left=pad_left, - right=pad_right, - borderType=PAD_MOD[padding_mode], - value=fill) - - return img - - -@keepdims -def rotate(img, angle, interpolation=1, expand=False, center=None): + Examples: + .. code-block:: python + + import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') + + fake_img = Image.fromarray(fake_img) + + cropped_img = F.center_crop(fake_img, (150, 100)) + print(cropped_img.size) + """ + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.center_crop(img, output_size) + else: + return F_cv2.center_crop(img, output_size) + + +def hflip(img, backend='pil'): + """Horizontally flips the given Image or np.array. + + Args: + img (PIL.Image|np.array): Image to be flipped. + backend (str, optional): The image proccess backend type. Options are `pil`, + `cv2`. Default: 'pil'. + + Returns: + PIL.Image or np.array: Horizontall flipped image. + + Examples: + .. code-block:: python + + import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') + + fake_img = Image.fromarray(fake_img) + + flpped_img = F.hflip(fake_img) + print(flpped_img.size) + + """ + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.hflip(img) + else: + return F_cv2.hflip(img) + + +def vflip(img): + """Vertically flips the given Image or np.array. + + Args: + img (PIL.Image|np.array): Image to be flipped. + + Returns: + PIL.Image or np.array: Vertically flipped image. + + Examples: + .. code-block:: python + + import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') + + fake_img = Image.fromarray(fake_img) + + flpped_img = F.vflip(fake_img) + print(flpped_img.size) + + """ + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.vflip(img) + else: + return F_cv2.vflip(img) + + +def adjust_brightness(img, brightness_factor): + """Adjusts brightness of an Image. + + Args: + img (PIL.Image|np.array): Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL.Image or np.array: Brightness adjusted image. + + Examples: + .. code-block:: python + + import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') + + fake_img = Image.fromarray(fake_img) + + converted_img = F.adjust_brightness(fake_img, 0.4) + print(converted_img.size) + """ + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.adjust_brightness(img, brightness_factor) + else: + return F_cv2.adjust_brightness(img, brightness_factor) + + +def adjust_contrast(img, contrast_factor): + """Adjusts contrast of an Image. + + Args: + img (PIL.Image|np.array): Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + PIL.Image or np.array: Contrast adjusted image. + + Examples: + .. code-block:: python + + import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') + + fake_img = Image.fromarray(fake_img) + + converted_img = F.adjust_contrast(fake_img, 0.4) + print(converted_img.size) + """ + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.adjust_contrast(img, contrast_factor) + else: + return F_cv2.adjust_contrast(img, contrast_factor) + + +def adjust_saturation(img, saturation_factor): + """Adjusts color saturation of an image. + + Args: + img (PIL.Image|np.array): Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + PIL.Image or np.array: Saturation adjusted image. + + Examples: + .. code-block:: python + + import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') + + fake_img = Image.fromarray(fake_img) + + converted_img = F.adjust_saturation(fake_img, 0.4) + print(converted_img.size) + + """ + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.adjust_saturation(img, saturation_factor) + else: + return F_cv2.adjust_saturation(img, saturation_factor) + + +def adjust_hue(img, hue_factor): + """Adjusts hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + Args: + img (PIL.Image|np.array): Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL.Image or np.array: Hue adjusted image. + + Examples: + .. code-block:: python + + import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') + + fake_img = Image.fromarray(fake_img) + + converted_img = F.adjust_hue(fake_img, 0.4) + print(converted_img.size) + + """ + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.adjust_hue(img, hue_factor) + else: + return F_cv2.adjust_hue(img, hue_factor) + + +def rotate(img, angle, resample=False, expand=False, center=None, fill=0): """Rotates the image by angle. + Args: - img (numpy.ndarray): Image to be rotated. - angle (float|int): In degrees clockwise order. - interpolation (int, optional): Interpolation method. Default: 1. - 0 : cv2.INTER_NEAREST - 1 : cv2.INTER_LINEAR - 2 : cv2.INTER_CUBIC - 3 : cv2.INTER_AREA - 4 : cv2.INTER_LANCZOS4 - 5 : cv2.INTER_LINEAR_EXACT - 7 : cv2.INTER_MAX - 8 : cv2.WARP_FILL_OUTLIERS - 16: cv2.WARP_INVERSE_MAP - expand (bool|optional): Optional expansion flag. + img (PIL.Image|np.array): Image to be rotated. + angle (float or int): In degrees degrees counter clockwise order. + resample (int|str, optional): An optional resampling filter. If omitted, or if the + image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST + according the backend. when use pil backend, support method are as following: + - "nearest": Image.NEAREST, + - "bilinear": Image.BILINEAR, + - "bicubic": Image.BICUBIC + when use cv2 backend, support method are as following: + - "nearest": cv2.INTER_NEAREST, + - "bilinear": cv2.INTER_LINEAR, + - "bicubic": cv2.INTER_CUBIC + expand (bool, optional): Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation. - center (2-tuple|optional): Optional center of rotation. + center (2-tuple, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. + fill (3-tuple or int): RGB pixel fill value for area outside the rotated image. + If int, it is used for all channels respectively. + Returns: - numpy ndarray: Rotated image. + PIL.Image or np.array: Rotated image. Examples: - .. code-block:: python import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') - from paddle.vision.transforms.functional import rotate + fake_img = Image.fromarray(fake_img) - fake_img = np.random.rand(500, 500, 3).astype('float32') + rotated_img = F.rotate(fake_img, 90) + print(rotated_img.size) - fake_img = rotate(fake_img, 10) - print(fake_img.shape) """ - cv2 = try_import('cv2') - - dtype = img.dtype - h, w, _ = img.shape - point = center or (w / 2, h / 2) - M = cv2.getRotationMatrix2D(point, angle=-angle, scale=1) - - if expand: - if center is None: - cos = np.abs(M[0, 0]) - sin = np.abs(M[0, 1]) - - nW = int((h * sin) + (w * cos)) - nH = int((h * cos) + (w * sin)) - - M[0, 2] += (nW / 2) - point[0] - M[1, 2] += (nH / 2) - point[1] - - dst = cv2.warpAffine(img, M, (nW, nH)) - else: - xx = [] - yy = [] - for point in (np.array([0, 0, 1]), np.array([w - 1, 0, 1]), - np.array([w - 1, h - 1, 1]), np.array([0, h - 1, 1])): - target = np.dot(M, point) - xx.append(target[0]) - yy.append(target[1]) - nh = int(math.ceil(max(yy)) - math.floor(min(yy))) - nw = int(math.ceil(max(xx)) - math.floor(min(xx))) - - M[0, 2] += (nw - w) / 2 - M[1, 2] += (nh - h) / 2 - dst = cv2.warpAffine(img, M, (nw, nh), flags=interpolation) + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.rotate(img, angle, resample, expand, center, fill) else: - dst = cv2.warpAffine(img, M, (w, h), flags=interpolation) - return dst.astype(dtype) + return F_cv2.rotate(img, angle, resample, expand, center, fill) -@keepdims def to_grayscale(img, num_output_channels=1): """Converts image to grayscale version of image. Args: - img (numpy.ndarray): Image to be converted to grayscale. + img (PIL.Image|np.array): Image to be converted to grayscale. + backend (str, optional): The image proccess backend type. Options are `pil`, + `cv2`. Default: 'pil'. Returns: - numpy.ndarray: Grayscale version of the image. - if num_output_channels == 1, returned image is single channel - if num_output_channels == 3, returned image is 3 channel with r == g == b + PIL.Image or np.array: Grayscale version of the image. + if num_output_channels = 1 : returned image is single channel + + if num_output_channels = 3 : returned image is 3 channel with r = g = b Examples: + .. code-block:: python + + import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') + + fake_img = Image.fromarray(fake_img) + + gray_img = F.to_grayscale(fake_img) + print(gray_img.size) + + """ + if not (_is_pil_image(img) or _is_numpy_image(img)): + raise TypeError( + 'img should be PIL Image or ndarray with dim=[2 or 3]. Got {}'. + format(type(img))) + + if _is_pil_image(img): + return F_pil.to_grayscale(img, num_output_channels) + else: + return F_cv2.to_grayscale(img, num_output_channels) + + +def normalize(img, mean, std, data_format='CHW', to_rgb=False): + """Normalizes a tensor or image with mean and standard deviation. + + Args: + img (PIL.Image|np.array|paddle.Tensor): input data to be normalized. + mean (list|tuple): Sequence of means for each channel. + std (list|tuple): Sequence of standard deviations for each channel. + data_format (str, optional): Data format of input img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + to_rgb (bool, optional): Whether to convert to rgb. If input is tensor, + this option will be igored. Default: False. + + Returns: + Tensor: Normalized mage. Data format is same as input img. + Examples: .. code-block:: python import numpy as np + from PIL import Image + from paddle.vision.transforms import functional as F + + fake_img = (np.random.rand(256, 300, 3) * 255.).astype('uint8') + + fake_img = Image.fromarray(fake_img) - from paddle.vision.transforms.functional import to_grayscale + mean = [127.5, 127.5, 127.5] + std = [127.5, 127.5, 127.5] - fake_img = np.random.rand(500, 500, 3).astype('float32') + normalized_img = F.normalize(fake_img, mean, std, data_format='HWC') + print(normalized_img.max(), normalized_img.min()) - fake_img = to_grayscale(fake_img) - print(fake_img.shape) """ - cv2 = try_import('cv2') - if num_output_channels == 1: - img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) - elif num_output_channels == 3: - img = cv2.cvtColor( - cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.COLOR_GRAY2RGB) + if _is_tensor_image(img): + return F_t.normalize(img, mean, std, data_format) else: - raise ValueError('num_output_channels should be either 1 or 3') + if _is_pil_image(img): + img = np.array(img).astype(np.float32) - return img + return F_cv2.normalize(img, mean, std, data_format, to_rgb) diff --git a/python/paddle/vision/transforms/functional_cv2.py b/python/paddle/vision/transforms/functional_cv2.py new file mode 100644 index 0000000000..5c2e8d61bc --- /dev/null +++ b/python/paddle/vision/transforms/functional_cv2.py @@ -0,0 +1,503 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import sys +import numbers +import warnings +import collections + +import numpy as np +from numpy import sin, cos, tan + +import paddle +from paddle.utils import try_import + +if sys.version_info < (3, 3): + Sequence = collections.Sequence + Iterable = collections.Iterable +else: + Sequence = collections.abc.Sequence + Iterable = collections.abc.Iterable + + +def to_tensor(pic, data_format='CHW'): + """Converts a ``numpy.ndarray`` to paddle.Tensor. + + See ``ToTensor`` for more details. + + Args: + pic (np.ndarray): Image to be converted to tensor. + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + + Returns: + Tensor: Converted image. + + """ + + if not data_format in ['CHW', 'HWC']: + raise ValueError('data_format should be CHW or HWC. Got {}'.format( + data_format)) + + if pic.ndim == 2: + pic = pic[:, :, None] + + if data_format == 'CHW': + img = paddle.to_tensor(pic.transpose((2, 0, 1))) + else: + img = paddle.to_tensor(pic) + + if paddle.fluid.data_feeder.convert_dtype(img.dtype) == 'uint8': + return paddle.cast(img, np.float32) / 255. + else: + return img + + +def resize(img, size, interpolation='bilinear'): + """ + Resizes the image to given size + + Args: + input (np.ndarray): Image to be resized. + size (int|list|tuple): Target size of input data, with (height, width) shape. + interpolation (int|str, optional): Interpolation method. when use cv2 backend, + support method are as following: + - "nearest": cv2.INTER_NEAREST, + - "bilinear": cv2.INTER_LINEAR, + - "area": cv2.INTER_AREA, + - "bicubic": cv2.INTER_CUBIC, + - "lanczos": cv2.INTER_LANCZOS4 + + Returns: + np.array: Resized image. + + """ + cv2 = try_import('cv2') + _cv2_interp_from_str = { + 'nearest': cv2.INTER_NEAREST, + 'bilinear': cv2.INTER_LINEAR, + 'area': cv2.INTER_AREA, + 'bicubic': cv2.INTER_CUBIC, + 'lanczos': cv2.INTER_LANCZOS4 + } + + if not (isinstance(size, int) or + (isinstance(size, Iterable) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + h, w = img.shape[:2] + + if isinstance(size, int): + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + output = cv2.resize( + img, + dsize=(ow, oh), + interpolation=_cv2_interp_from_str[interpolation]) + else: + oh = size + ow = int(size * w / h) + output = cv2.resize( + img, + dsize=(ow, oh), + interpolation=_cv2_interp_from_str[interpolation]) + else: + output = cv2.resize( + img, + dsize=(size[1], size[0]), + interpolation=_cv2_interp_from_str[interpolation]) + if len(img.shape) == 3 and img.shape[2] == 1: + return output[:, :, np.newaxis] + else: + return output + + +def pad(img, padding, fill=0, padding_mode='constant'): + """ + Pads the given numpy.array on all sides with specified padding mode and fill value. + + Args: + img (np.array): Image to be padded. + padding (int|list|tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill (float, optional): Pixel fill value for constant fill. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. Default: 0. + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default: 'constant'. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + + Returns: + np.array: Padded image. + + """ + cv2 = try_import('cv2') + _cv2_pad_from_str = { + 'constant': cv2.BORDER_CONSTANT, + 'edge': cv2.BORDER_REPLICATE, + 'reflect': cv2.BORDER_REFLECT_101, + 'symmetric': cv2.BORDER_REFLECT + } + + if not isinstance(padding, (numbers.Number, list, tuple)): + raise TypeError('Got inappropriate padding arg') + if not isinstance(fill, (numbers.Number, str, list, tuple)): + raise TypeError('Got inappropriate fill arg') + if not isinstance(padding_mode, str): + raise TypeError('Got inappropriate padding_mode arg') + + if isinstance(padding, Sequence) and len(padding) not in [2, 4]: + raise ValueError( + "Padding must be an int or a 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ + 'Padding mode should be either constant, edge, reflect or symmetric' + + if isinstance(padding, list): + padding = tuple(padding) + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + if isinstance(padding, Sequence) and len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + if isinstance(padding, Sequence) and len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + if len(img.shape) == 3 and img.shape[2] == 1: + return cv2.copyMakeBorder( + img, + top=pad_top, + bottom=pad_bottom, + left=pad_left, + right=pad_right, + borderType=_cv2_pad_from_str[padding_mode], + value=fill)[:, :, np.newaxis] + else: + return cv2.copyMakeBorder( + img, + top=pad_top, + bottom=pad_bottom, + left=pad_left, + right=pad_right, + borderType=_cv2_pad_from_str[padding_mode], + value=fill) + + +def crop(img, top, left, height, width): + """Crops the given image. + + Args: + img (np.array): Image to be cropped. (0,0) denotes the top left + corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + + Returns: + np.array: Cropped image. + + """ + + return img[top:top + height, left:left + width, :] + + +def center_crop(img, output_size): + """Crops the given image and resize it to desired size. + + Args: + img (np.array): Image to be cropped. (0,0) denotes the top left corner of the image. + output_size (sequence or int): (height, width) of the crop box. If int, + it is used for both directions + backend (str, optional): The image proccess backend type. Options are `pil`, `cv2`. Default: 'pil'. + + Returns: + np.array: Cropped image. + + """ + + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + + h, w = img.shape[0:2] + th, tw = output_size + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + return crop(img, i, j, th, tw) + + +def hflip(img): + """Horizontally flips the given image. + + Args: + img (np.array): Image to be flipped. + + Returns: + np.array: Horizontall flipped image. + + """ + cv2 = try_import('cv2') + + return cv2.flip(img, 1) + + +def vflip(img): + """Vertically flips the given np.array. + + Args: + img (np.array): Image to be flipped. + + Returns: + np.array: Vertically flipped image. + + """ + cv2 = try_import('cv2') + + if len(img.shape) == 3 and img.shape[2] == 1: + return cv2.flip(img, 0)[:, :, np.newaxis] + else: + return cv2.flip(img, 0) + + +def adjust_brightness(img, brightness_factor): + """Adjusts brightness of an image. + + Args: + img (np.array): Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + np.array: Brightness adjusted image. + + """ + cv2 = try_import('cv2') + + table = np.array([i * brightness_factor + for i in range(0, 256)]).clip(0, 255).astype('uint8') + + if len(img.shape) == 3 and img.shape[2] == 1: + return cv2.LUT(img, table)[:, :, np.newaxis] + else: + return cv2.LUT(img, table) + + +def adjust_contrast(img, contrast_factor): + """Adjusts contrast of an image. + + Args: + img (np.array): Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + np.array: Contrast adjusted image. + + """ + cv2 = try_import('cv2') + + table = np.array([(i - 74) * contrast_factor + 74 + for i in range(0, 256)]).clip(0, 255).astype('uint8') + if len(img.shape) == 3 and img.shape[2] == 1: + return cv2.LUT(img, table)[:, :, np.newaxis] + else: + return cv2.LUT(img, table) + + +def adjust_saturation(img, saturation_factor): + """Adjusts color saturation of an image. + + Args: + img (np.array): Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + np.array: Saturation adjusted image. + + """ + cv2 = try_import('cv2') + + dtype = img.dtype + img = img.astype(np.float32) + alpha = np.random.uniform( + max(0, 1 - saturation_factor), 1 + saturation_factor) + gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gray_img = gray_img[..., np.newaxis] + img = img * alpha + gray_img * (1 - alpha) + return img.clip(0, 255).astype(dtype) + + +def adjust_hue(img, hue_factor): + """Adjusts hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + Args: + img (np.array): Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + np.array: Hue adjusted image. + + """ + cv2 = try_import('cv2') + + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) + + dtype = img.dtype + img = img.astype(np.uint8) + hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV_FULL) + h, s, v = cv2.split(hsv_img) + + alpha = np.random.uniform(hue_factor, hue_factor) + h = h.astype(np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over="ignore"): + h += np.uint8(alpha * 255) + hsv_img = cv2.merge([h, s, v]) + return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype) + + +def rotate(img, angle, resample=False, expand=False, center=None, fill=0): + """Rotates the image by angle. + + Args: + img (np.array): Image to be rotated. + angle (float or int): In degrees degrees counter clockwise order. + resample (int|str, optional): An optional resampling filter. If omitted, or if the + image has only one channel, it is set to cv2.INTER_NEAREST. + when use cv2 backend, support method are as following: + - "nearest": cv2.INTER_NEAREST, + - "bilinear": cv2.INTER_LINEAR, + - "bicubic": cv2.INTER_CUBIC + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + fill (3-tuple or int): RGB pixel fill value for area outside the rotated image. + If int, it is used for all channels respectively. + + Returns: + np.array: Rotated image. + + """ + cv2 = try_import('cv2') + + rows, cols = img.shape[0:2] + if center is None: + center = (cols / 2, rows / 2) + M = cv2.getRotationMatrix2D(center, angle, 1) + if len(img.shape) == 3 and img.shape[2] == 1: + return cv2.warpAffine(img, M, (cols, rows))[:, :, np.newaxis] + else: + return cv2.warpAffine(img, M, (cols, rows)) + + +def to_grayscale(img, num_output_channels=1): + """Converts image to grayscale version of image. + + Args: + img (np.array): Image to be converted to grayscale. + + Returns: + np.array: Grayscale version of the image. + if num_output_channels = 1 : returned image is single channel + + if num_output_channels = 3 : returned image is 3 channel with r = g = b + + """ + cv2 = try_import('cv2') + + if num_output_channels == 1: + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis] + elif num_output_channels == 3: + # much faster than doing cvtColor to go back to gray + img = np.broadcast_to( + cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis], img.shape) + else: + raise ValueError('num_output_channels should be either 1 or 3') + + return img + + +def normalize(img, mean, std, data_format='CHW', to_rgb=False): + """Normalizes a ndarray imge or image with mean and standard deviation. + + Args: + img (np.array): input data to be normalized. + mean (list|tuple): Sequence of means for each channel. + std (list|tuple): Sequence of standard deviations for each channel. + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + to_rgb (bool, optional): Whether to convert to rgb. Default: False. + + Returns: + np.array: Normalized mage. + + """ + + if data_format == 'CHW': + mean = np.float32(np.array(mean).reshape(-1, 1, 1)) + std = np.float32(np.array(std).reshape(-1, 1, 1)) + else: + mean = np.float32(np.array(mean).reshape(1, 1, -1)) + std = np.float32(np.array(std).reshape(1, 1, -1)) + if to_rgb: + cv2 = try_import('cv2') + # inplace + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) + + img = (img - mean) / std + return img diff --git a/python/paddle/vision/transforms/functional_pil.py b/python/paddle/vision/transforms/functional_pil.py new file mode 100644 index 0000000000..49b02fc049 --- /dev/null +++ b/python/paddle/vision/transforms/functional_pil.py @@ -0,0 +1,458 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import sys +import math +import numbers +import warnings +import collections +from PIL import Image, ImageOps, ImageEnhance + +import numpy as np +from numpy import sin, cos, tan +import paddle + +if sys.version_info < (3, 3): + Sequence = collections.Sequence + Iterable = collections.Iterable +else: + Sequence = collections.abc.Sequence + Iterable = collections.abc.Iterable + +_pil_interp_from_str = { + 'nearest': Image.NEAREST, + 'bilinear': Image.BILINEAR, + 'bicubic': Image.BICUBIC, + 'box': Image.BOX, + 'lanczos': Image.LANCZOS, + 'hamming': Image.HAMMING +} + + +def to_tensor(pic, data_format='CHW'): + """Converts a ``PIL.Image`` to paddle.Tensor. + + See ``ToTensor`` for more details. + + Args: + pic (PIL.Image): Image to be converted to tensor. + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + + Returns: + Tensor: Converted image. + + """ + + if not data_format in ['CHW', 'HWC']: + raise ValueError('data_format should be CHW or HWC. Got {}'.format( + data_format)) + + # PIL Image + if pic.mode == 'I': + img = paddle.to_tensor(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + # cast and reshape not support int16 + img = paddle.to_tensor(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'F': + img = paddle.to_tensor(np.array(pic, np.float32, copy=False)) + elif pic.mode == '1': + img = 255 * paddle.to_tensor(np.array(pic, np.uint8, copy=False)) + else: + img = paddle.to_tensor(np.array(pic, copy=False)) + + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + + dtype = paddle.fluid.data_feeder.convert_dtype(img.dtype) + if dtype == 'uint8': + img = paddle.cast(img, np.float32) / 255. + + img = img.reshape([pic.size[1], pic.size[0], nchannel]) + + if data_format == 'CHW': + img = img.transpose([2, 0, 1]) + + return img + + +def resize(img, size, interpolation='bilinear'): + """ + Resizes the image to given size + + Args: + input (PIL.Image): Image to be resized. + size (int|list|tuple): Target size of input data, with (height, width) shape. + interpolation (int|str, optional): Interpolation method. when use pil backend, + support method are as following: + - "nearest": Image.NEAREST, + - "bilinear": Image.BILINEAR, + - "bicubic": Image.BICUBIC, + - "box": Image.BOX, + - "lanczos": Image.LANCZOS, + - "hamming": Image.HAMMING + + Returns: + PIL.Image: Resized image. + + """ + + if not (isinstance(size, int) or + (isinstance(size, Iterable) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + if isinstance(size, int): + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), _pil_interp_from_str[interpolation]) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), _pil_interp_from_str[interpolation]) + else: + return img.resize(size[::-1], _pil_interp_from_str[interpolation]) + + +def pad(img, padding, fill=0, padding_mode='constant'): + """ + Pads the given PIL.Image on all sides with specified padding mode and fill value. + + Args: + img (PIL.Image): Image to be padded. + padding (int|list|tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill (float, optional): Pixel fill value for constant fill. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. Default: 0. + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default: 'constant'. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + + Returns: + PIL.Image: Padded image. + + """ + + if not isinstance(padding, (numbers.Number, list, tuple)): + raise TypeError('Got inappropriate padding arg') + if not isinstance(fill, (numbers.Number, str, list, tuple)): + raise TypeError('Got inappropriate fill arg') + if not isinstance(padding_mode, str): + raise TypeError('Got inappropriate padding_mode arg') + + if isinstance(padding, Sequence) and len(padding) not in [2, 4]: + raise ValueError( + "Padding must be an int or a 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ + 'Padding mode should be either constant, edge, reflect or symmetric' + + if isinstance(padding, list): + padding = tuple(padding) + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + if isinstance(padding, Sequence) and len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + if isinstance(padding, Sequence) and len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + if padding_mode == 'constant': + if img.mode == 'P': + palette = img.getpalette() + image = ImageOps.expand(img, border=padding, fill=fill) + image.putpalette(palette) + return image + + return ImageOps.expand(img, border=padding, fill=fill) + else: + if img.mode == 'P': + palette = img.getpalette() + img = np.asarray(img) + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), + padding_mode) + img = Image.fromarray(img) + img.putpalette(palette) + return img + + img = np.asarray(img) + # RGB image + if len(img.shape) == 3: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), + (0, 0)), padding_mode) + # Grayscale image + if len(img.shape) == 2: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), + padding_mode) + + return Image.fromarray(img) + + +def crop(img, top, left, height, width): + """Crops the given PIL Image. + + Args: + img (PIL.Image): Image to be cropped. (0,0) denotes the top left + corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + + Returns: + PIL.Image: Cropped image. + + """ + return img.crop((left, top, left + width, top + height)) + + +def center_crop(img, output_size): + """Crops the given PIL Image and resize it to desired size. + + Args: + img (PIL.Image): Image to be cropped. (0,0) denotes the top left corner of the image. + output_size (sequence or int): (height, width) of the crop box. If int, + it is used for both directions + backend (str, optional): The image proccess backend type. Options are `pil`, `cv2`. Default: 'pil'. + + Returns: + PIL.Image: Cropped image. + + """ + + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + + image_width, image_height = img.size + crop_height, crop_width = output_size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, crop_top, crop_left, crop_height, crop_width) + + +def hflip(img): + """Horizontally flips the given PIL Image. + + Args: + img (PIL.Image): Image to be flipped. + + Returns: + PIL.Image: Horizontall flipped image. + + """ + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +def vflip(img): + """Vertically flips the given PIL Image. + + Args: + img (PIL.Image): Image to be flipped. + + Returns: + PIL.Image: Vertically flipped image. + + """ + + return img.transpose(Image.FLIP_TOP_BOTTOM) + + +def adjust_brightness(img, brightness_factor): + """Adjusts brightness of an Image. + + Args: + img (PIL.Image): PIL Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL.Image: Brightness adjusted image. + + """ + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +def adjust_contrast(img, contrast_factor): + """Adjusts contrast of an Image. + + Args: + img (PIL.Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + PIL.Image: Contrast adjusted image. + + """ + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +def adjust_saturation(img, saturation_factor): + """Adjusts color saturation of an image. + + Args: + img (PIL.Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + PIL.Image: Saturation adjusted image. + + """ + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +def adjust_hue(img, hue_factor): + """Adjusts hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + Args: + img (PIL.Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL.Image: Hue adjusted image. + + """ + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img + + +def rotate(img, angle, resample=False, expand=False, center=None, fill=0): + """Rotates the image by angle. + + Args: + img (PIL.Image): Image to be rotated. + angle (float or int): In degrees degrees counter clockwise order. + resample (int|str, optional): An optional resampling filter. If omitted, or if the + image has only one channel, it is set to PIL.Image.NEAREST . when use pil backend, + support method are as following: + - "nearest": Image.NEAREST, + - "bilinear": Image.BILINEAR, + - "bicubic": Image.BICUBIC + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + fill (3-tuple or int): RGB pixel fill value for area outside the rotated image. + If int, it is used for all channels respectively. + + Returns: + PIL.Image: Rotated image. + + """ + + if isinstance(fill, int): + fill = tuple([fill] * 3) + + return img.rotate(angle, resample, expand, center, fillcolor=fill) + + +def to_grayscale(img, num_output_channels=1): + """Converts image to grayscale version of image. + + Args: + img (PIL.Image): Image to be converted to grayscale. + backend (str, optional): The image proccess backend type. Options are `pil`, + `cv2`. Default: 'pil'. + + Returns: + PIL.Image: Grayscale version of the image. + if num_output_channels = 1 : returned image is single channel + + if num_output_channels = 3 : returned image is 3 channel with r = g = b + + """ + + if num_output_channels == 1: + img = img.convert('L') + elif num_output_channels == 3: + img = img.convert('L') + np_img = np.array(img, dtype=np.uint8) + np_img = np.dstack([np_img, np_img, np_img]) + img = Image.fromarray(np_img, 'RGB') + else: + raise ValueError('num_output_channels should be either 1 or 3') + + return img diff --git a/python/paddle/vision/transforms/functional_tensor.py b/python/paddle/vision/transforms/functional_tensor.py new file mode 100644 index 0000000000..e8b70820dd --- /dev/null +++ b/python/paddle/vision/transforms/functional_tensor.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import paddle + + +def normalize(img, mean, std, data_format='CHW'): + """Normalizes a tensor image with mean and standard deviation. + + Args: + img (paddle.Tensor): input data to be normalized. + mean (list|tuple): Sequence of means for each channel. + std (list|tuple): Sequence of standard deviations for each channel. + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + + Returns: + Tensor: Normalized mage. + + """ + if data_format == 'CHW': + mean = paddle.to_tensor(mean).reshape([-1, 1, 1]) + std = paddle.to_tensor(std).reshape([-1, 1, 1]) + else: + mean = paddle.to_tensor(mean) + std = paddle.to_tensor(std) + return (img - mean) / std diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index 9ea8282717..9079f91aac 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -36,30 +36,50 @@ else: Iterable = collections.abc.Iterable __all__ = [ - "Compose", - "BatchCompose", - "Resize", - "RandomResizedCrop", - "CenterCropResize", - "CenterCrop", - "RandomHorizontalFlip", - "RandomVerticalFlip", - "Permute", - "Normalize", - "GaussianNoise", - "BrightnessTransform", - "SaturationTransform", - "ContrastTransform", - "HueTransform", - "ColorJitter", - "RandomCrop", - "RandomErasing", - "Pad", - "RandomRotate", - "Grayscale", + "BaseTransform", "Compose", "Resize", "RandomResizedCrop", "CenterCrop", + "RandomHorizontalFlip", "RandomVerticalFlip", "Transpose", "Normalize", + "BrightnessTransform", "SaturationTransform", "ContrastTransform", + "HueTransform", "ColorJitter", "RandomCrop", "Pad", "RandomRotation", + "Grayscale", "ToTensor" ] +def _get_image_size(img): + if F._is_pil_image(img): + return img.size + elif F._is_numpy_image(img): + return img.shape[:2][::-1] + else: + raise TypeError("Unexpected type {}".format(type(img))) + + +def _check_input(value, + name, + center=1, + bound=(0, float('inf')), + clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + "If {} is a single number, it must be non negative.".format( + name)) + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, + bound)) + else: + raise TypeError( + "{} should be a single number or a list/tuple with lenght 2.". + format(name)) + + if value[0] == value[1] == center: + value = None + return value + + class Compose(object): """ Composes several transforms together use for composing list of transforms @@ -91,15 +111,10 @@ class Compose(object): def __init__(self, transforms): self.transforms = transforms - def __call__(self, *data): + def __call__(self, data): for f in self.transforms: try: - # multi-fileds in a sample - if isinstance(data, Sequence): - data = f(*data) - # single field in a sample, call transform directly - else: - data = f(data) + data = f(data) except Exception as e: stack_info = traceback.format_exc() print("fail to perform transform [{}] with error: " @@ -116,96 +131,217 @@ class Compose(object): return format_string -class BatchCompose(object): - """Composes several batch transforms together +class BaseTransform(object): + """ + Base class of all transforms used in computer vision. - Args: - transforms (list): List of transforms to compose. - these transforms perform on batch data. + calling logic: + + if keys is None: + _get_params -> _apply_image() + else: + _get_params -> _apply_*() for * in keys + + If you want to implement a self-defined transform method for image, + rewrite _apply_* method in subclass. + Args: + keys (list[str]|tuple[str], optional): Input type. Input is a tuple contains different structures, + key is used to specify the type of input. For example, if your input + is image type, then the key can be None or ("image"). if your input + is (image, image) type, then the keys should be ("image", "image"). + if your input is (image, boxes), then the keys should be ("image", "boxes"). + + Current available strings & data type are describe below: + + - "image": input image, with shape of (H, W, C) + - "coords": coordinates, with shape of (N, 2) + - "boxes": bounding boxes, with shape of (N, 4), "xyxy" format, + + the 1st "xy" represents top left point of a box, + the 2nd "xy" represents right bottom point. + + - "mask": map used for segmentation, with shape of (H, W, 1) + + You can also customize your data types only if you implement the corresponding + _apply_*() methods, otherwise ``NotImplementedError`` will be raised. + Examples: .. code-block:: python import numpy as np - from paddle.io import DataLoader + from PIL import Image + import paddle.vision.transforms.functional as F + from paddle.vision.transforms import BaseTransform + + def _get_image_size(img): + if F._is_pil_image(img): + return img.size + elif F._is_numpy_image(img): + return img.shape[:2][::-1] + else: + raise TypeError("Unexpected type {}".format(type(img))) + + class CustomRandomFlip(BaseTransform): + def __init__(self, prob=0.5, keys=None): + super(CustomRandomFlip, self).__init__(keys) + self.prob = prob + + def _get_params(self, inputs): + image = inputs[self.keys.index('image')] + params = {} + params['flip'] = np.random.random() < self.prob + params['size'] = _get_image_size(image) + return params + + def _apply_image(self, image): + if self.params['flip']: + return F.hflip(image) + return image + + # if you only want to transform image, do not need to rewrite this function + def _apply_coords(self, coords): + if self.params['flip']: + w = self.params['size'][0] + coords[:, 0] = w - coords[:, 0] + return coords + + # if you only want to transform image, do not need to rewrite this function + def _apply_boxes(self, boxes): + idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten() + coords = np.asarray(boxes).reshape(-1, 4)[:, idxs].reshape(-1, 2) + coords = self._apply_coords(coords).reshape((-1, 4, 2)) + minxy = coords.min(axis=1) + maxxy = coords.max(axis=1) + trans_boxes = np.concatenate((minxy, maxxy), axis=1) + return trans_boxes + + # if you only want to transform image, do not need to rewrite this function + def _apply_mask(self, mask): + if self.params['flip']: + return F.hflip(mask) + return mask + + # create fake inputs + fake_img = Image.fromarray((np.random.rand(400, 500, 3) * 255.).astype('uint8')) + fake_boxes = np.array([[2, 3, 200, 300], [50, 60, 80, 100]]) + fake_mask = fake_img.convert('L') + + # only transform for image: + flip_transform = CustomRandomFlip(1.0) + converted_img = flip_transform(fake_img) + + # transform for image, boxes and mask + flip_transform = CustomRandomFlip(1.0, keys=('image', 'boxes', 'mask')) + (converted_img, converted_boxes, converted_mask) = flip_transform((fake_img, fake_boxes, fake_mask)) + print('converted boxes', converted_boxes) - from paddle import set_device - from paddle.vision.datasets import Flowers - from paddle.vision.transforms import Compose, BatchCompose, Resize - - class NormalizeBatch(object): - def __init__(self, - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225], - scale=True, - channel_first=True): - - self.mean = mean - self.std = std - self.scale = scale - self.channel_first = channel_first - if not (isinstance(self.mean, list) and isinstance(self.std, list) and - isinstance(self.scale, bool)): - raise TypeError("{}: input type is invalid.".format(self)) - from functools import reduce - if reduce(lambda x, y: x * y, self.std) == 0: - raise ValueError('{}: std is invalid!'.format(self)) - - def __call__(self, samples): - for i in range(len(samples)): - samples[i] = list(samples[i]) - im = samples[i][0] - im = im.astype(np.float32, copy=False) - mean = np.array(self.mean)[np.newaxis, np.newaxis, :] - std = np.array(self.std)[np.newaxis, np.newaxis, :] - if self.scale: - im = im / 255.0 - im -= mean - im /= std - if self.channel_first: - im = im.transpose((2, 0, 1)) - samples[i][0] = im - return samples - - transform = Compose([Resize((500, 500))]) - flowers_dataset = Flowers(mode='test', transform=transform) - - device = set_device('cpu') - - collate_fn = BatchCompose([NormalizeBatch()]) - loader = DataLoader( - flowers_dataset, - batch_size=4, - places=device, - return_list=True, - collate_fn=collate_fn) - - for data in loader: - # do something - break """ - def __init__(self, transforms=[]): - self.transforms = transforms + def __init__(self, keys=None): + if keys is None: + keys = ("image", ) + elif not isinstance(keys, Sequence): + raise ValueError( + "keys should be a sequence, but got keys={}".format(keys)) + for k in keys: + if self._get_apply(k) is None: + raise NotImplementedError( + "{} is unsupported data structure".format(k)) + self.keys = keys + + # storage some params get from function get_params() + self.params = None + + def _get_params(self, inputs): + pass + + def __call__(self, inputs): + """Apply transform on single input data""" + if not isinstance(inputs, tuple): + inputs = (inputs, ) + + self.params = self._get_params(inputs) + + outputs = [] + for i in range(min(len(inputs), len(self.keys))): + apply_func = self._get_apply(self.keys[i]) + if apply_func is None: + outputs.append(inputs[i]) + else: + outputs.append(apply_func(inputs[i])) + if len(inputs) > len(self.keys): + outputs.extend(input[len(self.keys):]) + + if len(outputs) == 1: + outputs = outputs[0] + else: + outputs = tuple(outputs) + return outputs - def __call__(self, data): - for f in self.transforms: - try: - data = f(data) - except Exception as e: - stack_info = traceback.format_exc() - print("fail to perform batch transform [{}] with error: " - "{} and stack:\n{}".format(f, e, str(stack_info))) - raise e + def _get_apply(self, key): + return getattr(self, "_apply_{}".format(key), None) - # sample list to batch data - batch = list(zip(*data)) + def _apply_image(self, image): + raise NotImplementedError - return batch + def _apply_boxes(self, boxes): + raise NotImplementedError + def _apply_mask(self, mask): + raise NotImplementedError -class Resize(object): + +class ToTensor(BaseTransform): + """Convert a ``PIL.Image`` or ``numpy.ndarray`` to ``paddle.Tensor``. + + Converts a PIL.Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a paddle.Tensor of shape (C x H x W) in the range [0.0, 1.0] + if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) + or if the numpy.ndarray has dtype = np.uint8 + + In the other cases, tensors are returned without scaling. + + Args: + data_format (str, optional): Data format of input img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + + Examples: + + .. code-block:: python + + import numpy as np + from PIL import Image + + import paddle.vision.transforms as T + import paddle.vision.transforms.functional as F + + fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8)) + + transform = T.ToTensor() + + tensor = transform(fake_img) + + """ + + def __init__(self, data_format='CHW', keys=None): + super(ToTensor, self).__init__(keys) + self.data_format = data_format + + def _apply_image(self, img): + """ + Args: + img (PIL.Image|np.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return F.to_tensor(img, self.data_format) + + +class Resize(BaseTransform): """Resize the input Image to the given size. Args: @@ -214,97 +350,111 @@ class Resize(object): smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) - interpolation (int, optional): Interpolation mode of resize. Default: 1. - 0 : cv2.INTER_NEAREST - 1 : cv2.INTER_LINEAR - 2 : cv2.INTER_CUBIC - 3 : cv2.INTER_AREA - 4 : cv2.INTER_LANCZOS4 - 5 : cv2.INTER_LINEAR_EXACT - 7 : cv2.INTER_MAX - 8 : cv2.WARP_FILL_OUTLIERS - 16: cv2.WARP_INVERSE_MAP + interpolation (int|str, optional): Interpolation method. Default: 'bilinear'. + when use pil backend, support method are as following: + - "nearest": Image.NEAREST, + - "bilinear": Image.BILINEAR, + - "bicubic": Image.BICUBIC, + - "box": Image.BOX, + - "lanczos": Image.LANCZOS, + - "hamming": Image.HAMMING + when use cv2 backend, support method are as following: + - "nearest": cv2.INTER_NEAREST, + - "bilinear": cv2.INTER_LINEAR, + - "area": cv2.INTER_AREA, + - "bicubic": cv2.INTER_CUBIC, + - "lanczos": cv2.INTER_LANCZOS4 + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import Resize transform = Resize(size=224) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(100, 120, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + print(fake_img.size) """ - def __init__(self, size, interpolation=1): + def __init__(self, size, interpolation='bilinear', keys=None): + super(Resize, self).__init__(keys) assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) self.size = size self.interpolation = interpolation - def __call__(self, img): + def _apply_image(self, img): return F.resize(img, self.size, self.interpolation) -class RandomResizedCrop(object): +class RandomResizedCrop(BaseTransform): """Crop the input data to random size and aspect ratio. A crop of random size (default: of 0.08 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made. After applying crop transfrom, the input data will be resized to given size. Args: - output_size (int|list|tuple): Target size of output image, with (height, width) shape. + size (int|list|tuple): Target size of output image, with (height, width) shape. scale (list|tuple): Range of size of the origin size cropped. Default: (0.08, 1.0) ratio (list|tuple): Range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33) - interpolation (int, optional): Interpolation mode of resize. Default: 1. - 0 : cv2.INTER_NEAREST - 1 : cv2.INTER_LINEAR - 2 : cv2.INTER_CUBIC - 3 : cv2.INTER_AREA - 4 : cv2.INTER_LANCZOS4 - 5 : cv2.INTER_LINEAR_EXACT - 7 : cv2.INTER_MAX - 8 : cv2.WARP_FILL_OUTLIERS - 16: cv2.WARP_INVERSE_MAP + interpolation (int|str, optional): Interpolation method. Default: 'bilinear'. when use pil backend, + support method are as following: + - "nearest": Image.NEAREST, + - "bilinear": Image.BILINEAR, + - "bicubic": Image.BICUBIC, + - "box": Image.BOX, + - "lanczos": Image.LANCZOS, + - "hamming": Image.HAMMING + when use cv2 backend, support method are as following: + - "nearest": cv2.INTER_NEAREST, + - "bilinear": cv2.INTER_LINEAR, + - "area": cv2.INTER_AREA, + - "bicubic": cv2.INTER_CUBIC, + - "lanczos": cv2.INTER_LANCZOS4 + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import RandomResizedCrop transform = RandomResizedCrop(224) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + print(fake_img.size) + """ def __init__(self, - output_size, + size, scale=(0.08, 1.0), ratio=(3. / 4, 4. / 3), - interpolation=1): - if isinstance(output_size, int): - self.output_size = (output_size, output_size) + interpolation='bilinear', + keys=None): + super(RandomResizedCrop, self).__init__(keys) + if isinstance(size, int): + self.size = (size, size) else: - self.output_size = output_size + self.size = size assert (scale[0] <= scale[1]), "scale should be of kind (min, max)" assert (ratio[0] <= ratio[1]), "ratio should be of kind (min, max)" self.scale = scale self.ratio = ratio self.interpolation = interpolation - def _get_params(self, image, attempts=10): - height, width, _ = image.shape + def _get_param(self, image, attempts=10): + width, height = _get_image_size(image) area = height * width for _ in range(attempts): @@ -316,9 +466,9 @@ class RandomResizedCrop(object): h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: - x = np.random.randint(0, width - w + 1) - y = np.random.randint(0, height - h + 1) - return x, y, w, h + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w # Fallback to central crop in_ratio = float(width) / float(height) @@ -328,179 +478,123 @@ class RandomResizedCrop(object): elif in_ratio > max(self.ratio): h = height w = int(round(h * max(self.ratio))) - else: # whole image + else: + # return whole image w = width h = height - x = (width - w) // 2 - y = (height - h) // 2 - return x, y, w, h - - def __call__(self, img): - x, y, w, h = self._get_params(img) - cropped_img = img[y:y + h, x:x + w] - return F.resize(cropped_img, self.output_size, self.interpolation) - - -class CenterCropResize(object): - """Crops to center of image with padding then scales size. - - Args: - size (int|list|tuple): Target size of output image, with (height, width) shape. - crop_padding (int): Center crop with the padding. Default: 32. - interpolation (int, optional): Interpolation mode of resize. Default: 1. - 0 : cv2.INTER_NEAREST - 1 : cv2.INTER_LINEAR - 2 : cv2.INTER_CUBIC - 3 : cv2.INTER_AREA - 4 : cv2.INTER_LANCZOS4 - 5 : cv2.INTER_LINEAR_EXACT - 7 : cv2.INTER_MAX - 8 : cv2.WARP_FILL_OUTLIERS - 16: cv2.WARP_INVERSE_MAP - - Examples: - - .. code-block:: python - - import numpy as np + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w - from paddle.vision.transforms import CenterCropResize + def _apply_image(self, img): + i, j, h, w = self._get_param(img) - transform = CenterCropResize(224) - - fake_img = np.random.rand(500, 500, 3).astype('float32') - - fake_img = transform(fake_img) - print(fake_img.shape) - """ - - def __init__(self, size, crop_padding=32, interpolation=1): - if isinstance(size, int): - self.size = (size, size) - else: - self.size = size - self.crop_padding = crop_padding - self.interpolation = interpolation - - def _get_params(self, img): - h, w = img.shape[:2] - size = min(self.size) - c = int(size / (size + self.crop_padding) * min((h, w))) - x = (h + 1 - c) // 2 - y = (w + 1 - c) // 2 - return c, x, y - - def __call__(self, img): - c, x, y = self._get_params(img) - cropped_img = img[x:x + c, y:y + c, :] + cropped_img = F.crop(img, i, j, h, w) return F.resize(cropped_img, self.size, self.interpolation) -class CenterCrop(object): +class CenterCrop(BaseTransform): """Crops the given the input data at the center. Args: - output_size: Target size of output image, with (height, width) shape. - + size (int|list|tuple): Target size of output image, with (height, width) shape. + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import CenterCrop transform = CenterCrop(224) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + print(fake_img.size) """ - def __init__(self, output_size): - if isinstance(output_size, int): - self.output_size = (output_size, output_size) + def __init__(self, size, keys=None): + super(CenterCrop, self).__init__(keys) + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) else: - self.output_size = output_size - - def _get_params(self, img): - th, tw = self.output_size - h, w, _ = img.shape - assert th <= h and tw <= w, "output size is bigger than image size" - x = int(round((w - tw) / 2.0)) - y = int(round((h - th) / 2.0)) - return x, y + self.size = size - def __call__(self, img): - x, y = self._get_params(img) - th, tw = self.output_size - return img[y:y + th, x:x + tw] + def _apply_image(self, img): + return F.center_crop(img, self.size) -class RandomHorizontalFlip(object): +class RandomHorizontalFlip(BaseTransform): """Horizontally flip the input data randomly with a given probability. Args: - prob (float): Probability of the input data being flipped. Default: 0.5 + prob (float, optional): Probability of the input data being flipped. Default: 0.5 + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import RandomHorizontalFlip transform = RandomHorizontalFlip(224) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + print(fake_img.size) """ - def __init__(self, prob=0.5): + def __init__(self, prob=0.5, keys=None): + super(RandomHorizontalFlip, self).__init__(keys) self.prob = prob - def __call__(self, img): - if np.random.random() < self.prob: - return F.flip(img, code=1) + def _apply_image(self, img): + if random.random() < self.prob: + return F.hflip(img) return img -class RandomVerticalFlip(object): +class RandomVerticalFlip(BaseTransform): """Vertically flip the input data randomly with a given probability. Args: - prob (float): Probability of the input data being flipped. Default: 0.5 + prob (float, optional): Probability of the input data being flipped. Default: 0.5 + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import RandomVerticalFlip transform = RandomVerticalFlip(224) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + print(fake_img.size) + """ - def __init__(self, prob=0.5): + def __init__(self, prob=0.5, keys=None): + super(RandomVerticalFlip, self).__init__(keys) self.prob = prob - def __call__(self, img): - if np.random.random() < self.prob: - return F.flip(img, code=0) + def _apply_image(self, img): + if random.random() < self.prob: + return F.vflip(img) return img -class Normalize(object): +class Normalize(BaseTransform): """Normalize the input data with mean and standard deviation. Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform will normalize each channel of the input data. @@ -509,286 +603,240 @@ class Normalize(object): Args: mean (int|float|list): Sequence of means for each channel. std (int|float|list): Sequence of standard deviations for each channel. - + data_format (str, optional): Data format of img, should be 'HWC' or + 'CHW'. Default: 'CHW'. + to_rgb (bool, optional): Whether to convert to rgb. Default: False. + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import Normalize - normalize = Normalize(mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5]) + normalize = Normalize(mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + data_format='HWC') - fake_img = np.random.rand(3, 500, 500).astype('float32') + fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8)) fake_img = normalize(fake_img) print(fake_img.shape) + print(fake_img.max, fake_img.max) """ - def __init__(self, mean=0.0, std=1.0): + def __init__(self, + mean=0.0, + std=1.0, + data_format='CHW', + to_rgb=False, + keys=None): + super(Normalize, self).__init__(keys) if isinstance(mean, numbers.Number): mean = [mean, mean, mean] if isinstance(std, numbers.Number): std = [std, std, std] - self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1) - self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1) + self.mean = mean + self.std = std + self.data_format = data_format + self.to_rgb = to_rgb - def __call__(self, img): - return (img - self.mean) / self.std + def _apply_image(self, img): + return F.normalize(img, self.mean, self.std, self.data_format, + self.to_rgb) -class Permute(object): - """Change input data to a target mode. +class Transpose(BaseTransform): + """Transpose input data to a target format. For example, most transforms use HWC mode image, while the Neural Network might use CHW mode input tensor. - Input image should be HWC mode and an instance of numpy.ndarray. + output image will be an instance of numpy.ndarray. Args: - mode (str): Output mode of input. Default: "CHW". - to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True. - + order (list|tuple, optional): Target order of input data. Default: (2, 0, 1). + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Examples: .. code-block:: python import numpy as np + from PIL import Image + from paddle.vision.transforms import Transpose - from paddle.vision.transforms import Permute + transform = Transpose() - transform = Permute() - - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) print(fake_img.shape) - """ - - def __init__(self, mode="CHW", to_rgb=True): - assert mode in [ - "CHW" - ], "Only support 'CHW' mode, but received mode: {}".format(mode) - self.mode = mode - self.to_rgb = to_rgb - - def __call__(self, img): - if self.to_rgb: - img = img[..., ::-1] - if self.mode == "CHW": - return img.transpose((2, 0, 1)) - return img - - -class GaussianNoise(object): - """Add random gaussian noise to the input data. - Gaussian noise is generated with given mean and std. - - Args: - mean (float): Gaussian mean used to generate noise. - std (float): Gaussian standard deviation used to generate noise. - - Examples: - .. code-block:: python - - import numpy as np - - from paddle.vision.transforms import GaussianNoise - - transform = GaussianNoise() - - fake_img = np.random.rand(500, 500, 3).astype('float32') - - fake_img = transform(fake_img) - print(fake_img.shape) """ - def __init__(self, mean=0.0, std=1.0): - self.mean = np.array(mean, dtype=np.float32) - self.std = np.array(std, dtype=np.float32) + def __init__(self, order=(2, 0, 1), keys=None): + super(Transpose, self).__init__(keys) + self.order = order + + def _apply_image(self, img): + if F._is_pil_image(img): + img = np.asarray(img) - def __call__(self, img): - dtype = img.dtype - noise = np.random.normal(self.mean, self.std, img.shape) * 255 - img = img + noise.astype(np.float32) - return np.clip(img, 0, 255).astype(dtype) + return img.transpose(self.order) -class BrightnessTransform(object): +class BrightnessTransform(BaseTransform): """Adjust brightness of the image. Args: value (float): How much to adjust the brightness. Can be any non negative number. 0 gives the original image + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import BrightnessTransform transform = BrightnessTransform(0.4) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + """ - def __init__(self, value): - if value < 0: - raise ValueError("brightness value should be non-negative") - self.value = value + def __init__(self, value, keys=None): + super(BrightnessTransform, self).__init__(keys) + self.value = _check_input(value, 'brightness') - def __call__(self, img): - if self.value == 0: + def _apply_image(self, img): + if self.value is None: return img - dtype = img.dtype - img = img.astype(np.float32) - alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value) - img = img * alpha - return img.clip(0, 255).astype(dtype) + brightness_factor = random.uniform(self.value[0], self.value[1]) + return F.adjust_brightness(img, brightness_factor) -class ContrastTransform(object): +class ContrastTransform(BaseTransform): """Adjust contrast of the image. Args: value (float): How much to adjust the contrast. Can be any non negative number. 0 gives the original image + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import ContrastTransform transform = ContrastTransform(0.4) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + """ - def __init__(self, value): + def __init__(self, value, keys=None): + super(ContrastTransform, self).__init__(keys) if value < 0: raise ValueError("contrast value should be non-negative") - self.value = value + self.value = _check_input(value, 'contrast') - def __call__(self, img): - if self.value == 0: + def _apply_image(self, img): + if self.value is None: return img - cv2 = try_import('cv2') - dtype = img.dtype - img = img.astype(np.float32) - alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value) - img = img * alpha + cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).mean() * ( - 1 - alpha) - return img.clip(0, 255).astype(dtype) + contrast_factor = random.uniform(self.value[0], self.value[1]) + return F.adjust_contrast(img, contrast_factor) -class SaturationTransform(object): +class SaturationTransform(BaseTransform): """Adjust saturation of the image. Args: value (float): How much to adjust the saturation. Can be any non negative number. 0 gives the original image + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import SaturationTransform transform = SaturationTransform(0.4) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + """ - def __init__(self, value): - if value < 0: - raise ValueError("saturation value should be non-negative") - self.value = value + def __init__(self, value, keys=None): + super(SaturationTransform, self).__init__(keys) + self.value = _check_input(value, 'saturation') - def __call__(self, img): - if self.value == 0: + def _apply_image(self, img): + if self.value is None: return img - cv2 = try_import('cv2') + saturation_factor = random.uniform(self.value[0], self.value[1]) + return F.adjust_saturation(img, saturation_factor) - dtype = img.dtype - img = img.astype(np.float32) - alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value) - gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - gray_img = gray_img[..., np.newaxis] - img = img * alpha + gray_img * (1 - alpha) - return img.clip(0, 255).astype(dtype) - -class HueTransform(object): +class HueTransform(BaseTransform): """Adjust hue of the image. Args: value (float): How much to adjust the hue. Can be any number between 0 and 0.5, 0 gives the original image + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import HueTransform transform = HueTransform(0.4) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + """ - def __init__(self, value): - if value < 0 or value > 0.5: - raise ValueError("hue value should be in [0.0, 0.5]") - self.value = value + def __init__(self, value, keys=None): + super(HueTransform, self).__init__(keys) + self.value = _check_input( + value, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) - def __call__(self, img): - if self.value == 0: + def _apply_image(self, img): + if self.value is None: return img - cv2 = try_import('cv2') - dtype = img.dtype - img = img.astype(np.uint8) - hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV_FULL) - h, s, v = cv2.split(hsv_img) - - alpha = np.random.uniform(-self.value, self.value) - h = h.astype(np.uint8) - # uint8 addition take cares of rotation across boundaries - with np.errstate(over="ignore"): - h += np.uint8(alpha * 255) - hsv_img = cv2.merge([h, s, v]) - return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR_FULL).astype(dtype) + hue_factor = random.uniform(self.value[0], self.value[1]) + return F.adjust_hue(img, hue_factor) -class ColorJitter(object): +class ColorJitter(BaseTransform): """Randomly change the brightness, contrast, saturation and hue of an image. Args: @@ -800,42 +848,74 @@ class ColorJitter(object): Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. Should be non negative numbers. hue: How much to jitter hue. Chosen uniformly from [-hue, hue]. Should have 0<= hue <= 0.5. + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import ColorJitter - transform = ColorJitter(0.4) + transform = ColorJitter(0.4, 0.4, 0.4, 0.4) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + """ - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, + keys=None): + super(ColorJitter, self).__init__(keys) + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def _get_param(self, brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + + Arguments are same as that of __init__. + + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ transforms = [] - if brightness != 0: - transforms.append(BrightnessTransform(brightness)) - if contrast != 0: - transforms.append(ContrastTransform(contrast)) - if saturation != 0: - transforms.append(SaturationTransform(saturation)) - if hue != 0: - transforms.append(HueTransform(hue)) + + if brightness is not None: + transforms.append(BrightnessTransform(brightness, self.keys)) + + if contrast is not None: + transforms.append(ContrastTransform(contrast, self.keys)) + + if saturation is not None: + transforms.append(SaturationTransform(saturation, self.keys)) + + if hue is not None: + transforms.append(HueTransform(hue, self.keys)) random.shuffle(transforms) - self.transforms = Compose(transforms) + transform = Compose(transforms) - def __call__(self, img): - return self.transforms(img) + return transform + def _apply_image(self, img): + """ + Args: + img (PIL Image): Input image. -class RandomCrop(object): + Returns: + PIL Image: Color jittered image. + """ + transform = self._get_param(self.brightness, self.contrast, + self.saturation, self.hue) + return transform(img) + + +class RandomCrop(BaseTransform): """Crops the given CV Image at a random location. Args: @@ -847,159 +927,88 @@ class RandomCrop(object): top, right, bottom borders respectively. Default: 0. pad_if_needed (boolean|optional): It will pad the image if smaller than the desired size to avoid raising an exception. Default: False. - + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import RandomCrop transform = RandomCrop(224) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(324, 300, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + print(fake_img.size) """ - def __init__(self, size, padding=0, pad_if_needed=False): + def __init__(self, + size, + padding=None, + pad_if_needed=False, + fill=0, + padding_mode='constant', + keys=None): + super(RandomCrop, self).__init__(keys) if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size self.padding = padding self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode - def _get_params(self, img, output_size): + def _get_param(self, img, output_size): """Get parameters for ``crop`` for a random crop. Args: - img (numpy.ndarray): Image to be cropped. + img (PIL Image): Image to be cropped. output_size (tuple): Expected output size of the crop. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. - """ - h, w, _ = img.shape + w, h = _get_image_size(img) th, tw = output_size if w == tw and h == th: return 0, 0, h, w - try: - i = random.randint(0, h - th) - except ValueError: - i = random.randint(h - th, 0) - try: - j = random.randint(0, w - tw) - except ValueError: - j = random.randint(w - tw, 0) + i = random.randint(0, h - th) + j = random.randint(0, w - tw) return i, j, th, tw - def __call__(self, img): + def _apply_image(self, img): """ - Args: - img (numpy.ndarray): Image to be cropped. - Returns: - numpy.ndarray: Cropped image. + img (PIL Image): Image to be cropped. + Returns: + PIL Image: Cropped image. """ - if self.padding > 0: - img = F.pad(img, self.padding) + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + w, h = _get_image_size(img) # pad the width if needed - if self.pad_if_needed and img.shape[1] < self.size[1]: - img = F.pad(img, (int((1 + self.size[1] - img.shape[1]) / 2), 0)) + if self.pad_if_needed and w < self.size[1]: + img = F.pad(img, (self.size[1] - w, 0), self.fill, + self.padding_mode) # pad the height if needed - if self.pad_if_needed and img.shape[0] < self.size[0]: - img = F.pad(img, (0, int((1 + self.size[0] - img.shape[0]) / 2))) - - i, j, h, w = self._get_params(img, self.size) - - return img[i:i + h, j:j + w] - - -class RandomErasing(object): - """Randomly selects a rectangle region in an image and erases its pixels. - ``Random Erasing Data Augmentation`` by Zhong et al. - See https://arxiv.org/pdf/1708.04896.pdf - - Args: - prob (float): probability that the random erasing operation will be performed. - scale (tuple): range of proportion of erased area against input image. Should be (min, max). - ratio (float): range of aspect ratio of erased area. - value (float|list|tuple): erasing value. If a single int, it is used to - erase all pixels. If a tuple of length 3, it is used to erase - R, G, B channels respectively. Default: 0. - - Examples: - - .. code-block:: python - - import numpy as np - - from paddle.vision.transforms import RandomCrop - - transform = RandomCrop(224) - - fake_img = np.random.rand(500, 500, 3).astype('float32') - - fake_img = transform(fake_img) - print(fake_img.shape) - """ - - def __init__(self, - prob=0.5, - scale=(0.02, 0.4), - ratio=0.3, - value=[0., 0., 0.]): - assert isinstance(value, ( - float, Sequence - )), "Expected type of value in [float, list, tupue], but got {}".format( - type(value)) - assert scale[0] <= scale[1], "scale range should be of kind (min, max)!" - - if isinstance(value, float): - self.value = [value, value, value] - else: - self.value = value - - self.p = prob - self.scale = scale - self.ratio = ratio - - def __call__(self, img): - if random.uniform(0, 1) > self.p: - return img - - for _ in range(100): - area = img.shape[0] * img.shape[1] - - target_area = random.uniform(self.scale[0], self.scale[1]) * area - aspect_ratio = random.uniform(self.ratio, 1 / self.ratio) - - h = int(round(math.sqrt(target_area * aspect_ratio))) - w = int(round(math.sqrt(target_area / aspect_ratio))) + if self.pad_if_needed and h < self.size[0]: + img = F.pad(img, (0, self.size[0] - h), self.fill, + self.padding_mode) - if w < img.shape[1] and h < img.shape[0]: - x1 = random.randint(0, img.shape[0] - h) - y1 = random.randint(0, img.shape[1] - w) + i, j, h, w = self._get_param(img, self.size) - if len(img.shape) == 3 and img.shape[2] == 3: - img[x1:x1 + h, y1:y1 + w, 0] = self.value[0] - img[x1:x1 + h, y1:y1 + w, 1] = self.value[1] - img[x1:x1 + h, y1:y1 + w, 2] = self.value[2] - else: - img[x1:x1 + h, y1:y1 + w] = self.value[1] - return img - - return img + return F.crop(img, i, j, h, w) -class Pad(object): +class Pad(BaseTransform): """Pads the given CV Image on all sides with the given "pad" value. Args: @@ -1020,64 +1029,73 @@ class Pad(object): ``symmetric`` menas pads with reflection of image (repeating the last value on the edge) padding ``[1, 2, 3, 4]`` with 2 elements on both sides in symmetric mode will result in ``[2, 1, 1, 2, 3, 4, 4, 3]``. - + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Examples: .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import Pad transform = Pad(2) - fake_img = np.random.rand(500, 500, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + print(fake_img.size) """ - def __init__(self, padding, fill=0, padding_mode='constant'): + def __init__(self, padding, fill=0, padding_mode='constant', keys=None): assert isinstance(padding, (numbers.Number, list, tuple)) assert isinstance(fill, (numbers.Number, str, list, tuple)) assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] - if isinstance(padding, - collections.Sequence) and len(padding) not in [2, 4]: + + if isinstance(padding, list): + padding = tuple(padding) + if isinstance(fill, list): + fill = tuple(fill) + + if isinstance(padding, Sequence) and len(padding) not in [2, 4]: raise ValueError( "Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) + super(Pad, self).__init__(keys) self.padding = padding self.fill = fill self.padding_mode = padding_mode - def __call__(self, img): + def _apply_image(self, img): """ Args: - img (numpy.ndarray): Image to be padded. + img (PIL Image): Image to be padded. + Returns: - numpy.ndarray: Padded image. + PIL Image: Padded image. """ return F.pad(img, self.padding, self.fill, self.padding_mode) -class RandomRotate(object): +class RandomRotation(BaseTransform): """Rotates the image by angle. Args: degrees (sequence or float or int): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees) clockwise order. - interpolation (int, optional): Interpolation mode of resize. Default: 1. - 0 : cv2.INTER_NEAREST - 1 : cv2.INTER_LINEAR - 2 : cv2.INTER_CUBIC - 3 : cv2.INTER_AREA - 4 : cv2.INTER_LANCZOS4 - 5 : cv2.INTER_LINEAR_EXACT - 7 : cv2.INTER_MAX - 8 : cv2.WARP_FILL_OUTLIERS - 16: cv2.WARP_INVERSE_MAP + interpolation (int|str, optional): Interpolation method. Default: 'bilinear'. + resample (int|str, optional): An optional resampling filter. If omitted, or if the + image has only one channel, it is set to PIL.Image.NEAREST or cv2.INTER_NEAREST + according the backend. when use pil backend, support method are as following: + - "nearest": Image.NEAREST, + - "bilinear": Image.BILINEAR, + - "bicubic": Image.BICUBIC + when use cv2 backend, support method are as following: + - "nearest": cv2.INTER_NEAREST, + - "bilinear": cv2.INTER_LINEAR, + - "bicubic": cv2.INTER_CUBIC expand (bool|optional): Optional expansion flag. Default: False. If true, expands the output to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -1085,24 +1103,31 @@ class RandomRotate(object): center (2-tuple|optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. - + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Examples: .. code-block:: python import numpy as np + from PIL import Image + from paddle.vision.transforms import RandomRotation - from paddle.vision.transforms import RandomRotate - - transform = RandomRotate(90) + transform = RandomRotation(90) - fake_img = np.random.rand(500, 400, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(200, 150, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + print(fake_img.size) """ - def __init__(self, degrees, interpolation=1, expand=False, center=None): + def __init__(self, + degrees, + resample=False, + expand=False, + center=None, + fill=0, + keys=None): if isinstance(degrees, numbers.Number): if degrees < 0: raise ValueError( @@ -1114,37 +1139,39 @@ class RandomRotate(object): "If degrees is a sequence, it must be of len 2.") self.degrees = degrees - self.interpolation = interpolation + super(RandomRotation, self).__init__(keys) + self.resample = resample self.expand = expand self.center = center + self.fill = fill - def _get_params(self, degrees): - """Get parameters for ``rotate`` for a random rotation. - Returns: - sequence: params to be passed to ``rotate`` for random rotation. - """ + def _get_param(self, degrees): angle = random.uniform(degrees[0], degrees[1]) return angle - def __call__(self, img): + def _apply_image(self, img): """ - img (np.ndarray): Image to be rotated. + Args: + img (PIL.Image|np.array): Image to be rotated. + Returns: - np.ndarray: Rotated image. + PIL.Image or np.array: Rotated image. """ - angle = self._get_params(self.degrees) + angle = self._get_param(self.degrees) - return F.rotate(img, angle, self.interpolation, self.expand, - self.center) + return F.rotate(img, angle, self.resample, self.expand, self.center, + self.fill) -class Grayscale(object): +class Grayscale(BaseTransform): """Converts image to grayscale. Args: - output_channels (int): (1 or 3) number of channels desired for output image + num_output_channels (int): (1 or 3) number of channels desired for output image + keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Returns: CV Image: Grayscale version of the input. - If output_channels == 1 : returned image is single channel @@ -1155,25 +1182,27 @@ class Grayscale(object): .. code-block:: python import numpy as np - + from PIL import Image from paddle.vision.transforms import Grayscale transform = Grayscale() - fake_img = np.random.rand(500, 400, 3).astype('float32') + fake_img = Image.fromarray((np.random.rand(224, 224, 3) * 255.).astype(np.uint8)) fake_img = transform(fake_img) - print(fake_img.shape) + print(np.array(fake_img).shape) """ - def __init__(self, output_channels=1): - self.output_channels = output_channels + def __init__(self, num_output_channels=1, keys=None): + super(Grayscale, self).__init__(keys) + self.num_output_channels = num_output_channels - def __call__(self, img): + def _apply_image(self, img): """ Args: - img (numpy.ndarray): Image to be converted to grayscale. + img (PIL Image): Image to be converted to grayscale. + Returns: - numpy.ndarray: Randomly grayscaled image. + PIL Image: Randomly grayscaled image. """ - return F.to_grayscale(img, num_output_channels=self.output_channels) + return F.to_grayscale(img, self.num_output_channels) -- GitLab