mnist_provider.py 770 字节
Newer Older
W
wangyang59 已提交
1 2 3 4
from paddle.trainer.PyDataProvider2 import *


# Define a py data provider
5 6 7 8
@provider(input_types={
    'pixel': dense_vector(28 * 28),
    'label': integer_value(10)
})
W
wangyang59 已提交
9 10 11 12 13 14 15 16
def process(settings, filename):  # settings is not used currently.
    imgf = filename + "-images-idx3-ubyte"
    labelf = filename + "-labels-idx1-ubyte"
    f = open(imgf, "rb")
    l = open(labelf, "rb")

    f.read(16)
    l.read(8)
17

W
wangyang59 已提交
18 19 20 21 22
    # Define number of samples for train/test
    if "train" in filename:
        n = 60000
    else:
        n = 10000
23

W
wangyang59 已提交
24 25 26
    for i in range(n):
        label = ord(l.read(1))
        pixels = []
27
        for j in range(28 * 28):
W
wangyang59 已提交
28
            pixels.append(float(ord(f.read(1))) / 255.0)
29 30
        yield {"pixel": pixels, 'label': label}

W
wangyang59 已提交
31 32
    f.close()
    l.close()