mnist_test.py 811 字节
Newer Older
Y
Yi Wang 已提交
1 2 3
import paddle.v2.dataset.mnist
import unittest

Y
Yi Wang 已提交
4

Y
Yi Wang 已提交
5 6 7
class TestMNIST(unittest.TestCase):
    def check_reader(self, reader):
        sum = 0
Y
Yi Wang 已提交
8 9
        label = 0
        for l in reader():
Y
Yi Wang 已提交
10
            self.assertEqual(l[0].size, 784)
Y
Yi Wang 已提交
11 12
            if l[1] > label:
                label = l[1]
Y
Yi Wang 已提交
13
            sum += 1
Y
Yi Wang 已提交
14
        return sum, label
Y
Yi Wang 已提交
15 16

    def test_train(self):
Y
Yi Wang 已提交
17 18 19 20
        instances, max_label_value = self.check_reader(
            paddle.v2.dataset.mnist.train())
        self.assertEqual(instances, 60000)
        self.assertEqual(max_label_value, 9)
Y
Yi Wang 已提交
21 22

    def test_test(self):
Y
Yi Wang 已提交
23 24 25 26
        instances, max_label_value = self.check_reader(
            paddle.v2.dataset.mnist.test())
        self.assertEqual(instances, 10000)
        self.assertEqual(max_label_value, 9)
Y
Yi Wang 已提交
27 28 29 30


if __name__ == '__main__':
    unittest.main()