flowers.py 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tarfile
import numpy as np
from PIL import Image

20
import paddle
21
from paddle.io import Dataset
L
LielinJiang 已提交
22
from paddle.utils import try_import
23
from paddle.dataset.common import _check_exists_and_download
24

25
__all__ = []
26 27 28 29 30 31 32 33 34 35 36

DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'

# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
K
Kaipeng Deng 已提交
37
MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': 'valid'}
38 39 40 41


class Flowers(Dataset):
    """
42 43
    Implementation of `Flowers102 <https://www.robots.ox.ac.uk/~vgg/data/flowers/>`_
    dataset.
44 45

    Args:
46 47 48 49 50 51 52 53 54 55 56 57
        data_file (str, optional): Path to data file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        label_file (str, optional): Path to label file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        setid_file (str, optional): Path to subset index file, can be set
            None if :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        mode (str, optional): Either train or test mode. Default 'train'.
        transform (Callable, optional): transform to perform on image, None for no transform. Default: None.
        download (bool, optional): download dataset automatically if :attr:`data_file` is None. Default: True.
        backend (str, optional): Specifies which type of image to be returned:
            PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
            If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_vision_image_get_image_backend>`,
58
            default backend is 'pil'. Default: None.
59

60 61 62
    Returns:
        :ref:`api_paddle_io_Dataset`. An instance of Flowers dataset.

63
    Examples:
64

65 66
        .. code-block:: python

67 68
            import itertools
            import paddle.vision.transforms as T
69
            from paddle.vision.datasets import Flowers
70 71


72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
            flowers = Flowers()
            print(len(flowers))
            # 6149

            for i in range(5):  # only show first 5 images
                img, label = flowers[i]
                # do something with img and label
                print(type(img), img.size, label)
                # <class 'PIL.JpegImagePlugin.JpegImageFile'> (523, 500) [1]


            transform = T.Compose(
                [
                    T.Resize(64),
                    T.ToTensor(),
                    T.Normalize(
                        mean=[0.5, 0.5, 0.5],
                        std=[0.5, 0.5, 0.5],
                        to_rgb=True,
                    ),
                ]
            )

            flowers_test = Flowers(
                mode="test",
                transform=transform,  # apply transform to every image
                backend="cv2",  # use OpenCV as image transform backend
            )
            print(len(flowers_test))
            # 1020

            for img, label in itertools.islice(iter(flowers_test), 5):  # only show first 5 images
                # do something with img and label
                print(type(img), img.shape, label)
                # <class 'paddle.Tensor'> [3, 64, 96] [1]
107 108 109 110 111 112 113 114
    """

    def __init__(self,
                 data_file=None,
                 label_file=None,
                 setid_file=None,
                 mode='train',
                 transform=None,
115 116
                 download=True,
                 backend=None):
117 118
        assert mode.lower() in ['train', 'valid', 'test'], \
                "mode should be 'train', 'valid' or 'test', but got {}".format(mode)
119 120 121 122 123

        if backend is None:
            backend = paddle.vision.get_image_backend()
        if backend not in ['pil', 'cv2']:
            raise ValueError(
124 125
                "Expected backend are one of ['pil', 'cv2'], but got {}".format(
                    backend))
126 127
        self.backend = backend

128
        flag = MODE_FLAG_MAP[mode.lower()]
129

130
        if not data_file:
K
Kaipeng Deng 已提交
131
            assert download, "data_file is not set and downloading automatically is disabled"
132 133 134
            data_file = _check_exists_and_download(data_file, DATA_URL,
                                                   DATA_MD5, 'flowers',
                                                   download)
135

136
        if not label_file:
K
Kaipeng Deng 已提交
137
            assert download, "label_file is not set and downloading automatically is disabled"
138 139 140
            label_file = _check_exists_and_download(label_file, LABEL_URL,
                                                    LABEL_MD5, 'flowers',
                                                    download)
141

142
        if not setid_file:
K
Kaipeng Deng 已提交
143
            assert download, "setid_file is not set and downloading automatically is disabled"
144 145 146
            setid_file = _check_exists_and_download(setid_file, SETID_URL,
                                                    SETID_MD5, 'flowers',
                                                    download)
147 148 149

        self.transform = transform

150 151 152 153 154
        data_tar = tarfile.open(data_file)
        self.data_path = data_file.replace(".tgz", "/")
        if not os.path.exists(self.data_path):
            os.mkdir(self.data_path)
        data_tar.extractall(self.data_path)
155

L
LielinJiang 已提交
156
        scio = try_import('scipy.io')
157 158
        self.labels = scio.loadmat(label_file)['labels'][0]
        self.indexes = scio.loadmat(setid_file)[flag][0]
159 160 161 162 163

    def __getitem__(self, idx):
        index = self.indexes[idx]
        label = np.array([self.labels[index - 1]])
        img_name = "jpg/image_%05d.jpg" % index
164
        image = os.path.join(self.data_path, img_name)
165
        if self.backend == 'pil':
166
            image = Image.open(image)
167
        elif self.backend == 'cv2':
168
            image = np.array(Image.open(image))
169 170 171 172

        if self.transform is not None:
            image = self.transform(image)

173 174 175
        if self.backend == 'pil':
            return image, label.astype('int64')

176
        return image.astype(paddle.get_default_dtype()), label.astype('int64')
177 178 179

    def __len__(self):
        return len(self.indexes)