From 6c110344cd759338d7298d27b6b4fce967d93e0a Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Thu, 10 Jun 2021 12:51:08 +0800 Subject: [PATCH] fix cifar label dimension. test=develop (#33475) --- python/paddle/tests/test_dataset_cifar.py | 12 ++++++++++++ python/paddle/vision/datasets/cifar.py | 7 ++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/python/paddle/tests/test_dataset_cifar.py b/python/paddle/tests/test_dataset_cifar.py index abf79fb1e39..2e9efddf971 100644 --- a/python/paddle/tests/test_dataset_cifar.py +++ b/python/paddle/tests/test_dataset_cifar.py @@ -32,6 +32,8 @@ class TestCifar10Train(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) + self.assertTrue(len(label.shape) == 1) + self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 9) @@ -49,6 +51,8 @@ class TestCifar10Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) + self.assertTrue(len(label.shape) == 1) + self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 9) # test cv2 backend @@ -63,6 +67,8 @@ class TestCifar10Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) + self.assertTrue(len(label.shape) == 1) + self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) with self.assertRaises(ValueError): @@ -83,6 +89,8 @@ class TestCifar100Train(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) + self.assertTrue(len(label.shape) == 1) + self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) @@ -100,6 +108,8 @@ class TestCifar100Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) + self.assertTrue(len(label.shape) == 1) + self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) # test cv2 backend @@ -114,6 +124,8 @@ class TestCifar100Test(unittest.TestCase): self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[0] == 32) + self.assertTrue(len(label.shape) == 1) + self.assertTrue(label.shape[0] == 1) self.assertTrue(0 <= int(label) <= 99) with self.assertRaises(ValueError): diff --git a/python/paddle/vision/datasets/cifar.py b/python/paddle/vision/datasets/cifar.py index 0a0a48026af..ff3734bf7a0 100644 --- a/python/paddle/vision/datasets/cifar.py +++ b/python/paddle/vision/datasets/cifar.py @@ -151,7 +151,8 @@ class Cifar10(Dataset): six.b('labels'), batch.get(six.b('fine_labels'), None)) assert labels is not None for sample, label in six.moves.zip(data, labels): - self.data.append((sample, label)) + self.data.append((sample, + np.array([label]).astype('int64'))) def __getitem__(self, idx): image, label = self.data[idx] @@ -164,9 +165,9 @@ class Cifar10(Dataset): image = self.transform(image) if self.backend == 'pil': - return image, np.array(label).astype('int64') + return image, label.astype('int64') - return image.astype(self.dtype), np.array(label).astype('int64') + return image.astype(self.dtype), label.astype('int64') def __len__(self): return len(self.data) -- GitLab