提交 111e7710 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1438 from reyoung/feature/mnist_reader

MNIST dataset reader implementation
import os
__all__ = ['DATA_HOME']
DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set')
if not os.path.exists(DATA_HOME):
os.makedirs(DATA_HOME)
import sklearn.datasets.mldata
import sklearn.model_selection
import numpy
from config import DATA_HOME
__all__ = ['train_creator', 'test_creator']
def __mnist_reader_creator__(data, target):
def reader():
n_samples = data.shape[0]
for i in xrange(n_samples):
yield (data[i] / 255.0).astype(numpy.float32), int(target[i])
return reader
TEST_SIZE = 10000
data = sklearn.datasets.mldata.fetch_mldata(
"MNIST original", data_home=DATA_HOME)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
data.data, data.target, test_size=TEST_SIZE, random_state=0)
def train_creator():
return __mnist_reader_creator__(X_train, y_train)
def test_creator():
return __mnist_reader_creator__(X_test, y_test)
def unittest():
assert len(list(test_creator()())) == TEST_SIZE
if __name__ == '__main__':
unittest()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册