From dcbfbb15338e0ca0f195e12ce0e0275995622ca1 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 28 Feb 2017 02:46:31 +0000 Subject: [PATCH] yapf format --- python/paddle/v2/dataset/mnist.py | 34 +++++++++---------- python/paddle/v2/dataset/tests/common_test.py | 9 ++--- python/paddle/v2/dataset/tests/mnist_test.py | 7 ++-- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index 29fc20eae9b..ec334d39e6c 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -22,23 +22,21 @@ def reader_creator(image_filename, label_filename, buffer_size): # According to http://stackoverflow.com/a/38061619/724872, we # cannot use standard package gzip here. m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE) - m.stdout.read(16) # skip some magic bytes + m.stdout.read(16) # skip some magic bytes l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE) - l.stdout.read(8) # skip some magic bytes + l.stdout.read(8) # skip some magic bytes while True: labels = numpy.fromfile( - l.stdout, 'ubyte', count=buffer_size - ).astype("int") + l.stdout, 'ubyte', count=buffer_size).astype("int") if labels.size != buffer_size: - break # numpy.fromfile returns empty slice after EOF. + break # numpy.fromfile returns empty slice after EOF. images = numpy.fromfile( - m.stdout, 'ubyte', count=buffer_size * 28 * 28 - ).reshape((buffer_size, 28 * 28) - ).astype('float32') + m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape( + (buffer_size, 28 * 28)).astype('float32') images = images / 255.0 * 2.0 - 1.0 @@ -50,18 +48,18 @@ def reader_creator(image_filename, label_filename, buffer_size): return reader() + def train(): return reader_creator( - paddle.v2.dataset.common.download( - TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5), - paddle.v2.dataset.common.download( - TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5), - 100) + paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', + TRAIN_IMAGE_MD5), + paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', + TRAIN_LABEL_MD5), 100) + def test(): return reader_creator( - paddle.v2.dataset.common.download( - TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5), - paddle.v2.dataset.common.download( - TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5), - 100) + paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', + TEST_IMAGE_MD5), + paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', + TEST_LABEL_MD5), 100) diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py index 0672a467143..7d8406171b8 100644 --- a/python/paddle/v2/dataset/tests/common_test.py +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -2,14 +2,14 @@ import paddle.v2.dataset.common import unittest import tempfile + class TestCommon(unittest.TestCase): def test_md5file(self): - _, temp_path =tempfile.mkstemp() + _, temp_path = tempfile.mkstemp() with open(temp_path, 'w') as f: f.write("Hello\n") - self.assertEqual( - '09f7e02f1290be211da707a266f153b3', - paddle.v2.dataset.common.md5file(temp_path)) + self.assertEqual('09f7e02f1290be211da707a266f153b3', + paddle.v2.dataset.common.md5file(temp_path)) def test_download(self): yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460' @@ -18,5 +18,6 @@ class TestCommon(unittest.TestCase): paddle.v2.dataset.common.download( yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d')) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/dataset/tests/mnist_test.py b/python/paddle/v2/dataset/tests/mnist_test.py index 23ed2eaba8a..e4f0b33d520 100644 --- a/python/paddle/v2/dataset/tests/mnist_test.py +++ b/python/paddle/v2/dataset/tests/mnist_test.py @@ -1,6 +1,7 @@ import paddle.v2.dataset.mnist import unittest + class TestMNIST(unittest.TestCase): def check_reader(self, reader): sum = 0 @@ -14,13 +15,11 @@ class TestMNIST(unittest.TestCase): def test_train(self): self.assertEqual( - self.check_reader(paddle.v2.dataset.mnist.train()), - 60000) + self.check_reader(paddle.v2.dataset.mnist.train()), 60000) def test_test(self): self.assertEqual( - self.check_reader(paddle.v2.dataset.mnist.test()), - 10000) + self.check_reader(paddle.v2.dataset.mnist.test()), 10000) if __name__ == '__main__': -- GitLab