提交 559efcdc 编写于 作者: Y Yu Yang

Merge branch 'develop' of github.com:baidu/Paddle into feature/clean_mnist_v2

""" """
CIFAR Dataset. CIFAR dataset: https://www.cs.toronto.edu/~kriz/cifar.html
URL: https://www.cs.toronto.edu/~kriz/cifar.html
the default train_creator, test_creator used for CIFAR-10 dataset.
""" """
import cPickle import cPickle
import itertools import itertools
import tarfile
import numpy import numpy
import paddle.v2.dataset.common
import tarfile
from common import download __all__ = ['train100', 'test100', 'train10', 'test10']
__all__ = [
'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator',
'test_creator'
]
CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def __read_batch__(filename, sub_name): def reader_creator(filename, sub_name):
def reader(): def read_batch(batch):
def __read_one_batch_impl__(batch): data = batch['data']
data = batch['data'] labels = batch.get('labels', batch.get('fine_labels', None))
labels = batch.get('labels', batch.get('fine_labels', None)) assert labels is not None
assert labels is not None for sample, label in itertools.izip(data, labels):
for sample, label in itertools.izip(data, labels): yield (sample / 255.0).astype(numpy.float32), int(label)
yield (sample / 255.0).astype(numpy.float32), int(label)
def reader():
with tarfile.open(filename, mode='r') as f: with tarfile.open(filename, mode='r') as f:
names = (each_item.name for each_item in f names = (each_item.name for each_item in f
if sub_name in each_item.name) if sub_name in each_item.name)
for name in names: for name in names:
batch = cPickle.load(f.extractfile(name)) batch = cPickle.load(f.extractfile(name))
for item in __read_one_batch_impl__(batch): for item in read_batch(batch):
yield item yield item
return reader return reader
def cifar_100_train_creator(): def train100():
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) return reader_creator(
return __read_batch__(fn, 'train') paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
'train')
def cifar_100_test_creator():
fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5)
return __read_batch__(fn, 'test')
def train_creator():
"""
Default train reader creator. Use CIFAR-10 dataset.
"""
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
return __read_batch__(fn, 'data_batch')
def test_creator(): def test100():
""" return reader_creator(
Default test reader creator. Use CIFAR-10 dataset. paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
""" 'test')
fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5)
return __read_batch__(fn, 'test_batch')
def unittest(): def train10():
for _ in train_creator()(): return reader_creator(
pass paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
for _ in test_creator()(): 'data_batch')
pass
if __name__ == '__main__': def test10():
unittest() return reader_creator(
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch')
...@@ -27,7 +27,6 @@ def download(url, module_name, md5sum): ...@@ -27,7 +27,6 @@ def download(url, module_name, md5sum):
filename = os.path.join(dirname, url.split('/')[-1]) filename = os.path.join(dirname, url.split('/')[-1])
if not (os.path.exists(filename) and md5file(filename) == md5sum): if not (os.path.exists(filename) and md5file(filename) == md5sum):
# If file doesn't exist or MD5 doesn't match, then download.
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
with open(filename, 'w') as f: with open(filename, 'w') as f:
shutil.copyfileobj(r.raw, f) shutil.copyfileobj(r.raw, f)
......
"""
MNIST dataset.
"""
import paddle.v2.dataset.common import paddle.v2.dataset.common
import subprocess import subprocess
import numpy import numpy
import platform import platform
__all__ = ['train', 'test'] __all__ = ['train', 'test']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz' TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6' TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz' TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
...@@ -48,7 +49,7 @@ def reader_creator(image_filename, label_filename, buffer_size): ...@@ -48,7 +49,7 @@ def reader_creator(image_filename, label_filename, buffer_size):
images = images / 255.0 * 2.0 - 1.0 images = images / 255.0 * 2.0 - 1.0
for i in xrange(buffer_size): for i in xrange(buffer_size):
yield images[i, :], labels[i] yield images[i, :], int(labels[i])
m.terminate() m.terminate()
l.terminate() l.terminate()
......
import paddle.v2.dataset.cifar
import unittest
class TestCIFAR(unittest.TestCase):
def check_reader(self, reader):
sum = 0
label = 0
for l in reader():
self.assertEqual(l[0].size, 3072)
if l[1] > label:
label = l[1]
sum += 1
return sum, label
def test_test10(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.test10())
self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 9)
def test_train10(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.train10())
self.assertEqual(instances, 50000)
self.assertEqual(max_label_value, 9)
def test_test100(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.test100())
self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 99)
def test_train100(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.train100())
self.assertEqual(instances, 50000)
self.assertEqual(max_label_value, 99)
if __name__ == '__main__':
unittest.main()
...@@ -5,21 +5,25 @@ import unittest ...@@ -5,21 +5,25 @@ import unittest
class TestMNIST(unittest.TestCase): class TestMNIST(unittest.TestCase):
def check_reader(self, reader): def check_reader(self, reader):
sum = 0 sum = 0
for l in reader: label = 0
for l in reader():
self.assertEqual(l[0].size, 784) self.assertEqual(l[0].size, 784)
self.assertEqual(l[1].size, 1) if l[1] > label:
self.assertLess(l[1], 10) label = l[1]
self.assertGreaterEqual(l[1], 0)
sum += 1 sum += 1
return sum return sum, label
def test_train(self): def test_train(self):
self.assertEqual( instances, max_label_value = self.check_reader(
self.check_reader(paddle.v2.dataset.mnist.train()), 60000) paddle.v2.dataset.mnist.train())
self.assertEqual(instances, 60000)
self.assertEqual(max_label_value, 9)
def test_test(self): def test_test(self):
self.assertEqual( instances, max_label_value = self.check_reader(
self.check_reader(paddle.v2.dataset.mnist.test()), 10000) paddle.v2.dataset.mnist.test())
self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 9)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册