mnist_test.py 668 字节
Newer Older
Y
Yi Wang 已提交
1 2 3
import paddle.v2.dataset.mnist
import unittest

Y
Yi Wang 已提交
4

Y
Yi Wang 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17
class TestMNIST(unittest.TestCase):
    def check_reader(self, reader):
        sum = 0
        for l in reader:
            self.assertEqual(l[0].size, 784)
            self.assertEqual(l[1].size, 1)
            self.assertLess(l[1], 10)
            self.assertGreaterEqual(l[1], 0)
            sum += 1
        return sum

    def test_train(self):
        self.assertEqual(
Y
Yi Wang 已提交
18
            self.check_reader(paddle.v2.dataset.mnist.train()), 60000)
Y
Yi Wang 已提交
19 20 21

    def test_test(self):
        self.assertEqual(
Y
Yi Wang 已提交
22
            self.check_reader(paddle.v2.dataset.mnist.test()), 10000)
Y
Yi Wang 已提交
23 24 25 26


if __name__ == '__main__':
    unittest.main()