mnist.py 1.5 KB
Newer Older
Y
Yu Yang 已提交
1 2 3
import sklearn.datasets.mldata
import sklearn.model_selection
import numpy
Y
Yu Yang 已提交
4
from config import DATA_HOME
Y
Yu Yang 已提交
5

Y
Yu Yang 已提交
6
__all__ = ['MNIST', 'train_creator', 'test_creator']
Y
Yu Yang 已提交
7 8


Y
Yu Yang 已提交
9 10 11 12 13
def __mnist_reader_creator__(data, target):
    def reader():
        n_samples = data.shape[0]
        for i in xrange(n_samples):
            yield (data[i] / 255.0).astype(numpy.float32), int(target[i])
Y
Yu Yang 已提交
14

Y
Yu Yang 已提交
15
    return reader
Y
Yu Yang 已提交
16 17


Y
Yu Yang 已提交
18
class MNIST(object):
Y
Yu Yang 已提交
19 20 21 22 23 24
    """
    mnist dataset reader. The `train_reader` and `test_reader` method returns
    a iterator of each sample. Each sample is combined by 784-dim float and a
    one-dim label
    """

Y
Yu Yang 已提交
25
    def __init__(self, random_state=0, test_size=10000, **options):
Y
Yu Yang 已提交
26 27 28
        data = sklearn.datasets.mldata.fetch_mldata(
            "MNIST original", data_home=DATA_HOME)
        self.X_train, self.X_test, self.y_train, self.y_test = sklearn.model_selection.train_test_split(
Y
Yu Yang 已提交
29 30 31 32 33
            data.data,
            data.target,
            test_size=test_size,
            random_state=random_state,
            **options)
Y
Yu Yang 已提交
34

Y
Yu Yang 已提交
35 36
    def train_creator(self):
        return __mnist_reader_creator__(self.X_train, self.y_train)
Y
Yu Yang 已提交
37

Y
Yu Yang 已提交
38 39
    def test_creator(self):
        return __mnist_reader_creator__(self.X_test, self.y_test)
Y
Yu Yang 已提交
40 41


Y
Yu Yang 已提交
42 43 44
__default_instance__ = MNIST()
train_creator = __default_instance__.train_creator
test_creator = __default_instance__.test_creator
Y
Yu Yang 已提交
45 46 47


def unittest():
Y
Yu Yang 已提交
48 49 50
    size = 12045
    mnist = MNIST(test_size=size)
    assert len(list(mnist.test_creator()())) == size
Y
Yu Yang 已提交
51 52 53 54


if __name__ == '__main__':
    unittest()