未验证 提交 86d81af5 编写于 作者: L LielinJiang 提交者: GitHub

reduce unittest time of test_datasets (#30275)

上级 a0ee0914
...@@ -94,13 +94,13 @@ class TestMNISTTest(unittest.TestCase): ...@@ -94,13 +94,13 @@ class TestMNISTTest(unittest.TestCase):
mnist = MNIST(mode='test', transform=transform) mnist = MNIST(mode='test', transform=transform)
self.assertTrue(len(mnist) == 10000) self.assertTrue(len(mnist) == 10000)
for i in range(len(mnist)): i = np.random.randint(0, len(mnist) - 1)
image, label = mnist[i] image, label = mnist[i]
self.assertTrue(image.shape[0] == 1) self.assertTrue(image.shape[0] == 1)
self.assertTrue(image.shape[1] == 28) self.assertTrue(image.shape[1] == 28)
self.assertTrue(image.shape[2] == 28) self.assertTrue(image.shape[2] == 28)
self.assertTrue(label.shape[0] == 1) self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
class TestMNISTTrain(unittest.TestCase): class TestMNISTTrain(unittest.TestCase):
...@@ -109,13 +109,13 @@ class TestMNISTTrain(unittest.TestCase): ...@@ -109,13 +109,13 @@ class TestMNISTTrain(unittest.TestCase):
mnist = MNIST(mode='train', transform=transform) mnist = MNIST(mode='train', transform=transform)
self.assertTrue(len(mnist) == 60000) self.assertTrue(len(mnist) == 60000)
for i in range(len(mnist)): i = np.random.randint(0, len(mnist) - 1)
image, label = mnist[i] image, label = mnist[i]
self.assertTrue(image.shape[0] == 1) self.assertTrue(image.shape[0] == 1)
self.assertTrue(image.shape[1] == 28) self.assertTrue(image.shape[1] == 28)
self.assertTrue(image.shape[2] == 28) self.assertTrue(image.shape[2] == 28)
self.assertTrue(label.shape[0] == 1) self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
# test cv2 backend # test cv2 backend
mnist = MNIST(mode='train', transform=transform, backend='cv2') mnist = MNIST(mode='train', transform=transform, backend='cv2')
...@@ -140,13 +140,13 @@ class TestFASHIONMNISTTest(unittest.TestCase): ...@@ -140,13 +140,13 @@ class TestFASHIONMNISTTest(unittest.TestCase):
mnist = FashionMNIST(mode='test', transform=transform) mnist = FashionMNIST(mode='test', transform=transform)
self.assertTrue(len(mnist) == 10000) self.assertTrue(len(mnist) == 10000)
for i in range(len(mnist)): i = np.random.randint(0, len(mnist) - 1)
image, label = mnist[i] image, label = mnist[i]
self.assertTrue(image.shape[0] == 1) self.assertTrue(image.shape[0] == 1)
self.assertTrue(image.shape[1] == 28) self.assertTrue(image.shape[1] == 28)
self.assertTrue(image.shape[2] == 28) self.assertTrue(image.shape[2] == 28)
self.assertTrue(label.shape[0] == 1) self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
class TestFASHIONMNISTTrain(unittest.TestCase): class TestFASHIONMNISTTrain(unittest.TestCase):
...@@ -155,13 +155,13 @@ class TestFASHIONMNISTTrain(unittest.TestCase): ...@@ -155,13 +155,13 @@ class TestFASHIONMNISTTrain(unittest.TestCase):
mnist = FashionMNIST(mode='train', transform=transform) mnist = FashionMNIST(mode='train', transform=transform)
self.assertTrue(len(mnist) == 60000) self.assertTrue(len(mnist) == 60000)
for i in range(len(mnist)): i = np.random.randint(0, len(mnist) - 1)
image, label = mnist[i] image, label = mnist[i]
self.assertTrue(image.shape[0] == 1) self.assertTrue(image.shape[0] == 1)
self.assertTrue(image.shape[1] == 28) self.assertTrue(image.shape[1] == 28)
self.assertTrue(image.shape[2] == 28) self.assertTrue(image.shape[2] == 28)
self.assertTrue(label.shape[0] == 1) self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
# test cv2 backend # test cv2 backend
mnist = FashionMNIST(mode='train', transform=transform, backend='cv2') mnist = FashionMNIST(mode='train', transform=transform, backend='cv2')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册