From 982fd0f3c259ffc39c01a043fb01d5670bbf5b65 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Tue, 24 Nov 2020 14:04:25 +0800 Subject: [PATCH] fix mnist fmnist (#29018) --- python/paddle/tests/test_datasets.py | 7 +++++++ python/paddle/vision/datasets/mnist.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/paddle/tests/test_datasets.py b/python/paddle/tests/test_datasets.py index d119d2c5cc..3aa21ae2db 100644 --- a/python/paddle/tests/test_datasets.py +++ b/python/paddle/tests/test_datasets.py @@ -179,6 +179,13 @@ class TestFASHIONMNISTTrain(unittest.TestCase): with self.assertRaises(ValueError): mnist = FashionMNIST(mode='train', transform=transform, backend=1) + def test_dataset_value(self): + fmnist = FashionMNIST(mode='train') + value = np.mean([np.array(x[0]) for x in fmnist]) + + # 72.94035223214286 was getted from competitive products + np.testing.assert_allclose(value, 72.94035223214286) + class TestFlowersTrain(unittest.TestCase): def test_main(self): diff --git a/python/paddle/vision/datasets/mnist.py b/python/paddle/vision/datasets/mnist.py index 3d752ece34..0f4d4947aa 100644 --- a/python/paddle/vision/datasets/mnist.py +++ b/python/paddle/vision/datasets/mnist.py @@ -163,7 +163,7 @@ class MNIST(Dataset): image = np.reshape(image, [28, 28]) if self.backend == 'pil': - image = Image.fromarray(image, mode='L') + image = Image.fromarray(image.astype('uint8'), mode='L') if self.transform is not None: image = self.transform(image) -- GitLab