From 792875e3eaa0467e40748f0ed97f022fe7fdcd0b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 28 Feb 2017 09:56:27 +0800 Subject: [PATCH] Lazy initialize mnist dataset. Fix unittest --- python/paddle/v2/dataset/mnist.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index db84f37aa4..faae818a5d 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -16,18 +16,29 @@ def __mnist_reader_creator__(data, target): TEST_SIZE = 10000 +X_train = None +X_test = None +y_train = None +y_test = None -data = sklearn.datasets.mldata.fetch_mldata( - "MNIST original", data_home=DATA_HOME) -X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( - data.data, data.target, test_size=TEST_SIZE, random_state=0) + +def __initialize_dataset__(): + global X_train, X_test, y_train, y_test + if X_train is not None: + return + data = sklearn.datasets.mldata.fetch_mldata( + "MNIST original", data_home=DATA_HOME) + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + data.data, data.target, test_size=TEST_SIZE, random_state=0) def train_creator(): + __initialize_dataset__() return __mnist_reader_creator__(X_train, y_train) def test_creator(): + __initialize_dataset__() return __mnist_reader_creator__(X_test, y_test) -- GitLab