From 1f6f2235e0751248ee1a684667023b78834f9d1f Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Tue, 20 Jul 2021 14:31:10 +0800 Subject: [PATCH] Revert "fix cifar label dimension. test=develop (#33475)" (#34242) This reverts commit 6c110344cd759338d7298d27b6b4fce967d93e0a. --- python/paddle/tests/test_dataset_cifar.py | 12 ------------ python/paddle/vision/datasets/cifar.py | 7 +++---- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/python/paddle/tests/test_dataset_cifar.py b/python/paddle/tests/test_dataset_cifar.py index 2e9efddf97..abf79fb1e3 100644 --- a/python/paddle/tests/test_dataset_cifar.py +++ b/python/paddle/tests/test_dataset_cifar.py @@ -32,8 +32,6 @@ class TestCifar10Train(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 9) @@ -51,8 +49,6 @@ class TestCifar10Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 9) # test cv2 backend @@ -67,8 +63,6 @@ class TestCifar10Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) with self.assertRaises(ValueError): @@ -89,8 +83,6 @@ class TestCifar100Train(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) @@ -108,8 +100,6 @@ class TestCifar100Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) # test cv2 backend @@ -124,8 +114,6 @@ class TestCifar100Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) - self.assertTrue(len(label.shape) == 1) - self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) with self.assertRaises(ValueError): diff --git a/python/paddle/vision/datasets/cifar.py b/python/paddle/vision/datasets/cifar.py index 97ffb239fe..74ae8ef11e 100644 --- a/python/paddle/vision/datasets/cifar.py +++ b/python/paddle/vision/datasets/cifar.py @@ -148,8 +148,7 @@ class Cifar10(Dataset): six.b('labels'), batch.get(six.b('fine_labels'), None)) assert labels is not None for sample, label in six.moves.zip(data, labels): - self.data.append((sample, - np.array([label]).astype('int64'))) + self.data.append((sample, label)) def __getitem__(self, idx): image, label = self.data[idx] @@ -162,9 +161,9 @@ class Cifar10(Dataset): image = self.transform(image) if self.backend == 'pil': - return image, label.astype('int64') + return image, np.array(label).astype('int64') - return image.astype(self.dtype), label.astype('int64') + return image.astype(self.dtype), np.array(label).astype('int64') def __len__(self): return len(self.data) -- GitLab