未验证 提交 6c110344 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix cifar label dimension. test=develop (#33475)

上级 80614429
......@@ -32,6 +32,8 @@ class TestCifar10Train(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)
......@@ -49,6 +51,8 @@ class TestCifar10Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)
# test cv2 backend
......@@ -63,6 +67,8 @@ class TestCifar10Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError):
......@@ -83,6 +89,8 @@ class TestCifar100Train(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)
......@@ -100,6 +108,8 @@ class TestCifar100Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)
# test cv2 backend
......@@ -114,6 +124,8 @@ class TestCifar100Test(unittest.TestCase):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError):
......
......@@ -151,7 +151,8 @@ class Cifar10(Dataset):
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in six.moves.zip(data, labels):
self.data.append((sample, label))
self.data.append((sample,
np.array([label]).astype('int64')))
def __getitem__(self, idx):
image, label = self.data[idx]
......@@ -164,9 +165,9 @@ class Cifar10(Dataset):
image = self.transform(image)
if self.backend == 'pil':
return image, np.array(label).astype('int64')
return image, label.astype('int64')
return image.astype(self.dtype), np.array(label).astype('int64')
return image.astype(self.dtype), label.astype('int64')
def __len__(self):
return len(self.data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册