From a6028d79dcaba69f6f95c7ebf9c12c33ad42b82e Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 27 Feb 2017 10:39:17 +0800 Subject: [PATCH] Clean mnist reader --- python/paddle/v2/data_set/mnist.py | 35 +++++++++--------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/python/paddle/v2/data_set/mnist.py b/python/paddle/v2/data_set/mnist.py index 6f35acf6836..4b392af400a 100644 --- a/python/paddle/v2/data_set/mnist.py +++ b/python/paddle/v2/data_set/mnist.py @@ -15,39 +15,24 @@ def __mnist_reader_creator__(data, target): return reader -class MNIST(object): - """ - mnist dataset reader. The `train_reader` and `test_reader` method returns - a iterator of each sample. Each sample is combined by 784-dim float and a - one-dim label - """ +TEST_SIZE = 10000 - def __init__(self, random_state=0, test_size=10000, **options): - data = sklearn.datasets.mldata.fetch_mldata( - "MNIST original", data_home=DATA_HOME) - self.X_train, self.X_test, self.y_train, self.y_test = sklearn.model_selection.train_test_split( - data.data, - data.target, - test_size=test_size, - random_state=random_state, - **options) +data = sklearn.datasets.mldata.fetch_mldata( + "MNIST original", data_home=DATA_HOME) +X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + data.data, data.target, test_size=TEST_SIZE, random_state=0) - def train_creator(self): - return __mnist_reader_creator__(self.X_train, self.y_train) - def test_creator(self): - return __mnist_reader_creator__(self.X_test, self.y_test) +def train_creator(): + return __mnist_reader_creator__(X_train, y_train) -__default_instance__ = MNIST() -train_creator = __default_instance__.train_creator -test_creator = __default_instance__.test_creator +def test_creator(): + return __mnist_reader_creator__(X_test, y_test) def unittest(): - size = 12045 - mnist = MNIST(test_size=size) - assert len(list(mnist.test_creator()())) == size + assert len(list(test_creator()())) == TEST_SIZE if __name__ == '__main__': -- GitLab