flowers.py 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import io
import tarfile
import numpy as np
import scipy.io as scio
from PIL import Image

24
import paddle
25
from paddle.io import Dataset
26
from paddle.dataset.common import _check_exists_and_download
27 28 29 30 31 32 33 34 35 36 37 38 39

__all__ = ["Flowers"]

DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'

# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
K
Kaipeng Deng 已提交
40
MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': 'valid'}
41 42 43 44


class Flowers(Dataset):
    """
K
Kaipeng Deng 已提交
45 46
    Implementation of `Flowers <https://www.robots.ox.ac.uk/~vgg/data/flowers/>`_
    dataset
47 48 49 50 51 52 53 54 55

    Args:
        data_file(str): path to data file, can be set None if
            :attr:`download` is True. Default None
        label_file(str): path to label file, can be set None if
            :attr:`download` is True. Default None
        setid_file(str): path to subset index file, can be set
            None if :attr:`download` is True. Default None
        mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
K
Kaipeng Deng 已提交
56 57 58
        transform(callable): transform to perform on image, None for on transform.
        download(bool): whether to download dataset automatically if
            :attr:`data_file` is not set. Default True
59 60 61 62 63

    Examples:
        
        .. code-block:: python

64
            from paddle.vision.datasets import Flowers
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

            flowers = Flowers(mode='test')

            for i in range(len(flowers)):
                sample = flowers[i]
                print(sample[0].shape, sample[1])

    """

    def __init__(self,
                 data_file=None,
                 label_file=None,
                 setid_file=None,
                 mode='train',
                 transform=None,
                 download=True):
        assert mode.lower() in ['train', 'valid', 'test'], \
                "mode should be 'train', 'valid' or 'test', but got {}".format(mode)
        self.flag = MODE_FLAG_MAP[mode.lower()]

        self.data_file = data_file
        if self.data_file is None:
K
Kaipeng Deng 已提交
87
            assert download, "data_file is not set and downloading automatically is disabled"
88 89 90 91 92
            self.data_file = _check_exists_and_download(
                data_file, DATA_URL, DATA_MD5, 'flowers', download)

        self.label_file = label_file
        if self.label_file is None:
K
Kaipeng Deng 已提交
93
            assert download, "label_file is not set and downloading automatically is disabled"
94 95 96 97 98
            self.label_file = _check_exists_and_download(
                label_file, LABEL_URL, LABEL_MD5, 'flowers', download)

        self.setid_file = setid_file
        if self.setid_file is None:
K
Kaipeng Deng 已提交
99
            assert download, "setid_file is not set and downloading automatically is disabled"
100 101 102 103 104 105 106 107
            self.setid_file = _check_exists_and_download(
                setid_file, SETID_URL, SETID_MD5, 'flowers', download)

        self.transform = transform

        # read dataset into memory
        self._load_anno()

108 109
        self.dtype = paddle.get_default_dtype()

110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
    def _load_anno(self):
        self.name2mem = {}
        self.data_tar = tarfile.open(self.data_file)
        for ele in self.data_tar.getmembers():
            self.name2mem[ele.name] = ele

        self.labels = scio.loadmat(self.label_file)['labels'][0]
        self.indexes = scio.loadmat(self.setid_file)[self.flag][0]

    def __getitem__(self, idx):
        index = self.indexes[idx]
        label = np.array([self.labels[index - 1]])
        img_name = "jpg/image_%05d.jpg" % index
        img_ele = self.name2mem[img_name]
        image = self.data_tar.extractfile(img_ele).read()
        image = np.array(Image.open(io.BytesIO(image)))

        if self.transform is not None:
            image = self.transform(image)

130
        return image.astype(self.dtype), label.astype('int64')
131 132 133

    def __len__(self):
        return len(self.indexes)