diff --git a/demo/mnist/mnist_provider.py b/demo/mnist/mnist_provider.py index c435e1681d6254aff99351661e57cd70f96a258a..4635833d36b9f21c992d96910f3ac9094ccefd2c 100644 --- a/demo/mnist/mnist_provider.py +++ b/demo/mnist/mnist_provider.py @@ -1,6 +1,7 @@ from paddle.trainer.PyDataProvider2 import * import numpy + # Define a py data provider @provider( input_types={'pixel': dense_vector(28 * 28), @@ -20,13 +21,14 @@ def process(settings, filename): # settings is not used currently. n = 60000 else: n = 10000 - - images = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28)).astype('float32') - images = images / 255.0 * 2.0 - 1.0 + + images = numpy.fromfile( + f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32') + images = images / 255.0 * 2.0 - 1.0 labels = numpy.fromfile(l, 'ubyte', count=n).astype("int") - + for i in xrange(n): yield {"pixel": images[i, :], 'label': labels[i]} - + f.close() l.close()