diff --git a/python/paddle/tests/test_datasets.py b/python/paddle/tests/test_datasets.py index d119d2c5ccad62f85929327e4783b08b7108b63f..3aa21ae2db267d40d4483b27ed89b0da349ce0db 100644 --- a/python/paddle/tests/test_datasets.py +++ b/python/paddle/tests/test_datasets.py @@ -179,6 +179,13 @@ class TestFASHIONMNISTTrain(unittest.TestCase): with self.assertRaises(ValueError): mnist = FashionMNIST(mode='train', transform=transform, backend=1) + def test_dataset_value(self): + fmnist = FashionMNIST(mode='train') + value = np.mean([np.array(x[0]) for x in fmnist]) + + # 72.94035223214286 was getted from competitive products + np.testing.assert_allclose(value, 72.94035223214286) + class TestFlowersTrain(unittest.TestCase): def test_main(self): diff --git a/python/paddle/vision/datasets/mnist.py b/python/paddle/vision/datasets/mnist.py index 3d752ece346b7a477c385f2a06f96d696b4b9eb0..0f4d4947aa5f8b5a9cdcbac8642e0ac73e14a905 100644 --- a/python/paddle/vision/datasets/mnist.py +++ b/python/paddle/vision/datasets/mnist.py @@ -163,7 +163,7 @@ class MNIST(Dataset): image = np.reshape(image, [28, 28]) if self.backend == 'pil': - image = Image.fromarray(image, mode='L') + image = Image.fromarray(image.astype('uint8'), mode='L') if self.transform is not None: image = self.transform(image)