diff --git a/demo/mnist/mnist_provider.py b/demo/mnist/mnist_provider.py index 6df4676da3bdc2e6949cc911fa3720cb51ddc568..c435e1681d6254aff99351661e57cd70f96a258a 100644 --- a/demo/mnist/mnist_provider.py +++ b/demo/mnist/mnist_provider.py @@ -1,10 +1,11 @@ from paddle.trainer.PyDataProvider2 import * - +import numpy # Define a py data provider @provider( input_types={'pixel': dense_vector(28 * 28), - 'label': integer_value(10)}) + 'label': integer_value(10)}, + cache=CacheType.CACHE_PASS_IN_MEM) def process(settings, filename): # settings is not used currently. imgf = filename + "-images-idx3-ubyte" labelf = filename + "-labels-idx1-ubyte" @@ -19,13 +20,13 @@ def process(settings, filename): # settings is not used currently. n = 60000 else: n = 10000 - - for i in range(n): - label = ord(l.read(1)) - pixels = [] - for j in range(28 * 28): - pixels.append(float(ord(f.read(1))) / 255.0) - yield {"pixel": pixels, 'label': label} - + + images = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28)).astype('float32') + images = images / 255.0 * 2.0 - 1.0 + labels = numpy.fromfile(l, 'ubyte', count=n).astype("int") + + for i in xrange(n): + yield {"pixel": images[i, :], 'label': labels[i]} + f.close() l.close()