提交 a6028d79 编写于 作者: Y Yu Yang

Clean mnist reader

上级 befc3e06
...@@ -15,39 +15,24 @@ def __mnist_reader_creator__(data, target): ...@@ -15,39 +15,24 @@ def __mnist_reader_creator__(data, target):
return reader return reader
class MNIST(object): TEST_SIZE = 10000
"""
mnist dataset reader. The `train_reader` and `test_reader` method returns
a iterator of each sample. Each sample is combined by 784-dim float and a
one-dim label
"""
def __init__(self, random_state=0, test_size=10000, **options): data = sklearn.datasets.mldata.fetch_mldata(
data = sklearn.datasets.mldata.fetch_mldata( "MNIST original", data_home=DATA_HOME)
"MNIST original", data_home=DATA_HOME) X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
self.X_train, self.X_test, self.y_train, self.y_test = sklearn.model_selection.train_test_split( data.data, data.target, test_size=TEST_SIZE, random_state=0)
data.data,
data.target,
test_size=test_size,
random_state=random_state,
**options)
def train_creator(self):
return __mnist_reader_creator__(self.X_train, self.y_train)
def test_creator(self): def train_creator():
return __mnist_reader_creator__(self.X_test, self.y_test) return __mnist_reader_creator__(X_train, y_train)
__default_instance__ = MNIST() def test_creator():
train_creator = __default_instance__.train_creator return __mnist_reader_creator__(X_test, y_test)
test_creator = __default_instance__.test_creator
def unittest(): def unittest():
size = 12045 assert len(list(test_creator()())) == TEST_SIZE
mnist = MNIST(test_size=size)
assert len(list(mnist.test_creator()())) == size
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册