cifar.py 2.9 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
"""
CIFAR Dataset.

URL: https://www.cs.toronto.edu/~kriz/cifar.html

the default train_creator, test_creator used for CIFAR-10 dataset.
"""
from config import DATA_HOME
import os
import hashlib
import urllib2
import shutil
import tarfile
import cPickle
import itertools
import numpy

Y
Yu Yang 已提交
18 19 20 21
__all__ = [
    'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator',
    'test_creator'
]
Y
Yu Yang 已提交
22 23 24 25 26 27 28

CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'


Y
Yu Yang 已提交
29 30 31 32 33 34 35 36
def __read_batch__(filename, sub_name):
    def reader():
        def __read_one_batch_impl__(batch):
            data = batch['data']
            labels = batch.get('labels', batch.get('fine_labels', None))
            assert labels is not None
            for sample, label in itertools.izip(data, labels):
                yield (sample / 255.0).astype(numpy.float32), int(label)
Y
Yu Yang 已提交
37

Y
Yu Yang 已提交
38
        with tarfile.open(filename, mode='r') as f:
Y
Yu Yang 已提交
39 40 41 42 43
            names = (each_item.name for each_item in f
                     if sub_name in each_item.name)

            for name in names:
                batch = cPickle.load(f.extractfile(name))
Y
Yu Yang 已提交
44
                for item in __read_one_batch_impl__(batch):
Y
Yu Yang 已提交
45 46
                    yield item

Y
Yu Yang 已提交
47
    return reader
Y
Yu Yang 已提交
48 49


Y
Yu Yang 已提交
50 51 52 53 54 55 56
def download(url, md5):
    filename = os.path.split(url)[-1]
    assert DATA_HOME is not None
    filepath = os.path.join(DATA_HOME, md5)
    if not os.path.exists(filepath):
        os.makedirs(filepath)
    __full_file__ = os.path.join(filepath, filename)
Y
Yu Yang 已提交
57

Y
Yu Yang 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    def __file_ok__():
        if not os.path.exists(__full_file__):
            return False
        md5_hash = hashlib.md5()
        with open(__full_file__, 'rb') as f:
            for chunk in iter(lambda: f.read(4096), b""):
                md5_hash.update(chunk)

        return md5_hash.hexdigest() == md5

    while not __file_ok__():
        response = urllib2.urlopen(url)
        with open(__full_file__, mode='wb') as of:
            shutil.copyfileobj(fsrc=response, fdst=of)
    return __full_file__


def cifar_100_train_creator():
    fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
    return __read_batch__(fn, 'train')
Y
Yu Yang 已提交
78 79


Y
Yu Yang 已提交
80 81 82
def cifar_100_test_creator():
    fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
    return __read_batch__(fn, 'test')
Y
Yu Yang 已提交
83 84 85 86 87 88


def train_creator():
    """
    Default train reader creator. Use CIFAR-10 dataset.
    """
Y
Yu Yang 已提交
89 90
    fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
    return __read_batch__(fn, 'data_batch')
Y
Yu Yang 已提交
91 92 93 94 95 96


def test_creator():
    """
    Default test reader creator. Use CIFAR-10 dataset.
    """
Y
Yu Yang 已提交
97 98
    fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
    return __read_batch__(fn, 'test_batch')
Y
Yu Yang 已提交
99 100


Y
Yu Yang 已提交
101 102
def unittest():
    for _ in train_creator()():
Y
Yu Yang 已提交
103
        pass
Y
Yu Yang 已提交
104
    for _ in test_creator()():
Y
Yu Yang 已提交
105 106 107 108
        pass


if __name__ == '__main__':
Y
Yu Yang 已提交
109
    unittest()