未验证 提交 982fd0f3 编写于 作者: L LielinJiang 提交者: GitHub

fix mnist fmnist (#29018)

上级 887a3511
...@@ -179,6 +179,13 @@ class TestFASHIONMNISTTrain(unittest.TestCase): ...@@ -179,6 +179,13 @@ class TestFASHIONMNISTTrain(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mnist = FashionMNIST(mode='train', transform=transform, backend=1) mnist = FashionMNIST(mode='train', transform=transform, backend=1)
def test_dataset_value(self):
fmnist = FashionMNIST(mode='train')
value = np.mean([np.array(x[0]) for x in fmnist])
# 72.94035223214286 was getted from competitive products
np.testing.assert_allclose(value, 72.94035223214286)
class TestFlowersTrain(unittest.TestCase): class TestFlowersTrain(unittest.TestCase):
def test_main(self): def test_main(self):
......
...@@ -163,7 +163,7 @@ class MNIST(Dataset): ...@@ -163,7 +163,7 @@ class MNIST(Dataset):
image = np.reshape(image, [28, 28]) image = np.reshape(image, [28, 28])
if self.backend == 'pil': if self.backend == 'pil':
image = Image.fromarray(image, mode='L') image = Image.fromarray(image.astype('uint8'), mode='L')
if self.transform is not None: if self.transform is not None:
image = self.transform(image) image = self.transform(image)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册