mnist_provider.py 916 字节
Newer Older
W
wangyang59 已提交
1
from paddle.trainer.PyDataProvider2 import *
W
wangyang59 已提交
2
import numpy
W
wangyang59 已提交
3 4

# Define a py data provider
5 6
@provider(
    input_types={'pixel': dense_vector(28 * 28),
W
wangyang59 已提交
7 8
                 'label': integer_value(10)},
    cache=CacheType.CACHE_PASS_IN_MEM)
W
wangyang59 已提交
9 10 11 12 13 14 15 16
def process(settings, filename):  # settings is not used currently.
    imgf = filename + "-images-idx3-ubyte"
    labelf = filename + "-labels-idx1-ubyte"
    f = open(imgf, "rb")
    l = open(labelf, "rb")

    f.read(16)
    l.read(8)
17

W
wangyang59 已提交
18 19 20 21 22
    # Define number of samples for train/test
    if "train" in filename:
        n = 60000
    else:
        n = 10000
W
wangyang59 已提交
23 24 25 26 27 28 29 30
    
    images = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28)).astype('float32')
    images = images / 255.0 * 2.0 - 1.0    
    labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
    
    for i in xrange(n):
        yield {"pixel": images[i, :], 'label': labels[i]}
    
W
wangyang59 已提交
31 32
    f.close()
    l.close()