diff --git a/python/paddle/tests/test_dataset_cifar.py b/python/paddle/tests/test_dataset_cifar.py index 2e9efddf9712e35423e19dad02f738c40dbc8b51..abf79fb1e3974ce0c1d9de4efd1df05056ff3821 100644 --- a/python/paddle/tests/test_dataset_cifar.py +++ b/python/paddle/tests/test_dataset_cifar.py @@ -32,8 +32,6 @@ class TestCifar10Train(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 9) @@ -51,8 +49,6 @@ class TestCifar10Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 9) # test cv2 backend @@ -67,8 +63,6 @@ class TestCifar10Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) with self.assertRaises(ValueError): @@ -89,8 +83,6 @@ class TestCifar100Train(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) @@ -108,8 +100,6 @@ class TestCifar100Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) # test cv2 backend @@ -124,8 +114,6 @@ class TestCifar100Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) with self.assertRaises(ValueError): diff --git a/python/paddle/vision/datasets/cifar.py b/python/paddle/vision/datasets/cifar.py index 97ffb239fe7adf6d9482f765d64aaf460926c566..74ae8ef11e3de0464518ff2311d8f6e12b21f8e7 100644 --- a/python/paddle/vision/datasets/cifar.py +++ b/python/paddle/vision/datasets/cifar.py @@ -148,8 +148,7 @@ class Cifar10(Dataset): six.b('labels'), batch.get(six.b('fine_labels'), None)) assert labels is not None for sample, label in six.moves.zip(data, labels): - self.data.append((sample, - np.array([label]).astype('int64'))) + self.data.append((sample, label)) def __getitem__(self, idx): image, label = self.data[idx] @@ -162,9 +161,9 @@ class Cifar10(Dataset): image = self.transform(image) if self.backend == 'pil': - return image, label.astype('int64') + return image, np.array(label).astype('int64') - return image.astype(self.dtype), label.astype('int64') + return image.astype(self.dtype), np.array(label).astype('int64') def __len__(self): return len(self.data)