mnist_test.py 691 字节
Newer Older
Y
Yi Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
import paddle.v2.dataset.mnist
import unittest

class TestMNIST(unittest.TestCase):
    def check_reader(self, reader):
        sum = 0
        for l in reader:
            self.assertEqual(l[0].size, 784)
            self.assertEqual(l[1].size, 1)
            self.assertLess(l[1], 10)
            self.assertGreaterEqual(l[1], 0)
            sum += 1
        return sum

    def test_train(self):
        self.assertEqual(
            self.check_reader(paddle.v2.dataset.mnist.train()),
            60000)

    def test_test(self):
        self.assertEqual(
            self.check_reader(paddle.v2.dataset.mnist.test()),
            10000)


if __name__ == '__main__':
    unittest.main()