diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 07553a81414343dbe232402dc1d4e429e47d46f4..7da5f35d90952f114b19982ddf951e1c2a006558 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -202,7 +202,6 @@ class DataReader(object): self._max_length = max_length self._field_delimiter = field_delimiter self._token_delimiter = token_delimiter - self._epoch_batches = [] self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname, unk_mark) self._random = random.Random(x=seed) @@ -248,17 +247,23 @@ class DataReader(object): if tar_fname is None: raise Exception("If tar file provided, please set tar_fname.") - f = tarfile.open(fpaths[0], 'r') + f = tarfile.open(fpaths[0], "r") for line in f.extractfile(tar_fname): - yield line.split(self._field_delimiter) + fields = line.strip("\n").split(self._field_delimiter) + if (not self._only_src and len(fields) == 2) or ( + self._only_src and len(fields) == 1): + yield fields else: for fpath in fpaths: if not os.path.isfile(fpath): raise IOError("Invalid file: %s" % fpath) - with open(fpath, 'r') as f: + with open(fpath, "r") as f: for line in f: - yield line.split(self._field_delimiter) + fields = line.strip("\n").split(self._field_delimiter) + if (not self._only_src and len(fields) == 2) or ( + self._only_src and len(fields) == 1): + yield fields @staticmethod def load_dict(dict_path, reverse=False): @@ -266,9 +271,9 @@ class DataReader(object): with open(dict_path, "r") as fdict: for idx, line in enumerate(fdict): if reverse: - word_dict[idx] = line.strip('\n') + word_dict[idx] = line.strip("\n") else: - word_dict[line.strip('\n')] = idx + word_dict[line.strip("\n")] = idx return word_dict def batch_generator(self): @@ -285,12 +290,8 @@ class DataReader(object): if self._sort_type == SortType.POOL: for i in range(0, len(infos), self._pool_size): - infos[i * self._pool_size:(i + 1) * - self._pool_size] = sorted( - infos[i * self._pool_size:(i + 1) * - self._pool_size], - key=lambda x: x.max_len, - reverse=True) + infos[i:i + self._pool_size] = sorted( + infos[i:i + self._pool_size], key=lambda x: x.max_len) # concat batch batches = []