提交 38a792f2 编写于 作者: Y Yu Yang

Clean mnist code

上级 d1ab3c80
import os
__all__ = ['DATA_HOME']
DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set')
if not os.path.exists(DATA_HOME):
os.makedirs(DATA_HOME)
import sklearn.datasets.mldata import sklearn.datasets.mldata
import sklearn.model_selection import sklearn.model_selection
import numpy import numpy
from config import DATA_HOME
__all__ = ['MNISTReader', 'train_reader_creator', 'test_reader_creator'] __all__ = ['MNIST', 'train_creator', 'test_creator']
DATA_HOME = None
def __mnist_reader_creator__(data, target):
def reader():
n_samples = data.shape[0]
for i in xrange(n_samples):
yield (data[i] / 255.0).astype(numpy.float32), int(target[i])
def __mnist_reader__(data, target): return reader
n_samples = data.shape[0]
for i in xrange(n_samples):
yield data[i].astype(numpy.float32), int(target[i])
class MNISTReader(object): class MNIST(object):
""" """
mnist dataset reader. The `train_reader` and `test_reader` method returns mnist dataset reader. The `train_reader` and `test_reader` method returns
a iterator of each sample. Each sample is combined by 784-dim float and a a iterator of each sample. Each sample is combined by 784-dim float and a
one-dim label one-dim label
""" """
def __init__(self, random_state): def __init__(self, random_state=0, test_size=10000, **options):
data = sklearn.datasets.mldata.fetch_mldata( data = sklearn.datasets.mldata.fetch_mldata(
"MNIST original", data_home=DATA_HOME) "MNIST original", data_home=DATA_HOME)
n_train = 60000
self.X_train, self.X_test, self.y_train, self.y_test = sklearn.model_selection.train_test_split( self.X_train, self.X_test, self.y_train, self.y_test = sklearn.model_selection.train_test_split(
data.data / 255.0, data.data,
data.target.astype("int"), data.target,
train_size=n_train, test_size=test_size,
random_state=random_state) random_state=random_state,
**options)
def train_reader(self): def train_creator(self):
return __mnist_reader__(self.X_train, self.y_train) return __mnist_reader_creator__(self.X_train, self.y_train)
def test_reader(self): def test_creator(self):
return __mnist_reader__(self.X_test, self.y_test) return __mnist_reader_creator__(self.X_test, self.y_test)
__default_instance__ = MNISTReader(0) __default_instance__ = MNIST()
train_creator = __default_instance__.train_creator
test_creator = __default_instance__.test_creator
def train_reader_creator():
"""
Default train set reader creator.
"""
return __default_instance__.train_reader
def test_reader_creator():
"""
Default test set reader creator.
"""
return __default_instance__.test_reader
def unittest(): def unittest():
assert len(list(train_reader_creator()())) == 60000 size = 12045
mnist = MNIST(test_size=size)
assert len(list(mnist.test_creator()())) == size
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册