reader.py 11.6 KB
Newer Older
1
import glob
2
import six
Y
Yu Yang 已提交
3 4
import os
import tarfile
5

6 7
import numpy as np

8 9 10 11 12 13 14

class SortType(object):
    GLOBAL = 'global'
    POOL = 'pool'
    NONE = "none"


Y
Yu Yang 已提交
15
class Converter(object):
16
    def __init__(self, vocab, beg, end, unk, delimiter, add_beg):
Y
Yu Yang 已提交
17 18 19 20
        self._vocab = vocab
        self._beg = beg
        self._end = end
        self._unk = unk
21
        self._delimiter = delimiter
22
        self._add_beg = add_beg
23

Y
Yu Yang 已提交
24
    def __call__(self, sentence):
25
        return ([self._beg] if self._add_beg else []) + [
26 27
            self._vocab.get(w, self._unk)
            for w in sentence.split(self._delimiter)
Y
Yu Yang 已提交
28
        ] + [self._end]
29

Y
Yu Yang 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

class ComposedConverter(object):
    def __init__(self, converters):
        self._converters = converters

    def __call__(self, parallel_sentence):
        return [
            self._converters[i](parallel_sentence[i])
            for i in range(len(self._converters))
        ]


class SentenceBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self._batch_size = batch_size

    def append(self, info):
        self.batch.append(info)
        if len(self.batch) == self._batch_size:
            tmp = self.batch
            self.batch = []
            return tmp


class TokenBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self.max_len = -1
        self._batch_size = batch_size

    def append(self, info):
        cur_len = info.max_len
        max_len = max(self.max_len, cur_len)
        if max_len * (len(self.batch) + 1) > self._batch_size:
            result = self.batch
            self.batch = [info]
            self.max_len = cur_len
            return result
        else:
            self.max_len = max_len
            self.batch.append(info)


class SampleInfo(object):
    def __init__(self, i, max_len, min_len):
        self.i = i
        self.min_len = min_len
        self.max_len = max_len


class MinMaxFilter(object):
    def __init__(self, max_len, min_len, underlying_creator):
        self._min_len = min_len
        self._max_len = max_len
        self._creator = underlying_creator

    def append(self, info):
        if info.max_len > self._max_len or info.min_len < self._min_len:
            return
90
        else:
Y
Yu Yang 已提交
91 92 93 94 95
            return self._creator.append(info)

    @property
    def batch(self):
        return self._creator.batch
96 97 98 99 100


class DataReader(object):
    """
    The data reader loads all data from files and produces batches of data
101
    in the way corresponding to settings.
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

    An example of returning a generator producing data batches whose data
    is shuffled in each pass and sorted in each pool:

    ```
    train_data = DataReader(
        src_vocab_fpath='data/src_vocab_file',
        trg_vocab_fpath='data/trg_vocab_file',
        fpattern='data/part-*',
        use_token_batch=True,
        batch_size=2000,
        pool_size=10000,
        sort_type=SortType.POOL,
        shuffle=True,
        shuffle_batch=True,
        start_mark='<s>',
        end_mark='<e>',
        unk_mark='<unk>',
        clip_last_batch=False).batch_generator
    ```
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152

    :param src_vocab_fpath: The path of vocabulary file of source language.
    :type src_vocab_fpath: basestring
    :param trg_vocab_fpath: The path of vocabulary file of target language.
    :type trg_vocab_fpath: basestring
    :param fpattern: The pattern to match data files.
    :type fpattern: basestring
    :param batch_size: The number of sequences contained in a mini-batch.
        or the maximum number of tokens (include paddings) contained in a
        mini-batch.
    :type batch_size: int
    :param pool_size: The size of pool buffer.
    :type pool_size: int
    :param sort_type: The grain to sort by length: 'global' for all
        instances; 'pool' for instances in pool; 'none' for no sort.
    :type sort_type: basestring
    :param clip_last_batch: Whether to clip the last uncompleted batch.
    :type clip_last_batch: bool
    :param tar_fname: The data file in tar if fpattern matches a tar file.
    :type tar_fname: basestring
    :param min_length: The minimum length used to filt sequences.
    :type min_length: int
    :param max_length: The maximum length used to filt sequences.
    :type max_length: int
    :param shuffle: Whether to shuffle all instances.
    :type shuffle: bool
    :param shuffle_batch: Whether to shuffle the generated batches.
    :type shuffle_batch: bool
    :param use_token_batch: Whether to produce batch data according to
        token number.
    :type use_token_batch: bool
153 154 155 156 157 158
    :param field_delimiter: The delimiter used to split source and target in
        each line of data file.
    :type field_delimiter: basestring
    :param token_delimiter: The delimiter used to split tokens in source or
        target sentences.
    :type token_delimiter: basestring
159 160 161 162 163 164 165 166 167 168
    :param start_mark: The token representing for the beginning of
        sentences in dictionary.
    :type start_mark: basestring
    :param end_mark: The token representing for the end of sentences
        in dictionary.
    :type end_mark: basestring
    :param unk_mark: The token representing for unknown word in dictionary.
    :type unk_mark: basestring
    :param seed: The seed for random.
    :type seed: int
169 170 171 172 173 174 175 176
    """

    def __init__(self,
                 src_vocab_fpath,
                 trg_vocab_fpath,
                 fpattern,
                 batch_size,
                 pool_size,
Y
Yu Yang 已提交
177
                 sort_type=SortType.GLOBAL,
178 179 180 181 182
                 clip_last_batch=True,
                 tar_fname=None,
                 min_length=0,
                 max_length=100,
                 shuffle=True,
183
                 shuffle_seed=None,
184 185
                 shuffle_batch=False,
                 use_token_batch=False,
186 187
                 field_delimiter="\t",
                 token_delimiter=" ",
188 189 190 191
                 start_mark="<s>",
                 end_mark="<e>",
                 unk_mark="<unk>",
                 seed=0):
192
        self._src_vocab = self.load_dict(src_vocab_fpath)
193 194
        self._only_src = True
        if trg_vocab_fpath is not None:
195
            self._trg_vocab = self.load_dict(trg_vocab_fpath)
196 197 198 199 200 201 202
            self._only_src = False
        self._pool_size = pool_size
        self._batch_size = batch_size
        self._use_token_batch = use_token_batch
        self._sort_type = sort_type
        self._clip_last_batch = clip_last_batch
        self._shuffle = shuffle
203
        self._shuffle_seed = shuffle_seed
204 205 206
        self._shuffle_batch = shuffle_batch
        self._min_length = min_length
        self._max_length = max_length
207 208
        self._field_delimiter = field_delimiter
        self._token_delimiter = token_delimiter
209 210
        self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
                              unk_mark)
211 212
        self._random = np.random
        self._random.seed(seed)
Y
Yu Yang 已提交
213 214 215 216 217 218 219 220

    def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
                         unk_mark):
        converters = [
            Converter(
                vocab=self._src_vocab,
                beg=self._src_vocab[start_mark],
                end=self._src_vocab[end_mark],
221
                unk=self._src_vocab[unk_mark],
222 223
                delimiter=self._token_delimiter,
                add_beg=False)
Y
Yu Yang 已提交
224
        ]
225
        if not self._only_src:
Y
Yu Yang 已提交
226 227 228 229 230
            converters.append(
                Converter(
                    vocab=self._trg_vocab,
                    beg=self._trg_vocab[start_mark],
                    end=self._trg_vocab[end_mark],
231
                    unk=self._trg_vocab[unk_mark],
232 233
                    delimiter=self._token_delimiter,
                    add_beg=True))
Y
Yu Yang 已提交
234 235 236 237 238 239 240 241 242 243 244

        converters = ComposedConverter(converters)

        self._src_seq_ids = []
        self._trg_seq_ids = None if self._only_src else []
        self._sample_infos = []

        for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
            src_trg_ids = converters(line)
            self._src_seq_ids.append(src_trg_ids[0])
            lens = [len(src_trg_ids[0])]
245
            if not self._only_src:
Y
Yu Yang 已提交
246 247 248
                self._trg_seq_ids.append(src_trg_ids[1])
                lens.append(len(src_trg_ids[1]))
            self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
249

Y
Yu Yang 已提交
250
    def _load_lines(self, fpattern, tar_fname):
251 252 253 254 255 256
        fpaths = glob.glob(fpattern)

        if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
            if tar_fname is None:
                raise Exception("If tar file provided, please set tar_fname.")

257
            f = tarfile.open(fpaths[0], "r")
Y
Yu Yang 已提交
258
            for line in f.extractfile(tar_fname):
259 260 261 262
                fields = line.strip("\n").split(self._field_delimiter)
                if (not self._only_src and len(fields) == 2) or (
                        self._only_src and len(fields) == 1):
                    yield fields
263 264 265 266 267
        else:
            for fpath in fpaths:
                if not os.path.isfile(fpath):
                    raise IOError("Invalid file: %s" % fpath)

268
                with open(fpath, "rb") as f:
Y
Yu Yang 已提交
269
                    for line in f:
270
                        if six.PY3:
271
                            line = line.decode("utf8", errors="ignore")
272 273 274 275
                        fields = line.strip("\n").split(self._field_delimiter)
                        if (not self._only_src and len(fields) == 2) or (
                                self._only_src and len(fields) == 1):
                            yield fields
276

277 278
    @staticmethod
    def load_dict(dict_path, reverse=False):
279
        word_dict = {}
280
        with open(dict_path, "rb") as fdict:
281
            for idx, line in enumerate(fdict):
282
                if six.PY3:
283
                    line = line.decode("utf8", errors="ignore")
284
                if reverse:
285
                    word_dict[idx] = line.strip("\n")
286
                else:
287
                    word_dict[line.strip("\n")] = idx
288 289
        return word_dict

Y
Yu Yang 已提交
290 291
    def batch_generator(self):
        # global sort or global shuffle
292
        if self._sort_type == SortType.GLOBAL:
293
            infos = sorted(self._sample_infos, key=lambda x: x.max_len)
Y
Yu Yang 已提交
294
        else:
Y
Yu Yang 已提交
295 296
            if self._shuffle:
                infos = self._sample_infos
297 298
                if self._shuffle_seed is not None:
                    self._random.seed(self._shuffle_seed)
Y
Yu Yang 已提交
299 300 301 302 303
                self._random.shuffle(infos)
            else:
                infos = self._sample_infos

            if self._sort_type == SortType.POOL:
304
                reverse = True
Y
Yu Yang 已提交
305
                for i in range(0, len(infos), self._pool_size):
306 307
                    # to avoid placing short next to long sentences
                    reverse = not reverse
G
guosheng 已提交
308
                    infos[i:i + self._pool_size] = sorted(
309 310 311
                        infos[i:i + self._pool_size],
                        key=lambda x: x.max_len,
                        reverse=reverse)
312

Y
Yu Yang 已提交
313 314 315 316 317 318 319
        # concat batch
        batches = []
        batch_creator = TokenBatchCreator(
            self._batch_size
        ) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
        batch_creator = MinMaxFilter(self._max_length, self._min_length,
                                     batch_creator)
320

Y
Yu Yang 已提交
321 322 323
        for info in infos:
            batch = batch_creator.append(info)
            if batch is not None:
Y
Yu Yang 已提交
324
                batches.append(batch)
325

Y
Yu Yang 已提交
326
        if not self._clip_last_batch and len(batch_creator.batch) != 0:
Y
Yu Yang 已提交
327
            batches.append(batch_creator.batch)
328

Y
Yu Yang 已提交
329
        if self._shuffle_batch:
Y
Yu Yang 已提交
330
            self._random.shuffle(batches)
331

Y
Yu Yang 已提交
332
        for batch in batches:
Y
Yu Yang 已提交
333
            batch_ids = [info.i for info in batch]
Y
Yu Yang 已提交
334

Y
Yu Yang 已提交
335 336 337 338 339
            if self._only_src:
                yield [[self._src_seq_ids[idx]] for idx in batch_ids]
            else:
                yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
                        self._trg_seq_ids[idx][1:]) for idx in batch_ids]