From ff4e046378244eacb025c16eb10dbe20b86037c3 Mon Sep 17 00:00:00 2001
From: wangyang59 <wangyang59@baidu.com>
Date: Fri, 2 Dec 2016 11:20:09 -0800
Subject: [PATCH] improve demo/mnist dataProvider speed

---
 demo/mnist/mnist_provider.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/demo/mnist/mnist_provider.py b/demo/mnist/mnist_provider.py
index 6df4676da3b..c435e1681d6 100644
--- a/demo/mnist/mnist_provider.py
+++ b/demo/mnist/mnist_provider.py
@@ -1,10 +1,11 @@
 from paddle.trainer.PyDataProvider2 import *
-
+import numpy
 
 # Define a py data provider
 @provider(
     input_types={'pixel': dense_vector(28 * 28),
-                 'label': integer_value(10)})
+                 'label': integer_value(10)},
+    cache=CacheType.CACHE_PASS_IN_MEM)
 def process(settings, filename):  # settings is not used currently.
     imgf = filename + "-images-idx3-ubyte"
     labelf = filename + "-labels-idx1-ubyte"
@@ -19,13 +20,13 @@ def process(settings, filename):  # settings is not used currently.
         n = 60000
     else:
         n = 10000
-
-    for i in range(n):
-        label = ord(l.read(1))
-        pixels = []
-        for j in range(28 * 28):
-            pixels.append(float(ord(f.read(1))) / 255.0)
-        yield {"pixel": pixels, 'label': label}
-
+    
+    images = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28)).astype('float32')
+    images = images / 255.0 * 2.0 - 1.0    
+    labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
+    
+    for i in xrange(n):
+        yield {"pixel": images[i, :], 'label': labels[i]}
+    
     f.close()
     l.close()
-- 
GitLab