提交 93e8ae0b 编写于 作者: K Kaipeng Deng 提交者: GitHub

Revert "fix cifar label dimension. test=develop (#33475)"

This reverts commit 6c110344.
上级 1dfd857b
...@@ -32,8 +32,6 @@ class TestCifar10Train(unittest.TestCase): ...@@ -32,8 +32,6 @@ class TestCifar10Train(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
...@@ -51,8 +49,6 @@ class TestCifar10Test(unittest.TestCase): ...@@ -51,8 +49,6 @@ class TestCifar10Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
# test cv2 backend # test cv2 backend
...@@ -67,8 +63,6 @@ class TestCifar10Test(unittest.TestCase): ...@@ -67,8 +63,6 @@ class TestCifar10Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -89,8 +83,6 @@ class TestCifar100Train(unittest.TestCase): ...@@ -89,8 +83,6 @@ class TestCifar100Train(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
...@@ -108,8 +100,6 @@ class TestCifar100Test(unittest.TestCase): ...@@ -108,8 +100,6 @@ class TestCifar100Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
# test cv2 backend # test cv2 backend
...@@ -124,8 +114,6 @@ class TestCifar100Test(unittest.TestCase): ...@@ -124,8 +114,6 @@ class TestCifar100Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
...@@ -148,8 +148,7 @@ class Cifar10(Dataset): ...@@ -148,8 +148,7 @@ class Cifar10(Dataset):
six.b('labels'), batch.get(six.b('fine_labels'), None)) six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None assert labels is not None
for sample, label in six.moves.zip(data, labels): for sample, label in six.moves.zip(data, labels):
self.data.append((sample, self.data.append((sample, label))
np.array([label]).astype('int64')))
def __getitem__(self, idx): def __getitem__(self, idx):
image, label = self.data[idx] image, label = self.data[idx]
...@@ -162,9 +161,9 @@ class Cifar10(Dataset): ...@@ -162,9 +161,9 @@ class Cifar10(Dataset):
image = self.transform(image) image = self.transform(image)
if self.backend == 'pil': if self.backend == 'pil':
return image, label.astype('int64') return image, np.array(label).astype('int64')
return image.astype(self.dtype), label.astype('int64') return image.astype(self.dtype), np.array(label).astype('int64')
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册