flowers.py 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tarfile
17

18 19 20
import numpy as np
from PIL import Image

21
import paddle
22
from paddle.dataset.common import _check_exists_and_download
23
from paddle.io import Dataset
L
LielinJiang 已提交
24
from paddle.utils import try_import
25

26
__all__ = []
27 28 29 30 31 32 33 34 35 36 37

DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'

# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
K
Kaipeng Deng 已提交
38
MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': 'valid'}
39 40 41 42


class Flowers(Dataset):
    """
43 44
    Implementation of `Flowers102 <https://www.robots.ox.ac.uk/~vgg/data/flowers/>`_
    dataset.
45 46

    Args:
47 48 49 50 51 52 53 54 55 56 57 58
        data_file (str, optional): Path to data file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        label_file (str, optional): Path to label file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        setid_file (str, optional): Path to subset index file, can be set
            None if :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        mode (str, optional): Either train or test mode. Default 'train'.
        transform (Callable, optional): transform to perform on image, None for no transform. Default: None.
        download (bool, optional): download dataset automatically if :attr:`data_file` is None. Default: True.
        backend (str, optional): Specifies which type of image to be returned:
            PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
            If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_vision_image_get_image_backend>`,
59
            default backend is 'pil'. Default: None.
60

61 62 63
    Returns:
        :ref:`api_paddle_io_Dataset`. An instance of Flowers dataset.

64
    Examples:
65

66 67
        .. code-block:: python

68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
            >>> # doctest: +TIMEOUT(60)
            >>> import itertools
            >>> import paddle.vision.transforms as T
            >>> from paddle.vision.datasets import Flowers

            >>> flowers = Flowers()
            >>> print(len(flowers))
            6149

            >>> for i in range(5):  # only show first 5 images
            ...     img, label = flowers[i]
            ...     # do something with img and label
            ...     print(type(img), img.size, label)
            ...     # <class 'PIL.JpegImagePlugin.JpegImageFile'> (523, 500) [1]

            >>> transform = T.Compose(
            ...     [
            ...         T.Resize(64),
            ...         T.ToTensor(),
            ...         T.Normalize(
            ...             mean=[0.5, 0.5, 0.5],
            ...             std=[0.5, 0.5, 0.5],
            ...             to_rgb=True,
            ...         ),
            ...     ]
            ... )
            >>> flowers_test = Flowers(
            ...     mode="test",
            ...     transform=transform,  # apply transform to every image
            ...     backend="cv2",  # use OpenCV as image transform backend
            ... )
            >>> print(len(flowers_test))
            1020

            >>> for img, label in itertools.islice(iter(flowers_test), 5):  # only show first 5 images
            ...     # do something with img and label
            ...     print(type(img), img.shape, label)
            ...     # <class 'paddle.Tensor'> [3, 64, 96] [1]
106 107
    """

108 109 110 111 112 113 114 115 116 117 118 119 120 121
    def __init__(
        self,
        data_file=None,
        label_file=None,
        setid_file=None,
        mode='train',
        transform=None,
        download=True,
        backend=None,
    ):
        assert mode.lower() in [
            'train',
            'valid',
            'test',
122
        ], f"mode should be 'train', 'valid' or 'test', but got {mode}"
123 124 125 126 127

        if backend is None:
            backend = paddle.vision.get_image_backend()
        if backend not in ['pil', 'cv2']:
            raise ValueError(
128
                "Expected backend are one of ['pil', 'cv2'], but got {}".format(
129 130 131
                    backend
                )
            )
132 133
        self.backend = backend

134
        flag = MODE_FLAG_MAP[mode.lower()]
135

136
        if not data_file:
137 138 139 140 141 142
            assert (
                download
            ), "data_file is not set and downloading automatically is disabled"
            data_file = _check_exists_and_download(
                data_file, DATA_URL, DATA_MD5, 'flowers', download
            )
143

144
        if not label_file:
145 146 147 148 149 150
            assert (
                download
            ), "label_file is not set and downloading automatically is disabled"
            label_file = _check_exists_and_download(
                label_file, LABEL_URL, LABEL_MD5, 'flowers', download
            )
151

152
        if not setid_file:
153 154 155 156 157 158
            assert (
                download
            ), "setid_file is not set and downloading automatically is disabled"
            setid_file = _check_exists_and_download(
                setid_file, SETID_URL, SETID_MD5, 'flowers', download
            )
159 160 161

        self.transform = transform

162 163 164 165 166
        data_tar = tarfile.open(data_file)
        self.data_path = data_file.replace(".tgz", "/")
        if not os.path.exists(self.data_path):
            os.mkdir(self.data_path)
        data_tar.extractall(self.data_path)
167

L
LielinJiang 已提交
168
        scio = try_import('scipy.io')
169 170
        self.labels = scio.loadmat(label_file)['labels'][0]
        self.indexes = scio.loadmat(setid_file)[flag][0]
171 172 173 174 175

    def __getitem__(self, idx):
        index = self.indexes[idx]
        label = np.array([self.labels[index - 1]])
        img_name = "jpg/image_%05d.jpg" % index
176
        image = os.path.join(self.data_path, img_name)
177
        if self.backend == 'pil':
178
            image = Image.open(image)
179
        elif self.backend == 'cv2':
180
            image = np.array(Image.open(image))
181 182 183 184

        if self.transform is not None:
            image = self.transform(image)

185 186 187
        if self.backend == 'pil':
            return image, label.astype('int64')

188
        return image.astype(paddle.get_default_dtype()), label.astype('int64')
189 190 191

    def __len__(self):
        return len(self.indexes)