提交 59f7778b 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1476 from wangkuiyi/dataset

Simplify CIFAR/MNIST Data Package, Remove Scipy/sklearn package dependencies.
""" """
CIFAR Dataset. CIFAR dataset: https://www.cs.toronto.edu/~kriz/cifar.html
URL: https://www.cs.toronto.edu/~kriz/cifar.html
the default train_creator, test_creator used for CIFAR-10 dataset.
""" """
import cPickle import cPickle
import itertools import itertools
import tarfile
import numpy import numpy
import paddle.v2.dataset.common
import tarfile
from config import download __all__ = ['train100', 'test100', 'train10', 'test10']
__all__ = [
'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator',
'test_creator'
]
CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def __read_batch__(filename, sub_name): def reader_creator(filename, sub_name):
def reader(): def read_batch(batch):
def __read_one_batch_impl__(batch): data = batch['data']
data = batch['data'] labels = batch.get('labels', batch.get('fine_labels', None))
labels = batch.get('labels', batch.get('fine_labels', None)) assert labels is not None
assert labels is not None for sample, label in itertools.izip(data, labels):
for sample, label in itertools.izip(data, labels): yield (sample / 255.0).astype(numpy.float32), int(label)
yield (sample / 255.0).astype(numpy.float32), int(label)
def reader():
with tarfile.open(filename, mode='r') as f: with tarfile.open(filename, mode='r') as f:
names = (each_item.name for each_item in f names = (each_item.name for each_item in f
if sub_name in each_item.name) if sub_name in each_item.name)
for name in names: for name in names:
batch = cPickle.load(f.extractfile(name)) batch = cPickle.load(f.extractfile(name))
for item in __read_one_batch_impl__(batch): for item in read_batch(batch):
yield item yield item
return reader return reader
def cifar_100_train_creator(): def train100():
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) return reader_creator(
return __read_batch__(fn, 'train') paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
'train')
def cifar_100_test_creator():
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
return __read_batch__(fn, 'test')
def train_creator():
"""
Default train reader creator. Use CIFAR-10 dataset.
"""
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
return __read_batch__(fn, 'data_batch')
def test_creator(): def test100():
""" return reader_creator(
Default test reader creator. Use CIFAR-10 dataset. paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
""" 'test')
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
return __read_batch__(fn, 'test_batch')
def unittest(): def train10():
for _ in train_creator()(): return reader_creator(
pass paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
for _ in test_creator()(): 'data_batch')
pass
if __name__ == '__main__': def test10():
unittest() return reader_creator(
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch')
import requests
import hashlib
import os
import shutil
__all__ = ['DATA_HOME', 'download', 'md5file']
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
if not os.path.exists(DATA_HOME):
os.makedirs(DATA_HOME)
def md5file(fname):
hash_md5 = hashlib.md5()
f = open(fname, "rb")
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
f.close()
return hash_md5.hexdigest()
def download(url, module_name, md5sum):
dirname = os.path.join(DATA_HOME, module_name)
if not os.path.exists(dirname):
os.makedirs(dirname)
filename = os.path.join(dirname, url.split('/')[-1])
if not (os.path.exists(filename) and md5file(filename) == md5sum):
r = requests.get(url, stream=True)
with open(filename, 'w') as f:
shutil.copyfileobj(r.raw, f)
return filename
import hashlib
import os
import shutil
import urllib2
__all__ = ['DATA_HOME', 'download']
DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set')
if not os.path.exists(DATA_HOME):
os.makedirs(DATA_HOME)
def download(url, md5):
filename = os.path.split(url)[-1]
assert DATA_HOME is not None
filepath = os.path.join(DATA_HOME, md5)
if not os.path.exists(filepath):
os.makedirs(filepath)
__full_file__ = os.path.join(filepath, filename)
def __file_ok__():
if not os.path.exists(__full_file__):
return False
md5_hash = hashlib.md5()
with open(__full_file__, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5_hash.update(chunk)
return md5_hash.hexdigest() == md5
while not __file_ok__():
response = urllib2.urlopen(url)
with open(__full_file__, mode='wb') as of:
shutil.copyfileobj(fsrc=response, fdst=of)
return __full_file__
import sklearn.datasets.mldata """
import sklearn.model_selection MNIST dataset.
"""
import numpy import numpy
from config import DATA_HOME import paddle.v2.dataset.common
import subprocess
__all__ = ['train_creator', 'test_creator'] __all__ = ['train', 'test']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688'
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
def __mnist_reader_creator__(data, target):
def reader_creator(image_filename, label_filename, buffer_size):
def reader(): def reader():
n_samples = data.shape[0] # According to http://stackoverflow.com/a/38061619/724872, we
for i in xrange(n_samples): # cannot use standard package gzip here.
yield (data[i] / 255.0).astype(numpy.float32), int(target[i]) m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE)
m.stdout.read(16) # skip some magic bytes
return reader l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE)
l.stdout.read(8) # skip some magic bytes
while True:
labels = numpy.fromfile(
l.stdout, 'ubyte', count=buffer_size).astype("int")
TEST_SIZE = 10000 if labels.size != buffer_size:
break # numpy.fromfile returns empty slice after EOF.
data = sklearn.datasets.mldata.fetch_mldata( images = numpy.fromfile(
"MNIST original", data_home=DATA_HOME) m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( (buffer_size, 28 * 28)).astype('float32')
data.data, data.target, test_size=TEST_SIZE, random_state=0)
images = images / 255.0 * 2.0 - 1.0
def train_creator(): for i in xrange(buffer_size):
return __mnist_reader_creator__(X_train, y_train) yield images[i, :], int(labels[i])
m.terminate()
l.terminate()
def test_creator(): return reader
return __mnist_reader_creator__(X_test, y_test)
def unittest(): def train():
assert len(list(test_creator()())) == TEST_SIZE return reader_creator(
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist',
TRAIN_IMAGE_MD5),
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist',
TRAIN_LABEL_MD5), 100)
if __name__ == '__main__': def test():
unittest() return reader_creator(
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist',
TEST_IMAGE_MD5),
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist',
TEST_LABEL_MD5), 100)
import zipfile import zipfile
from config import download from common import download
import re import re
import random import random
import functools import functools
......
import paddle.v2.dataset.cifar
import unittest
class TestCIFAR(unittest.TestCase):
def check_reader(self, reader):
sum = 0
label = 0
for l in reader():
self.assertEqual(l[0].size, 3072)
if l[1] > label:
label = l[1]
sum += 1
return sum, label
def test_test10(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.test10())
self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 9)
def test_train10(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.train10())
self.assertEqual(instances, 50000)
self.assertEqual(max_label_value, 9)
def test_test100(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.test100())
self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 99)
def test_train100(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.train100())
self.assertEqual(instances, 50000)
self.assertEqual(max_label_value, 99)
if __name__ == '__main__':
unittest.main()
import paddle.v2.dataset.common
import unittest
import tempfile
class TestCommon(unittest.TestCase):
def test_md5file(self):
_, temp_path = tempfile.mkstemp()
with open(temp_path, 'w') as f:
f.write("Hello\n")
self.assertEqual('09f7e02f1290be211da707a266f153b3',
paddle.v2.dataset.common.md5file(temp_path))
def test_download(self):
yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460'
self.assertEqual(
paddle.v2.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460',
paddle.v2.dataset.common.download(
yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d'))
if __name__ == '__main__':
unittest.main()
import paddle.v2.dataset.mnist
import unittest
class TestMNIST(unittest.TestCase):
def check_reader(self, reader):
sum = 0
label = 0
for l in reader():
self.assertEqual(l[0].size, 784)
if l[1] > label:
label = l[1]
sum += 1
return sum, label
def test_train(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.mnist.train())
self.assertEqual(instances, 60000)
self.assertEqual(max_label_value, 9)
def test_test(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.mnist.test())
self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 9)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册