diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index db84f37aa4fc3477b17599a48a4de9b45cfb6c1f..faae818a5d7b78b22a30e719411feee94a2cc883 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -16,18 +16,29 @@ def __mnist_reader_creator__(data, target): TEST_SIZE = 10000 +X_train = None +X_test = None +y_train = None +y_test = None -data = sklearn.datasets.mldata.fetch_mldata( - "MNIST original", data_home=DATA_HOME) -X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( - data.data, data.target, test_size=TEST_SIZE, random_state=0) + +def __initialize_dataset__(): + global X_train, X_test, y_train, y_test + if X_train is not None: + return + data = sklearn.datasets.mldata.fetch_mldata( + "MNIST original", data_home=DATA_HOME) + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + data.data, data.target, test_size=TEST_SIZE, random_state=0) def train_creator(): + __initialize_dataset__() return __mnist_reader_creator__(X_train, y_train) def test_creator(): + __initialize_dataset__() return __mnist_reader_creator__(X_test, y_test)