diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 5c13d89ef8e2ea0141e5d8ffcc8cb9320a7bf00d..4d61affee857b93986425658eae92f4f5ef9bc05 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -1,8 +1,9 @@ import glob import os -import random import tarfile +import numpy as np + class SortType(object): GLOBAL = 'global' @@ -203,7 +204,8 @@ class DataReader(object): self._token_delimiter = token_delimiter self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname, unk_mark) - self._random = random.Random(x=seed) + self._random = np.random + self._random.seed(seed) def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname, unk_mark):