From 86d81af5ef43e9bc0d9684ba91f96871a6cd8d6f Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Mon, 11 Jan 2021 20:06:49 +0800 Subject: [PATCH] reduce unittest time of test_datasets (#30275) --- python/paddle/tests/test_datasets.py | 56 ++++++++++++++-------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/python/paddle/tests/test_datasets.py b/python/paddle/tests/test_datasets.py index 3aa21ae2db2..89fa01cbceb 100644 --- a/python/paddle/tests/test_datasets.py +++ b/python/paddle/tests/test_datasets.py @@ -94,13 +94,13 @@ class TestMNISTTest(unittest.TestCase): mnist = MNIST(mode='test', transform=transform) self.assertTrue(len(mnist) == 10000) - for i in range(len(mnist)): - image, label = mnist[i] - self.assertTrue(image.shape[0] == 1) - self.assertTrue(image.shape[1] == 28) - self.assertTrue(image.shape[2] == 28) - self.assertTrue(label.shape[0] == 1) - self.assertTrue(0 <= int(label) <= 9) + i = np.random.randint(0, len(mnist) - 1) + image, label = mnist[i] + self.assertTrue(image.shape[0] == 1) + self.assertTrue(image.shape[1] == 28) + self.assertTrue(image.shape[2] == 28) + self.assertTrue(label.shape[0] == 1) + self.assertTrue(0 <= int(label) <= 9) class TestMNISTTrain(unittest.TestCase): @@ -109,13 +109,13 @@ class TestMNISTTrain(unittest.TestCase): mnist = MNIST(mode='train', transform=transform) self.assertTrue(len(mnist) == 60000) - for i in range(len(mnist)): - image, label = mnist[i] - self.assertTrue(image.shape[0] == 1) - self.assertTrue(image.shape[1] == 28) - self.assertTrue(image.shape[2] == 28) - self.assertTrue(label.shape[0] == 1) - self.assertTrue(0 <= int(label) <= 9) + i = np.random.randint(0, len(mnist) - 1) + image, label = mnist[i] + self.assertTrue(image.shape[0] == 1) + self.assertTrue(image.shape[1] == 28) + self.assertTrue(image.shape[2] == 28) + self.assertTrue(label.shape[0] == 1) + self.assertTrue(0 <= int(label) <= 9) # test cv2 backend mnist = MNIST(mode='train', transform=transform, backend='cv2') @@ -140,13 +140,13 @@ class TestFASHIONMNISTTest(unittest.TestCase): mnist = FashionMNIST(mode='test', transform=transform) self.assertTrue(len(mnist) == 10000) - for i in range(len(mnist)): - image, label = mnist[i] - self.assertTrue(image.shape[0] == 1) - self.assertTrue(image.shape[1] == 28) - self.assertTrue(image.shape[2] == 28) - self.assertTrue(label.shape[0] == 1) - self.assertTrue(0 <= int(label) <= 9) + i = np.random.randint(0, len(mnist) - 1) + image, label = mnist[i] + self.assertTrue(image.shape[0] == 1) + self.assertTrue(image.shape[1] == 28) + self.assertTrue(image.shape[2] == 28) + self.assertTrue(label.shape[0] == 1) + self.assertTrue(0 <= int(label) <= 9) class TestFASHIONMNISTTrain(unittest.TestCase): @@ -155,13 +155,13 @@ class TestFASHIONMNISTTrain(unittest.TestCase): mnist = FashionMNIST(mode='train', transform=transform) self.assertTrue(len(mnist) == 60000) - for i in range(len(mnist)): - image, label = mnist[i] - self.assertTrue(image.shape[0] == 1) - self.assertTrue(image.shape[1] == 28) - self.assertTrue(image.shape[2] == 28) - self.assertTrue(label.shape[0] == 1) - self.assertTrue(0 <= int(label) <= 9) + i = np.random.randint(0, len(mnist) - 1) + image, label = mnist[i] + self.assertTrue(image.shape[0] == 1) + self.assertTrue(image.shape[1] == 28) + self.assertTrue(image.shape[2] == 28) + self.assertTrue(label.shape[0] == 1) + self.assertTrue(0 <= int(label) <= 9) # test cv2 backend mnist = FashionMNIST(mode='train', transform=transform, backend='cv2') -- GitLab