mnist_provider.py 773 字节
Newer Older
W
wangyang59 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
from paddle.trainer.PyDataProvider2 import *


# Define a py data provider
@provider(input_types=[
    dense_vector(28 * 28),
    integer_value(10)
])
def process(settings, filename):  # settings is not used currently.
    imgf = filename + "-images-idx3-ubyte"
    labelf = filename + "-labels-idx1-ubyte"
    f = open(imgf, "rb")
    l = open(labelf, "rb")

    f.read(16)
    l.read(8)
    
    # Define number of samples for train/test
    if "train" in filename:
        n = 60000
    else:
        n = 10000
    
    for i in range(n):
        label = ord(l.read(1))
        pixels = []
        for j in range(28*28):
            pixels.append(float(ord(f.read(1))) / 255.0)
        yield  { "pixel": pixels, 'label': label }
        
    f.close()
    l.close()