提交 ee4b0bcc 编写于 作者: O Oleg Zabluda 提交者: François Chollet

Improve synchronized shuffle in datasets (#8325)

* Improve synchronized shuffle

* pep8
上级 de17ff3b
......@@ -25,9 +25,10 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
f.close()
np.random.seed(seed)
np.random.shuffle(x)
np.random.seed(seed)
np.random.shuffle(y)
indices = np.arange(len(x))
np.random.shuffle(indices)
x = x[indices]
y = y[indices]
x_train = np.array(x[:int(len(x) * (1 - test_split))])
y_train = np.array(y[:int(len(x) * (1 - test_split))])
......
......@@ -55,15 +55,15 @@ def load_data(path='imdb.npz', num_words=None, skip_top=0,
x_train, labels_train = f['x_train'], f['y_train']
x_test, labels_test = f['x_test'], f['y_test']
np.random.seed(seed)
np.random.shuffle(x_train)
np.random.seed(seed)
np.random.shuffle(labels_train)
np.random.seed(seed * 2)
np.random.shuffle(x_test)
np.random.seed(seed * 2)
np.random.shuffle(labels_test)
indices = np.arange(len(x_train))
np.random.shuffle(indices)
x_train = x_train[indices]
labels_train = labels_train[indices]
indices = np.arange(len(x_test))
np.random.shuffle(indices)
x_test = x_test[indices]
labels_test = labels_test[indices]
xs = np.concatenate([x_train, x_test])
labels = np.concatenate([labels_train, labels_test])
......
......@@ -53,9 +53,10 @@ def load_data(path='reuters.npz', num_words=None, skip_top=0,
xs, labels = f['x'], f['y']
np.random.seed(seed)
np.random.shuffle(xs)
np.random.seed(seed)
np.random.shuffle(labels)
indices = np.arange(len(xs))
np.random.shuffle(indices)
xs = xs[indices]
labels = labels[indices]
if start_char is not None:
xs = [[start_char] + [w + index_from for w in x] for x in xs]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册