提交 93e8ae0b 编写于 作者: K Kaipeng Deng 提交者: GitHub

Revert "fix cifar label dimension. test=develop (#33475)"

This reverts commit 6c110344.
上级 1dfd857b
......@@ -32,8 +32,6 @@ class TestCifar10Train(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)
......@@ -51,8 +49,6 @@ class TestCifar10Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)
# test cv2 backend
......@@ -67,8 +63,6 @@ class TestCifar10Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError):
......@@ -89,8 +83,6 @@ class TestCifar100Train(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)
......@@ -108,8 +100,6 @@ class TestCifar100Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)
# test cv2 backend
......@@ -124,8 +114,6 @@ class TestCifar100Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError):
......
......@@ -148,8 +148,7 @@ class Cifar10(Dataset):
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in six.moves.zip(data, labels):
self.data.append((sample,
np.array([label]).astype('int64')))
self.data.append((sample, label))
def __getitem__(self, idx):
image, label = self.data[idx]
......@@ -162,9 +161,9 @@ class Cifar10(Dataset):
image = self.transform(image)
if self.backend == 'pil':
return image, label.astype('int64')
return image, np.array(label).astype('int64')
return image.astype(self.dtype), label.astype('int64')
return image.astype(self.dtype), np.array(label).astype('int64')
def __len__(self):
return len(self.data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册