未验证 提交 6c110344 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix cifar label dimension. test=develop (#33475)

上级 80614429
...@@ -32,6 +32,8 @@ class TestCifar10Train(unittest.TestCase): ...@@ -32,6 +32,8 @@ class TestCifar10Train(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
...@@ -49,6 +51,8 @@ class TestCifar10Test(unittest.TestCase): ...@@ -49,6 +51,8 @@ class TestCifar10Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
# test cv2 backend # test cv2 backend
...@@ -63,6 +67,8 @@ class TestCifar10Test(unittest.TestCase): ...@@ -63,6 +67,8 @@ class TestCifar10Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -83,6 +89,8 @@ class TestCifar100Train(unittest.TestCase): ...@@ -83,6 +89,8 @@ class TestCifar100Train(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
...@@ -100,6 +108,8 @@ class TestCifar100Test(unittest.TestCase): ...@@ -100,6 +108,8 @@ class TestCifar100Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
# test cv2 backend # test cv2 backend
...@@ -114,6 +124,8 @@ class TestCifar100Test(unittest.TestCase): ...@@ -114,6 +124,8 @@ class TestCifar100Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
...@@ -151,7 +151,8 @@ class Cifar10(Dataset): ...@@ -151,7 +151,8 @@ class Cifar10(Dataset):
six.b('labels'), batch.get(six.b('fine_labels'), None)) six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None assert labels is not None
for sample, label in six.moves.zip(data, labels): for sample, label in six.moves.zip(data, labels):
self.data.append((sample, label)) self.data.append((sample,
np.array([label]).astype('int64')))
def __getitem__(self, idx): def __getitem__(self, idx):
image, label = self.data[idx] image, label = self.data[idx]
...@@ -164,9 +165,9 @@ class Cifar10(Dataset): ...@@ -164,9 +165,9 @@ class Cifar10(Dataset):
image = self.transform(image) image = self.transform(image)
if self.backend == 'pil': if self.backend == 'pil':
return image, np.array(label).astype('int64') return image, label.astype('int64')
return image.astype(self.dtype), np.array(label).astype('int64') return image.astype(self.dtype), label.astype('int64')
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册