image.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from PIL import Image
16

17 18
from paddle.utils import try_import

19
__all__ = []
20 21 22 23 24 25

_image_backend = 'pil'


def set_image_backend(backend):
    """
I
Infinity_lee 已提交
26 27
    Specifies the backend used to load images in class :ref:`api_paddle_datasets_ImageFolder`
    and :ref:`api_paddle_datasets_DatasetFolder` . Now support backends are pillow and opencv.
28
    If backend not set, will use 'pil' as default.
29 30 31 32 33

    Args:
        backend (str): Name of the image load backend, should be one of {'pil', 'cv2'}.

    Examples:
34

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        .. code-block:: python

            import os
            import shutil
            import tempfile
            import numpy as np
            from PIL import Image

            from paddle.vision import DatasetFolder
            from paddle.vision import set_image_backend

            set_image_backend('pil')

            def make_fake_dir():
                data_dir = tempfile.mkdtemp()

                for i in range(2):
                    sub_dir = os.path.join(data_dir, 'class_' + str(i))
                    if not os.path.exists(sub_dir):
                        os.makedirs(sub_dir)
                    for j in range(2):
                        fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8'))
                        fake_img.save(os.path.join(sub_dir, str(j) + '.png'))
                return data_dir

            temp_dir = make_fake_dir()

            pil_data_folder = DatasetFolder(temp_dir)

            for items in pil_data_folder:
                break

            # should get PIL.Image.Image
            print(type(items[0]))

            # use opencv as backend
            # set_image_backend('cv2')

            # cv2_data_folder = DatasetFolder(temp_dir)

            # for items in cv2_data_folder:
            #     break

            # should get numpy.ndarray
            # print(type(items[0]))

            shutil.rmtree(temp_dir)
    """
    global _image_backend
84
    if backend not in ['pil', 'cv2', 'tensor']:
85
        raise ValueError(
86 87 88 89
            "Expected backend are one of ['pil', 'cv2', 'tensor'], but got {}".format(
                backend
            )
        )
90 91 92 93 94 95 96 97 98 99 100
    _image_backend = backend


def get_image_backend():
    """
    Gets the name of the package used to load images

    Returns:
        str: backend of image load.

    Examples:
101

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        .. code-block:: python

            from paddle.vision import get_image_backend

            backend = get_image_backend()
            print(backend)

    """
    return _image_backend


def image_load(path, backend=None):
    """Load an image.

    Args:
        path (str): Path of the image.
        backend (str, optional): The image decoding backend type. Options are
119
            `cv2`, `pil`, `None`. If backend is None, the global _imread_backend
I
Infinity_lee 已提交
120
            specified by :ref:`api_paddle_vision_set_image_backend` will be used. Default: None.
121 122 123 124 125

    Returns:
        PIL.Image or np.array: Loaded image.

    Examples:
126

127 128 129 130 131 132 133 134 135 136 137 138
        .. code-block:: python

            import numpy as np
            from PIL import Image
            from paddle.vision import image_load, set_image_backend

            fake_img = Image.fromarray((np.random.random((32, 32, 3)) * 255).astype('uint8'))

            path = 'temp.png'
            fake_img.save(path)

            set_image_backend('pil')
139

140 141 142 143 144 145 146 147 148 149 150
            pil_img = image_load(path).convert('RGB')

            # should be PIL.Image.Image
            print(type(pil_img))

            # use opencv as backend
            # set_image_backend('cv2')

            # np_img = image_load(path)
            # # should get numpy.ndarray
            # print(type(np_img))
151

152 153 154 155
    """

    if backend is None:
        backend = _image_backend
156
    if backend not in ['pil', 'cv2', 'tensor']:
157
        raise ValueError(
158 159 160 161
            "Expected backend are one of ['pil', 'cv2', 'tensor'], but got {}".format(
                backend
            )
        )
162 163 164

    if backend == 'pil':
        return Image.open(path)
165
    elif backend == 'cv2':
166 167
        cv2 = try_import('cv2')
        return cv2.imread(path)