diff --git a/python/paddle/v2/dataset/cifar.py b/python/paddle/v2/dataset/cifar.py index accb32f117720fdee7bef89d48ee23ef7a6024d2..77c54bd268b5d988b0802a3edca91605e56f730e 100644 --- a/python/paddle/v2/dataset/cifar.py +++ b/python/paddle/v2/dataset/cifar.py @@ -1,82 +1,61 @@ """ -CIFAR Dataset. - -URL: https://www.cs.toronto.edu/~kriz/cifar.html - -the default train_creator, test_creator used for CIFAR-10 dataset. +CIFAR dataset: https://www.cs.toronto.edu/~kriz/cifar.html """ import cPickle import itertools -import tarfile - import numpy +import paddle.v2.dataset.common +import tarfile -from common import download - -__all__ = [ - 'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator', - 'test_creator' -] +__all__ = ['train100', 'test100', 'train10', 'test10'] -CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' +URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/' +CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz' CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' -CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' +CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz' CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' -def __read_batch__(filename, sub_name): - def reader(): - def __read_one_batch_impl__(batch): - data = batch['data'] - labels = batch.get('labels', batch.get('fine_labels', None)) - assert labels is not None - for sample, label in itertools.izip(data, labels): - yield (sample / 255.0).astype(numpy.float32), int(label) +def reader_creator(filename, sub_name): + def read_batch(batch): + data = batch['data'] + labels = batch.get('labels', batch.get('fine_labels', None)) + assert labels is not None + for sample, label in itertools.izip(data, labels): + yield (sample / 255.0).astype(numpy.float32), int(label) + def reader(): with tarfile.open(filename, mode='r') as f: names = (each_item.name for each_item in f if sub_name in each_item.name) for name in names: batch = cPickle.load(f.extractfile(name)) - for item in __read_one_batch_impl__(batch): + for item in read_batch(batch): yield item return reader -def cifar_100_train_creator(): - fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) - return __read_batch__(fn, 'train') - - -def cifar_100_test_creator(): - fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) - return __read_batch__(fn, 'test') - - -def train_creator(): - """ - Default train reader creator. Use CIFAR-10 dataset. - """ - fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5) - return __read_batch__(fn, 'data_batch') +def train100(): + return reader_creator( + paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5), + 'train') -def test_creator(): - """ - Default test reader creator. Use CIFAR-10 dataset. - """ - fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5) - return __read_batch__(fn, 'test_batch') +def test100(): + return reader_creator( + paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5), + 'test') -def unittest(): - for _ in train_creator()(): - pass - for _ in test_creator()(): - pass +def train10(): + return reader_creator( + paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), + 'data_batch') -if __name__ == '__main__': - unittest() +def test10(): + return reader_creator( + paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), + 'test_batch') diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index b1831f38afb4f15796e4eaaacce6bc37f975578a..a5ffe25a116e9be039bdebaaaad435685e23d372 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -27,7 +27,6 @@ def download(url, module_name, md5sum): filename = os.path.join(dirname, url.split('/')[-1]) if not (os.path.exists(filename) and md5file(filename) == md5sum): - # If file doesn't exist or MD5 doesn't match, then download. r = requests.get(url, stream=True) with open(filename, 'w') as f: shutil.copyfileobj(r.raw, f) diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index 8ba11ca5ec7943032ba5dbd5de48b1be38786010..a36c20e3fa3734bdc14c1f47779a61375f298511 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -1,11 +1,13 @@ +""" +MNIST dataset. +""" +import numpy import paddle.v2.dataset.common import subprocess -import numpy __all__ = ['train', 'test'] URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' - TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz' TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6' TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz' @@ -40,12 +42,12 @@ def reader_creator(image_filename, label_filename, buffer_size): images = images / 255.0 * 2.0 - 1.0 for i in xrange(buffer_size): - yield images[i, :], labels[i] + yield images[i, :], int(labels[i]) m.terminate() l.terminate() - return reader() + return reader def train(): diff --git a/python/paddle/v2/dataset/tests/cifar_test.py b/python/paddle/v2/dataset/tests/cifar_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a2af45ecf508462fe4b596b5d8d6401c5b974eff --- /dev/null +++ b/python/paddle/v2/dataset/tests/cifar_test.py @@ -0,0 +1,42 @@ +import paddle.v2.dataset.cifar +import unittest + + +class TestCIFAR(unittest.TestCase): + def check_reader(self, reader): + sum = 0 + label = 0 + for l in reader(): + self.assertEqual(l[0].size, 3072) + if l[1] > label: + label = l[1] + sum += 1 + return sum, label + + def test_test10(self): + instances, max_label_value = self.check_reader( + paddle.v2.dataset.cifar.test10()) + self.assertEqual(instances, 10000) + self.assertEqual(max_label_value, 9) + + def test_train10(self): + instances, max_label_value = self.check_reader( + paddle.v2.dataset.cifar.train10()) + self.assertEqual(instances, 50000) + self.assertEqual(max_label_value, 9) + + def test_test100(self): + instances, max_label_value = self.check_reader( + paddle.v2.dataset.cifar.test100()) + self.assertEqual(instances, 10000) + self.assertEqual(max_label_value, 99) + + def test_train100(self): + instances, max_label_value = self.check_reader( + paddle.v2.dataset.cifar.train100()) + self.assertEqual(instances, 50000) + self.assertEqual(max_label_value, 99) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/dataset/tests/mnist_test.py b/python/paddle/v2/dataset/tests/mnist_test.py index e4f0b33d5207b2590fbafa8969fdef741a5e2848..b4408cc2f590d4d8da4ce5e98213cf7b208cfc15 100644 --- a/python/paddle/v2/dataset/tests/mnist_test.py +++ b/python/paddle/v2/dataset/tests/mnist_test.py @@ -5,21 +5,25 @@ import unittest class TestMNIST(unittest.TestCase): def check_reader(self, reader): sum = 0 - for l in reader: + label = 0 + for l in reader(): self.assertEqual(l[0].size, 784) - self.assertEqual(l[1].size, 1) - self.assertLess(l[1], 10) - self.assertGreaterEqual(l[1], 0) + if l[1] > label: + label = l[1] sum += 1 - return sum + return sum, label def test_train(self): - self.assertEqual( - self.check_reader(paddle.v2.dataset.mnist.train()), 60000) + instances, max_label_value = self.check_reader( + paddle.v2.dataset.mnist.train()) + self.assertEqual(instances, 60000) + self.assertEqual(max_label_value, 9) def test_test(self): - self.assertEqual( - self.check_reader(paddle.v2.dataset.mnist.test()), 10000) + instances, max_label_value = self.check_reader( + paddle.v2.dataset.mnist.test()) + self.assertEqual(instances, 10000) + self.assertEqual(max_label_value, 9) if __name__ == '__main__':