flowers.py 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import io
import tarfile
import numpy as np
from PIL import Image

21
import paddle
22
from paddle.io import Dataset
L
LielinJiang 已提交
23
from paddle.utils import try_import
24
from paddle.dataset.common import _check_exists_and_download
25

26
__all__ = []
27 28 29 30 31 32 33 34 35 36 37

DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'

# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
K
Kaipeng Deng 已提交
38
MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': 'valid'}
39 40 41 42


class Flowers(Dataset):
    """
43 44
    Implementation of `Flowers102 <https://www.robots.ox.ac.uk/~vgg/data/flowers/>`_
    dataset.
45 46

    Args:
47 48 49 50 51 52 53 54 55 56 57 58
        data_file (str, optional): Path to data file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        label_file (str, optional): Path to label file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        setid_file (str, optional): Path to subset index file, can be set
            None if :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        mode (str, optional): Either train or test mode. Default 'train'.
        transform (Callable, optional): transform to perform on image, None for no transform. Default: None.
        download (bool, optional): download dataset automatically if :attr:`data_file` is None. Default: True.
        backend (str, optional): Specifies which type of image to be returned:
            PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
            If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_vision_image_get_image_backend>`,
59
            default backend is 'pil'. Default: None.
60

61 62 63
    Returns:
        :ref:`api_paddle_io_Dataset`. An instance of Flowers dataset.

64
    Examples:
65

66 67
        .. code-block:: python

68 69
            import itertools
            import paddle.vision.transforms as T
70
            from paddle.vision.datasets import Flowers
71 72


73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
            flowers = Flowers()
            print(len(flowers))
            # 6149

            for i in range(5):  # only show first 5 images
                img, label = flowers[i]
                # do something with img and label
                print(type(img), img.size, label)
                # <class 'PIL.JpegImagePlugin.JpegImageFile'> (523, 500) [1]


            transform = T.Compose(
                [
                    T.Resize(64),
                    T.ToTensor(),
                    T.Normalize(
                        mean=[0.5, 0.5, 0.5],
                        std=[0.5, 0.5, 0.5],
                        to_rgb=True,
                    ),
                ]
            )

            flowers_test = Flowers(
                mode="test",
                transform=transform,  # apply transform to every image
                backend="cv2",  # use OpenCV as image transform backend
            )
            print(len(flowers_test))
            # 1020

            for img, label in itertools.islice(iter(flowers_test), 5):  # only show first 5 images
                # do something with img and label
                print(type(img), img.shape, label)
                # <class 'paddle.Tensor'> [3, 64, 96] [1]
108 109 110 111 112 113 114 115
    """

    def __init__(self,
                 data_file=None,
                 label_file=None,
                 setid_file=None,
                 mode='train',
                 transform=None,
116 117
                 download=True,
                 backend=None):
118 119
        assert mode.lower() in ['train', 'valid', 'test'], \
                "mode should be 'train', 'valid' or 'test', but got {}".format(mode)
120 121 122 123 124

        if backend is None:
            backend = paddle.vision.get_image_backend()
        if backend not in ['pil', 'cv2']:
            raise ValueError(
125 126
                "Expected backend are one of ['pil', 'cv2'], but got {}".format(
                    backend))
127 128
        self.backend = backend

129
        flag = MODE_FLAG_MAP[mode.lower()]
130

131
        if not data_file:
K
Kaipeng Deng 已提交
132
            assert download, "data_file is not set and downloading automatically is disabled"
133 134 135
            data_file = _check_exists_and_download(data_file, DATA_URL,
                                                   DATA_MD5, 'flowers',
                                                   download)
136

137
        if not label_file:
K
Kaipeng Deng 已提交
138
            assert download, "label_file is not set and downloading automatically is disabled"
139 140 141
            label_file = _check_exists_and_download(label_file, LABEL_URL,
                                                    LABEL_MD5, 'flowers',
                                                    download)
142

143
        if not setid_file:
K
Kaipeng Deng 已提交
144
            assert download, "setid_file is not set and downloading automatically is disabled"
145 146 147
            setid_file = _check_exists_and_download(setid_file, SETID_URL,
                                                    SETID_MD5, 'flowers',
                                                    download)
148 149 150

        self.transform = transform

151 152 153 154 155
        data_tar = tarfile.open(data_file)
        self.data_path = data_file.replace(".tgz", "/")
        if not os.path.exists(self.data_path):
            os.mkdir(self.data_path)
        data_tar.extractall(self.data_path)
156

L
LielinJiang 已提交
157
        scio = try_import('scipy.io')
158 159
        self.labels = scio.loadmat(label_file)['labels'][0]
        self.indexes = scio.loadmat(setid_file)[flag][0]
160 161 162 163 164

    def __getitem__(self, idx):
        index = self.indexes[idx]
        label = np.array([self.labels[index - 1]])
        img_name = "jpg/image_%05d.jpg" % index
165
        image = os.path.join(self.data_path, img_name)
166
        if self.backend == 'pil':
167
            image = Image.open(image)
168
        elif self.backend == 'cv2':
169
            image = np.array(Image.open(image))
170 171 172 173

        if self.transform is not None:
            image = self.transform(image)

174 175 176
        if self.backend == 'pil':
            return image, label.astype('int64')

177
        return image.astype(paddle.get_default_dtype()), label.astype('int64')
178 179 180

    def __len__(self):
        return len(self.indexes)