flowers.py 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import io
import tarfile
import numpy as np
from PIL import Image

23
import paddle
24
from paddle.io import Dataset
L
LielinJiang 已提交
25
from paddle.utils import try_import
26
from paddle.dataset.common import _check_exists_and_download
27

28
__all__ = []
29 30 31 32 33 34 35 36 37 38 39

DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'

# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
K
Kaipeng Deng 已提交
40
MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': 'valid'}
41 42 43 44


class Flowers(Dataset):
    """
45 46
    Implementation of `Flowers102 <https://www.robots.ox.ac.uk/~vgg/data/flowers/>`_
    dataset.
47 48

    Args:
49 50 51 52 53 54 55 56 57 58 59 60
        data_file (str, optional): Path to data file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        label_file (str, optional): Path to label file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        setid_file (str, optional): Path to subset index file, can be set
            None if :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
        mode (str, optional): Either train or test mode. Default 'train'.
        transform (Callable, optional): transform to perform on image, None for no transform. Default: None.
        download (bool, optional): download dataset automatically if :attr:`data_file` is None. Default: True.
        backend (str, optional): Specifies which type of image to be returned:
            PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
            If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_vision_image_get_image_backend>`,
61
            default backend is 'pil'. Default: None.
62

63 64 65
    Returns:
        :ref:`api_paddle_io_Dataset`. An instance of Flowers dataset.

66
    Examples:
67

68 69
        .. code-block:: python

70 71
            import itertools
            import paddle.vision.transforms as T
72
            from paddle.vision.datasets import Flowers
73 74


75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
            flowers = Flowers()
            print(len(flowers))
            # 6149

            for i in range(5):  # only show first 5 images
                img, label = flowers[i]
                # do something with img and label
                print(type(img), img.size, label)
                # <class 'PIL.JpegImagePlugin.JpegImageFile'> (523, 500) [1]


            transform = T.Compose(
                [
                    T.Resize(64),
                    T.ToTensor(),
                    T.Normalize(
                        mean=[0.5, 0.5, 0.5],
                        std=[0.5, 0.5, 0.5],
                        to_rgb=True,
                    ),
                ]
            )

            flowers_test = Flowers(
                mode="test",
                transform=transform,  # apply transform to every image
                backend="cv2",  # use OpenCV as image transform backend
            )
            print(len(flowers_test))
            # 1020

            for img, label in itertools.islice(iter(flowers_test), 5):  # only show first 5 images
                # do something with img and label
                print(type(img), img.shape, label)
                # <class 'paddle.Tensor'> [3, 64, 96] [1]
110 111 112 113 114 115 116 117
    """

    def __init__(self,
                 data_file=None,
                 label_file=None,
                 setid_file=None,
                 mode='train',
                 transform=None,
118 119
                 download=True,
                 backend=None):
120 121
        assert mode.lower() in ['train', 'valid', 'test'], \
                "mode should be 'train', 'valid' or 'test', but got {}".format(mode)
122 123 124 125 126

        if backend is None:
            backend = paddle.vision.get_image_backend()
        if backend not in ['pil', 'cv2']:
            raise ValueError(
127 128
                "Expected backend are one of ['pil', 'cv2'], but got {}".format(
                    backend))
129 130
        self.backend = backend

131
        flag = MODE_FLAG_MAP[mode.lower()]
132

133
        if not data_file:
K
Kaipeng Deng 已提交
134
            assert download, "data_file is not set and downloading automatically is disabled"
135 136 137
            data_file = _check_exists_and_download(data_file, DATA_URL,
                                                   DATA_MD5, 'flowers',
                                                   download)
138

139
        if not label_file:
K
Kaipeng Deng 已提交
140
            assert download, "label_file is not set and downloading automatically is disabled"
141 142 143
            label_file = _check_exists_and_download(label_file, LABEL_URL,
                                                    LABEL_MD5, 'flowers',
                                                    download)
144

145
        if not setid_file:
K
Kaipeng Deng 已提交
146
            assert download, "setid_file is not set and downloading automatically is disabled"
147 148 149
            setid_file = _check_exists_and_download(setid_file, SETID_URL,
                                                    SETID_MD5, 'flowers',
                                                    download)
150 151 152

        self.transform = transform

153 154 155 156 157
        data_tar = tarfile.open(data_file)
        self.data_path = data_file.replace(".tgz", "/")
        if not os.path.exists(self.data_path):
            os.mkdir(self.data_path)
        data_tar.extractall(self.data_path)
158

L
LielinJiang 已提交
159
        scio = try_import('scipy.io')
160 161
        self.labels = scio.loadmat(label_file)['labels'][0]
        self.indexes = scio.loadmat(setid_file)[flag][0]
162 163 164 165 166

    def __getitem__(self, idx):
        index = self.indexes[idx]
        label = np.array([self.labels[index - 1]])
        img_name = "jpg/image_%05d.jpg" % index
167
        image = os.path.join(self.data_path, img_name)
168
        if self.backend == 'pil':
169
            image = Image.open(image)
170
        elif self.backend == 'cv2':
171
            image = np.array(Image.open(image))
172 173 174 175

        if self.transform is not None:
            image = self.transform(image)

176 177 178
        if self.backend == 'pil':
            return image, label.astype('int64')

179
        return image.astype(paddle.get_default_dtype()), label.astype('int64')
180 181 182

    def __len__(self):
        return len(self.indexes)