From 38a792f20ed9e65d2920ded6ad42a5b68f2146ee Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 24 Feb 2017 13:52:31 +0800 Subject: [PATCH] Clean mnist code --- python/paddle/v2/data_set/config.py | 8 ++++ python/paddle/v2/data_set/mnist.py | 58 +++++++++++++---------------- 2 files changed, 33 insertions(+), 33 deletions(-) create mode 100644 python/paddle/v2/data_set/config.py diff --git a/python/paddle/v2/data_set/config.py b/python/paddle/v2/data_set/config.py new file mode 100644 index 00000000000..69e96d65ef1 --- /dev/null +++ b/python/paddle/v2/data_set/config.py @@ -0,0 +1,8 @@ +import os + +__all__ = ['DATA_HOME'] + +DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set') + +if not os.path.exists(DATA_HOME): + os.makedirs(DATA_HOME) diff --git a/python/paddle/v2/data_set/mnist.py b/python/paddle/v2/data_set/mnist.py index 34f61bb9f63..6f35acf6836 100644 --- a/python/paddle/v2/data_set/mnist.py +++ b/python/paddle/v2/data_set/mnist.py @@ -1,61 +1,53 @@ import sklearn.datasets.mldata import sklearn.model_selection import numpy +from config import DATA_HOME -__all__ = ['MNISTReader', 'train_reader_creator', 'test_reader_creator'] +__all__ = ['MNIST', 'train_creator', 'test_creator'] -DATA_HOME = None +def __mnist_reader_creator__(data, target): + def reader(): + n_samples = data.shape[0] + for i in xrange(n_samples): + yield (data[i] / 255.0).astype(numpy.float32), int(target[i]) -def __mnist_reader__(data, target): - n_samples = data.shape[0] - for i in xrange(n_samples): - yield data[i].astype(numpy.float32), int(target[i]) + return reader -class MNISTReader(object): +class MNIST(object): """ mnist dataset reader. The `train_reader` and `test_reader` method returns a iterator of each sample. Each sample is combined by 784-dim float and a one-dim label """ - def __init__(self, random_state): + def __init__(self, random_state=0, test_size=10000, **options): data = sklearn.datasets.mldata.fetch_mldata( "MNIST original", data_home=DATA_HOME) - n_train = 60000 self.X_train, self.X_test, self.y_train, self.y_test = sklearn.model_selection.train_test_split( - data.data / 255.0, - data.target.astype("int"), - train_size=n_train, - random_state=random_state) + data.data, + data.target, + test_size=test_size, + random_state=random_state, + **options) - def train_reader(self): - return __mnist_reader__(self.X_train, self.y_train) + def train_creator(self): + return __mnist_reader_creator__(self.X_train, self.y_train) - def test_reader(self): - return __mnist_reader__(self.X_test, self.y_test) + def test_creator(self): + return __mnist_reader_creator__(self.X_test, self.y_test) -__default_instance__ = MNISTReader(0) - - -def train_reader_creator(): - """ - Default train set reader creator. - """ - return __default_instance__.train_reader - - -def test_reader_creator(): - """ - Default test set reader creator. - """ - return __default_instance__.test_reader +__default_instance__ = MNIST() +train_creator = __default_instance__.train_creator +test_creator = __default_instance__.test_creator def unittest(): - assert len(list(train_reader_creator()())) == 60000 + size = 12045 + mnist = MNIST(test_size=size) + assert len(list(mnist.test_creator()())) == size if __name__ == '__main__': -- GitLab