flowers.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import io
import tarfile
import numpy as np
import scipy.io as scio
from PIL import Image

from paddle.io import Dataset
25
from paddle.dataset.common import _check_exists_and_download
26 27 28 29 30 31 32 33 34 35 36 37 38

__all__ = ["Flowers"]

DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'

# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
K
Kaipeng Deng 已提交
39
MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': 'valid'}
40 41 42 43


class Flowers(Dataset):
    """
K
Kaipeng Deng 已提交
44 45
    Implementation of `Flowers <https://www.robots.ox.ac.uk/~vgg/data/flowers/>`_
    dataset
46 47 48 49 50 51 52 53 54

    Args:
        data_file(str): path to data file, can be set None if
            :attr:`download` is True. Default None
        label_file(str): path to label file, can be set None if
            :attr:`download` is True. Default None
        setid_file(str): path to subset index file, can be set
            None if :attr:`download` is True. Default None
        mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
K
Kaipeng Deng 已提交
55 56 57
        transform(callable): transform to perform on image, None for on transform.
        download(bool): whether to download dataset automatically if
            :attr:`data_file` is not set. Default True
58 59 60 61 62

    Examples:
        
        .. code-block:: python

63
            from paddle.vision.datasets import Flowers
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

            flowers = Flowers(mode='test')

            for i in range(len(flowers)):
                sample = flowers[i]
                print(sample[0].shape, sample[1])

    """

    def __init__(self,
                 data_file=None,
                 label_file=None,
                 setid_file=None,
                 mode='train',
                 transform=None,
                 download=True):
        assert mode.lower() in ['train', 'valid', 'test'], \
                "mode should be 'train', 'valid' or 'test', but got {}".format(mode)
        self.flag = MODE_FLAG_MAP[mode.lower()]

        self.data_file = data_file
        if self.data_file is None:
K
Kaipeng Deng 已提交
86
            assert download, "data_file is not set and downloading automatically is disabled"
87 88 89 90 91
            self.data_file = _check_exists_and_download(
                data_file, DATA_URL, DATA_MD5, 'flowers', download)

        self.label_file = label_file
        if self.label_file is None:
K
Kaipeng Deng 已提交
92
            assert download, "label_file is not set and downloading automatically is disabled"
93 94 95 96 97
            self.label_file = _check_exists_and_download(
                label_file, LABEL_URL, LABEL_MD5, 'flowers', download)

        self.setid_file = setid_file
        if self.setid_file is None:
K
Kaipeng Deng 已提交
98
            assert download, "setid_file is not set and downloading automatically is disabled"
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
            self.setid_file = _check_exists_and_download(
                setid_file, SETID_URL, SETID_MD5, 'flowers', download)

        self.transform = transform

        # read dataset into memory
        self._load_anno()

    def _load_anno(self):
        self.name2mem = {}
        self.data_tar = tarfile.open(self.data_file)
        for ele in self.data_tar.getmembers():
            self.name2mem[ele.name] = ele

        self.labels = scio.loadmat(self.label_file)['labels'][0]
        self.indexes = scio.loadmat(self.setid_file)[self.flag][0]

    def __getitem__(self, idx):
        index = self.indexes[idx]
        label = np.array([self.labels[index - 1]])
        img_name = "jpg/image_%05d.jpg" % index
        img_ele = self.name2mem[img_name]
        image = self.data_tar.extractfile(img_ele).read()
        image = np.array(Image.open(io.BytesIO(image)))

        if self.transform is not None:
            image = self.transform(image)

        return image, label.astype('int64')

    def __len__(self):
        return len(self.indexes)