mnist_provider.py 389 字节
Newer Older
L
Luo Tao 已提交
1 2 3 4 5 6 7 8 9 10 11 12
from paddle.trainer.PyDataProvider2 import *
from mnist_util import read_from_mnist


# Define a py data provider
@provider(
    input_types={'pixel': dense_vector(28 * 28),
                 'label': integer_value(10)},
    cache=CacheType.CACHE_PASS_IN_MEM)
def process(settings, filename):  # settings is not used currently.
    for each in read_from_mnist(filename):
        yield each