mnist_provider.py 916 字节
Newer Older
W
wangyang59 已提交
1
from paddle.trainer.PyDataProvider2 import *
W
wangyang59 已提交
2
import numpy
W
wangyang59 已提交
3

W
wangyang59 已提交
4

W
wangyang59 已提交
5
# Define a py data provider
6 7
@provider(
    input_types={'pixel': dense_vector(28 * 28),
W
wangyang59 已提交
8 9
                 'label': integer_value(10)},
    cache=CacheType.CACHE_PASS_IN_MEM)
W
wangyang59 已提交
10 11 12 13 14 15 16 17
def process(settings, filename):  # settings is not used currently.
    imgf = filename + "-images-idx3-ubyte"
    labelf = filename + "-labels-idx1-ubyte"
    f = open(imgf, "rb")
    l = open(labelf, "rb")

    f.read(16)
    l.read(8)
18

W
wangyang59 已提交
19 20 21 22 23
    # Define number of samples for train/test
    if "train" in filename:
        n = 60000
    else:
        n = 10000
W
wangyang59 已提交
24 25 26 27

    images = numpy.fromfile(
        f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
    images = images / 255.0 * 2.0 - 1.0
W
wangyang59 已提交
28
    labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
W
wangyang59 已提交
29

W
wangyang59 已提交
30 31
    for i in xrange(n):
        yield {"pixel": images[i, :], 'label': labels[i]}
W
wangyang59 已提交
32

W
wangyang59 已提交
33 34
    f.close()
    l.close()