cifar.py 8.5 KB
Newer Older
K
Kaipeng Deng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import tarfile
import numpy as np
import six
20
from PIL import Image
K
Kaipeng Deng 已提交
21 22
from six.moves import cPickle as pickle

23
import paddle
K
Kaipeng Deng 已提交
24
from paddle.io import Dataset
25
from paddle.dataset.common import _check_exists_and_download
K
Kaipeng Deng 已提交
26

27
__all__ = []
K
Kaipeng Deng 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

URL_PREFIX = 'https://dataset.bj.bcebos.com/cifar/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'

MODE_FLAG_MAP = {
    'train10': 'data_batch',
    'test10': 'test_batch',
    'train100': 'train',
    'test100': 'test'
}


class Cifar10(Dataset):
    """
    Implementation of `Cifar-10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_
    dataset, which has 10 categories.

    Args:
        data_file(str): path to data file, can be set None if
50
            :attr:`download` is True. Default None, default data path: ~/.cache/paddle/dataset/cifar
K
Kaipeng Deng 已提交
51
        mode(str): 'train', 'test' mode. Default 'train'.
52 53
        transform(callable): transform to perform on image, None for no transform.
        download(bool): download dataset automatically if :attr:`data_file` is None. Default True
54 55 56 57
        backend(str, optional): Specifies which type of image to be returned: 
            PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}. 
            If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
            default backend is 'pil'. Default: None.
K
Kaipeng Deng 已提交
58 59 60 61 62 63 64 65

    Returns:
        Dataset: instance of cifar-10 dataset

    Examples:

        .. code-block:: python

66 67 68 69
            import paddle
            import paddle.nn as nn
            from paddle.vision.datasets import Cifar10
            from paddle.vision.transforms import Normalize
K
Kaipeng Deng 已提交
70

71 72 73 74 75 76
            class SimpleNet(paddle.nn.Layer):
                def __init__(self):
                    super(SimpleNet, self).__init__()
                    self.fc = nn.Sequential(
                        nn.Linear(3072, 10),
                        nn.Softmax())
K
Kaipeng Deng 已提交
77

78
                def forward(self, image, label):
79
                    image = paddle.reshape(image, (1, -1))
80
                    return self.fc(image), label
K
Kaipeng Deng 已提交
81 82


83
            normalize = Normalize(mean=[0.5, 0.5, 0.5],
84 85
                                  std=[0.5, 0.5, 0.5],
                                  data_format='HWC')
86
            cifar10 = Cifar10(mode='train', transform=normalize)
K
Kaipeng Deng 已提交
87

88 89 90 91
            for i in range(10):
                image, label = cifar10[i]
                image = paddle.to_tensor(image)
                label = paddle.to_tensor(label)
K
Kaipeng Deng 已提交
92

93 94 95
                model = SimpleNet()
                image, label = model(image, label)
                print(image.numpy().shape, label.numpy().shape)
K
Kaipeng Deng 已提交
96 97 98 99 100 101 102

    """

    def __init__(self,
                 data_file=None,
                 mode='train',
                 transform=None,
103 104
                 download=True,
                 backend=None):
K
Kaipeng Deng 已提交
105 106 107 108
        assert mode.lower() in ['train', 'test', 'train', 'test'], \
            "mode should be 'train10', 'test10', 'train100' or 'test100', but got {}".format(mode)
        self.mode = mode.lower()

109 110 111 112
        if backend is None:
            backend = paddle.vision.get_image_backend()
        if backend not in ['pil', 'cv2']:
            raise ValueError(
113 114
                "Expected backend are one of ['pil', 'cv2'], but got {}".format(
                    backend))
115 116
        self.backend = backend

K
Kaipeng Deng 已提交
117 118 119 120 121
        self._init_url_md5_flag()

        self.data_file = data_file
        if self.data_file is None:
            assert download, "data_file is not set and downloading automatically is disabled"
122 123 124 125
            self.data_file = _check_exists_and_download(data_file,
                                                        self.data_url,
                                                        self.data_md5, 'cifar',
                                                        download)
K
Kaipeng Deng 已提交
126 127 128 129 130 131

        self.transform = transform

        # read dataset into memory
        self._load_data()

132 133
        self.dtype = paddle.get_default_dtype()

K
Kaipeng Deng 已提交
134 135 136 137 138 139 140 141 142 143 144
    def _init_url_md5_flag(self):
        self.data_url = CIFAR10_URL
        self.data_md5 = CIFAR10_MD5
        self.flag = MODE_FLAG_MAP[self.mode + '10']

    def _load_data(self):
        self.data = []
        with tarfile.open(self.data_file, mode='r') as f:
            names = (each_item.name for each_item in f
                     if self.flag in each_item.name)

145 146
            names = sorted(list(names))

K
Kaipeng Deng 已提交
147
            for name in names:
T
tianshuo78520a 已提交
148
                batch = pickle.load(f.extractfile(name), encoding='bytes')
K
Kaipeng Deng 已提交
149 150

                data = batch[six.b('data')]
151 152
                labels = batch.get(six.b('labels'),
                                   batch.get(six.b('fine_labels'), None))
K
Kaipeng Deng 已提交
153 154
                assert labels is not None
                for sample, label in six.moves.zip(data, labels):
155
                    self.data.append((sample, label))
K
Kaipeng Deng 已提交
156 157 158

    def __getitem__(self, idx):
        image, label = self.data[idx]
159
        image = np.reshape(image, [3, 32, 32])
160 161 162
        image = image.transpose([1, 2, 0])

        if self.backend == 'pil':
L
LielinJiang 已提交
163
            image = Image.fromarray(image.astype('uint8'))
K
Kaipeng Deng 已提交
164 165
        if self.transform is not None:
            image = self.transform(image)
166 167

        if self.backend == 'pil':
168
            return image, np.array(label).astype('int64')
169

170
        return image.astype(self.dtype), np.array(label).astype('int64')
K
Kaipeng Deng 已提交
171 172 173 174 175 176 177 178 179 180 181 182

    def __len__(self):
        return len(self.data)


class Cifar100(Cifar10):
    """
    Implementation of `Cifar-100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_
    dataset, which has 100 categories.

    Args:
        data_file(str): path to data file, can be set None if
183
            :attr:`download` is True. Default None, default data path: ~/.cache/paddle/dataset/cifar
K
Kaipeng Deng 已提交
184
        mode(str): 'train', 'test' mode. Default 'train'.
185 186
        transform(callable): transform to perform on image, None for no transform.
        download(bool): download dataset automatically if :attr:`data_file` is None. Default True
187 188 189 190
        backend(str, optional): Specifies which type of image to be returned: 
            PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}. 
            If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
            default backend is 'pil'. Default: None.
K
Kaipeng Deng 已提交
191 192 193 194 195 196 197 198

    Returns:
        Dataset: instance of cifar-100 dataset

    Examples:

        .. code-block:: python

199 200 201 202
            import paddle
            import paddle.nn as nn
            from paddle.vision.datasets import Cifar100
            from paddle.vision.transforms import Normalize
K
Kaipeng Deng 已提交
203

204 205 206 207 208 209
            class SimpleNet(paddle.nn.Layer):
                def __init__(self):
                    super(SimpleNet, self).__init__()
                    self.fc = nn.Sequential(
                        nn.Linear(3072, 10),
                        nn.Softmax())
K
Kaipeng Deng 已提交
210

211
                def forward(self, image, label):
212
                    image = paddle.reshape(image, (1, -1))
213
                    return self.fc(image), label
K
Kaipeng Deng 已提交
214 215


216
            normalize = Normalize(mean=[0.5, 0.5, 0.5],
217 218
                                  std=[0.5, 0.5, 0.5],
                                  data_format='HWC')
219
            cifar100 = Cifar100(mode='train', transform=normalize)
K
Kaipeng Deng 已提交
220

221 222 223 224
            for i in range(10):
                image, label = cifar100[i]
                image = paddle.to_tensor(image)
                label = paddle.to_tensor(label)
K
Kaipeng Deng 已提交
225

226 227 228
                model = SimpleNet()
                image, label = model(image, label)
                print(image.numpy().shape, label.numpy().shape)
K
Kaipeng Deng 已提交
229 230 231 232 233 234 235

    """

    def __init__(self,
                 data_file=None,
                 mode='train',
                 transform=None,
236 237 238 239
                 download=True,
                 backend=None):
        super(Cifar100, self).__init__(data_file, mode, transform, download,
                                       backend)
K
Kaipeng Deng 已提交
240 241 242 243 244

    def _init_url_md5_flag(self):
        self.data_url = CIFAR100_URL
        self.data_md5 = CIFAR100_MD5
        self.flag = MODE_FLAG_MAP[self.mode + '100']