mnist_provider.py 782 字节
Newer Older
W
wangyang59 已提交
1 2 3 4
from paddle.trainer.PyDataProvider2 import *


# Define a py data provider
5 6 7
@provider(
    input_types={'pixel': dense_vector(28 * 28),
                 'label': integer_value(10)})
W
wangyang59 已提交
8 9 10 11 12 13 14 15
def process(settings, filename):  # settings is not used currently.
    imgf = filename + "-images-idx3-ubyte"
    labelf = filename + "-labels-idx1-ubyte"
    f = open(imgf, "rb")
    l = open(labelf, "rb")

    f.read(16)
    l.read(8)
16

W
wangyang59 已提交
17 18 19 20 21
    # Define number of samples for train/test
    if "train" in filename:
        n = 60000
    else:
        n = 10000
22

W
wangyang59 已提交
23 24 25
    for i in range(n):
        label = ord(l.read(1))
        pixels = []
26
        for j in range(28 * 28):
W
wangyang59 已提交
27
            pixels.append(float(ord(f.read(1))) / 255.0)
28 29
        yield {"pixel": pixels, 'label': label}

W
wangyang59 已提交
30 31
    f.close()
    l.close()