mnist.py 892 字节
Newer Older
Y
Yu Yang 已提交
1 2 3
import sklearn.datasets.mldata
import sklearn.model_selection
import numpy
Y
Yu Yang 已提交
4
from config import DATA_HOME
Y
Yu Yang 已提交
5

6
__all__ = ['train_creator', 'test_creator']
Y
Yu Yang 已提交
7 8


Y
Yu Yang 已提交
9 10 11 12 13
def __mnist_reader_creator__(data, target):
    def reader():
        n_samples = data.shape[0]
        for i in xrange(n_samples):
            yield (data[i] / 255.0).astype(numpy.float32), int(target[i])
Y
Yu Yang 已提交
14

Y
Yu Yang 已提交
15
    return reader
Y
Yu Yang 已提交
16 17


Y
Yu Yang 已提交
18
TEST_SIZE = 10000
Y
Yu Yang 已提交
19

Y
Yu Yang 已提交
20 21 22 23
data = sklearn.datasets.mldata.fetch_mldata(
    "MNIST original", data_home=DATA_HOME)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
    data.data, data.target, test_size=TEST_SIZE, random_state=0)
Y
Yu Yang 已提交
24 25


Y
Yu Yang 已提交
26 27
def train_creator():
    return __mnist_reader_creator__(X_train, y_train)
Y
Yu Yang 已提交
28 29


Y
Yu Yang 已提交
30 31
def test_creator():
    return __mnist_reader_creator__(X_test, y_test)
Y
Yu Yang 已提交
32 33 34


def unittest():
Y
Yu Yang 已提交
35
    assert len(list(test_creator()())) == TEST_SIZE
Y
Yu Yang 已提交
36 37 38 39


if __name__ == '__main__':
    unittest()