mnist.py 1.5 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
import sklearn.datasets.mldata
import sklearn.model_selection
import numpy

__all__ = ['MNISTReader', 'train_reader_creator', 'test_reader_creator']

DATA_HOME = None


def __mnist_reader__(data, target):
    n_samples = data.shape[0]
    for i in xrange(n_samples):
        yield data[i].astype(numpy.float32), int(target[i])


class MNISTReader(object):
    """
    mnist dataset reader. The `train_reader` and `test_reader` method returns
    a iterator of each sample. Each sample is combined by 784-dim float and a
    one-dim label
    """

    def __init__(self, random_state):
        data = sklearn.datasets.mldata.fetch_mldata(
            "MNIST original", data_home=DATA_HOME)
        n_train = 60000
        self.X_train, self.X_test, self.y_train, self.y_test = sklearn.model_selection.train_test_split(
            data.data / 255.0,
            data.target.astype("int"),
            train_size=n_train,
            random_state=random_state)

    def train_reader(self):
        return __mnist_reader__(self.X_train, self.y_train)

    def test_reader(self):
        return __mnist_reader__(self.X_test, self.y_test)


__default_instance__ = MNISTReader(0)


def train_reader_creator():
    """
    Default train set reader creator.
    """
    return __default_instance__.train_reader


def test_reader_creator():
    """
    Default test set reader creator.
    """
    return __default_instance__.test_reader


def unittest():
    assert len(list(train_reader_creator()())) == 60000


if __name__ == '__main__':
    unittest()