image.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from PIL import Image
from paddle.utils import try_import

__all__ = ['set_image_backend', 'get_image_backend', 'image_load']

_image_backend = 'pil'


def set_image_backend(backend):
    """
    Specifies the backend used to load images in class ``paddle.vision.datasets.ImageFolder`` 
    and ``paddle.vision.datasets.DatasetFolder`` . Now support backends are pillow and opencv. 
    If backend not set, will use 'pil' as default. 

    Args:
        backend (str): Name of the image load backend, should be one of {'pil', 'cv2'}.

    Examples:
    
        .. code-block:: python

            import os
            import shutil
            import tempfile
            import numpy as np
            from PIL import Image

            from paddle.vision import DatasetFolder
            from paddle.vision import set_image_backend

            set_image_backend('pil')

            def make_fake_dir():
                data_dir = tempfile.mkdtemp()

                for i in range(2):
                    sub_dir = os.path.join(data_dir, 'class_' + str(i))
                    if not os.path.exists(sub_dir):
                        os.makedirs(sub_dir)
                    for j in range(2):
                        fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8'))
                        fake_img.save(os.path.join(sub_dir, str(j) + '.png'))
                return data_dir

            temp_dir = make_fake_dir()

            pil_data_folder = DatasetFolder(temp_dir)

            for items in pil_data_folder:
                break

            # should get PIL.Image.Image
            print(type(items[0]))

            # use opencv as backend
            # set_image_backend('cv2')

            # cv2_data_folder = DatasetFolder(temp_dir)

            # for items in cv2_data_folder:
            #     break

            # should get numpy.ndarray
            # print(type(items[0]))

            shutil.rmtree(temp_dir)
    """
    global _image_backend
83
    if backend not in ['pil', 'cv2', 'tensor']:
84
        raise ValueError(
85
            "Expected backend are one of ['pil', 'cv2', 'tensor'], but got {}"
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
            .format(backend))
    _image_backend = backend


def get_image_backend():
    """
    Gets the name of the package used to load images

    Returns:
        str: backend of image load.

    Examples:
    
        .. code-block:: python

            from paddle.vision import get_image_backend

            backend = get_image_backend()
            print(backend)

    """
    return _image_backend


def image_load(path, backend=None):
    """Load an image.

    Args:
        path (str): Path of the image.
        backend (str, optional): The image decoding backend type. Options are
            `cv2`, `pil`, `None`. If backend is None, the global _imread_backend 
            specified by ``paddle.vision.set_image_backend`` will be used. Default: None.

    Returns:
        PIL.Image or np.array: Loaded image.

    Examples:
    
        .. code-block:: python

            import numpy as np
            from PIL import Image
            from paddle.vision import image_load, set_image_backend

            fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8'))

            path = 'temp.png'
            fake_img.save(path)

            set_image_backend('pil')
            
            pil_img = image_load(path).convert('RGB')

            # should be PIL.Image.Image
            print(type(pil_img))

            # use opencv as backend
            # set_image_backend('cv2')

            # np_img = image_load(path)
            # # should get numpy.ndarray
            # print(type(np_img))
    
    """

    if backend is None:
        backend = _image_backend
153
    if backend not in ['pil', 'cv2', 'tensor']:
154
        raise ValueError(
155
            "Expected backend are one of ['pil', 'cv2', 'tensor'], but got {}"
156 157 158 159
            .format(backend))

    if backend == 'pil':
        return Image.open(path)
160
    elif backend == 'cv2':
161 162
        cv2 = try_import('cv2')
        return cv2.imread(path)