reader.py 11.2 KB
Newer Older
1
import glob
2
import six
Y
Yu Yang 已提交
3 4
import os
import tarfile
5

6 7
import numpy as np

8 9 10 11 12 13 14

class SortType(object):
    GLOBAL = 'global'
    POOL = 'pool'
    NONE = "none"


Y
Yu Yang 已提交
15
class Converter(object):
16
    def __init__(self, vocab, beg, end, unk, delimiter, add_beg):
Y
Yu Yang 已提交
17 18 19 20
        self._vocab = vocab
        self._beg = beg
        self._end = end
        self._unk = unk
21
        self._delimiter = delimiter
22
        self._add_beg = add_beg
23

Y
Yu Yang 已提交
24
    def __call__(self, sentence):
25
        return ([self._beg] if self._add_beg else []) + [
26 27
            self._vocab.get(w, self._unk)
            for w in sentence.split(self._delimiter)
Y
Yu Yang 已提交
28
        ] + [self._end]
29

Y
Yu Yang 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

class ComposedConverter(object):
    def __init__(self, converters):
        self._converters = converters

    def __call__(self, parallel_sentence):
        return [
            self._converters[i](parallel_sentence[i])
            for i in range(len(self._converters))
        ]


class SentenceBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self._batch_size = batch_size

    def append(self, info):
        self.batch.append(info)
        if len(self.batch) == self._batch_size:
            tmp = self.batch
            self.batch = []
            return tmp


class TokenBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self.max_len = -1
        self._batch_size = batch_size

    def append(self, info):
        cur_len = info.max_len
        max_len = max(self.max_len, cur_len)
        if max_len * (len(self.batch) + 1) > self._batch_size:
            result = self.batch
            self.batch = [info]
            self.max_len = cur_len
            return result
        else:
            self.max_len = max_len
            self.batch.append(info)


class SampleInfo(object):
    def __init__(self, i, max_len, min_len):
        self.i = i
        self.min_len = min_len
        self.max_len = max_len


class MinMaxFilter(object):
    def __init__(self, max_len, min_len, underlying_creator):
        self._min_len = min_len
        self._max_len = max_len
        self._creator = underlying_creator

    def append(self, info):
        if info.max_len > self._max_len or info.min_len < self._min_len:
            return
90
        else:
Y
Yu Yang 已提交
91 92 93 94 95
            return self._creator.append(info)

    @property
    def batch(self):
        return self._creator.batch
96 97 98 99 100


class DataReader(object):
    """
    The data reader loads all data from files and produces batches of data
101
    in the way corresponding to settings.
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

    An example of returning a generator producing data batches whose data
    is shuffled in each pass and sorted in each pool:

    ```
    train_data = DataReader(
        src_vocab_fpath='data/src_vocab_file',
        trg_vocab_fpath='data/trg_vocab_file',
        fpattern='data/part-*',
        use_token_batch=True,
        batch_size=2000,
        pool_size=10000,
        sort_type=SortType.POOL,
        shuffle=True,
        shuffle_batch=True,
        start_mark='<s>',
        end_mark='<e>',
        unk_mark='<unk>',
        clip_last_batch=False).batch_generator
    ```
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152

    :param src_vocab_fpath: The path of vocabulary file of source language.
    :type src_vocab_fpath: basestring
    :param trg_vocab_fpath: The path of vocabulary file of target language.
    :type trg_vocab_fpath: basestring
    :param fpattern: The pattern to match data files.
    :type fpattern: basestring
    :param batch_size: The number of sequences contained in a mini-batch.
        or the maximum number of tokens (include paddings) contained in a
        mini-batch.
    :type batch_size: int
    :param pool_size: The size of pool buffer.
    :type pool_size: int
    :param sort_type: The grain to sort by length: 'global' for all
        instances; 'pool' for instances in pool; 'none' for no sort.
    :type sort_type: basestring
    :param clip_last_batch: Whether to clip the last uncompleted batch.
    :type clip_last_batch: bool
    :param tar_fname: The data file in tar if fpattern matches a tar file.
    :type tar_fname: basestring
    :param min_length: The minimum length used to filt sequences.
    :type min_length: int
    :param max_length: The maximum length used to filt sequences.
    :type max_length: int
    :param shuffle: Whether to shuffle all instances.
    :type shuffle: bool
    :param shuffle_batch: Whether to shuffle the generated batches.
    :type shuffle_batch: bool
    :param use_token_batch: Whether to produce batch data according to
        token number.
    :type use_token_batch: bool
153 154 155 156 157 158
    :param field_delimiter: The delimiter used to split source and target in
        each line of data file.
    :type field_delimiter: basestring
    :param token_delimiter: The delimiter used to split tokens in source or
        target sentences.
    :type token_delimiter: basestring
159 160 161 162 163 164 165 166 167 168
    :param start_mark: The token representing for the beginning of
        sentences in dictionary.
    :type start_mark: basestring
    :param end_mark: The token representing for the end of sentences
        in dictionary.
    :type end_mark: basestring
    :param unk_mark: The token representing for unknown word in dictionary.
    :type unk_mark: basestring
    :param seed: The seed for random.
    :type seed: int
169 170 171 172 173 174 175 176
    """

    def __init__(self,
                 src_vocab_fpath,
                 trg_vocab_fpath,
                 fpattern,
                 batch_size,
                 pool_size,
Y
Yu Yang 已提交
177
                 sort_type=SortType.GLOBAL,
178 179 180 181 182 183 184
                 clip_last_batch=True,
                 tar_fname=None,
                 min_length=0,
                 max_length=100,
                 shuffle=True,
                 shuffle_batch=False,
                 use_token_batch=False,
185 186
                 field_delimiter="\t",
                 token_delimiter=" ",
187 188 189 190
                 start_mark="<s>",
                 end_mark="<e>",
                 unk_mark="<unk>",
                 seed=0):
191
        self._src_vocab = self.load_dict(src_vocab_fpath)
192 193
        self._only_src = True
        if trg_vocab_fpath is not None:
194
            self._trg_vocab = self.load_dict(trg_vocab_fpath)
195 196 197 198 199 200 201 202 203 204
            self._only_src = False
        self._pool_size = pool_size
        self._batch_size = batch_size
        self._use_token_batch = use_token_batch
        self._sort_type = sort_type
        self._clip_last_batch = clip_last_batch
        self._shuffle = shuffle
        self._shuffle_batch = shuffle_batch
        self._min_length = min_length
        self._max_length = max_length
205 206
        self._field_delimiter = field_delimiter
        self._token_delimiter = token_delimiter
207 208
        self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
                              unk_mark)
209 210
        self._random = np.random
        self._random.seed(seed)
Y
Yu Yang 已提交
211 212 213 214 215 216 217 218

    def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
                         unk_mark):
        converters = [
            Converter(
                vocab=self._src_vocab,
                beg=self._src_vocab[start_mark],
                end=self._src_vocab[end_mark],
219
                unk=self._src_vocab[unk_mark],
220 221
                delimiter=self._token_delimiter,
                add_beg=False)
Y
Yu Yang 已提交
222
        ]
223
        if not self._only_src:
Y
Yu Yang 已提交
224 225 226 227 228
            converters.append(
                Converter(
                    vocab=self._trg_vocab,
                    beg=self._trg_vocab[start_mark],
                    end=self._trg_vocab[end_mark],
229
                    unk=self._trg_vocab[unk_mark],
230 231
                    delimiter=self._token_delimiter,
                    add_beg=True))
Y
Yu Yang 已提交
232 233 234 235 236 237 238 239 240 241 242

        converters = ComposedConverter(converters)

        self._src_seq_ids = []
        self._trg_seq_ids = None if self._only_src else []
        self._sample_infos = []

        for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
            src_trg_ids = converters(line)
            self._src_seq_ids.append(src_trg_ids[0])
            lens = [len(src_trg_ids[0])]
243
            if not self._only_src:
Y
Yu Yang 已提交
244 245 246
                self._trg_seq_ids.append(src_trg_ids[1])
                lens.append(len(src_trg_ids[1]))
            self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
247

Y
Yu Yang 已提交
248
    def _load_lines(self, fpattern, tar_fname):
249 250 251 252 253 254
        fpaths = glob.glob(fpattern)

        if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
            if tar_fname is None:
                raise Exception("If tar file provided, please set tar_fname.")

255
            f = tarfile.open(fpaths[0], "r")
Y
Yu Yang 已提交
256
            for line in f.extractfile(tar_fname):
257 258 259 260
                fields = line.strip("\n").split(self._field_delimiter)
                if (not self._only_src and len(fields) == 2) or (
                        self._only_src and len(fields) == 1):
                    yield fields
261 262 263 264 265
        else:
            for fpath in fpaths:
                if not os.path.isfile(fpath):
                    raise IOError("Invalid file: %s" % fpath)

266
                with open(fpath, "rb") as f:
Y
Yu Yang 已提交
267
                    for line in f:
268 269
                        if six.PY3:
                            line = line.decode()
270 271 272 273
                        fields = line.strip("\n").split(self._field_delimiter)
                        if (not self._only_src and len(fields) == 2) or (
                                self._only_src and len(fields) == 1):
                            yield fields
274

275 276
    @staticmethod
    def load_dict(dict_path, reverse=False):
277
        word_dict = {}
278
        with open(dict_path, "rb") as fdict:
279
            for idx, line in enumerate(fdict):
280 281
                if six.PY3:
                    line = line.decode()
282
                if reverse:
283
                    word_dict[idx] = line.strip("\n")
284
                else:
285
                    word_dict[line.strip("\n")] = idx
286 287
        return word_dict

Y
Yu Yang 已提交
288 289
    def batch_generator(self):
        # global sort or global shuffle
290
        if self._sort_type == SortType.GLOBAL:
291
            infos = sorted(self._sample_infos, key=lambda x: x.max_len)
Y
Yu Yang 已提交
292
        else:
Y
Yu Yang 已提交
293 294 295 296 297 298 299 300
            if self._shuffle:
                infos = self._sample_infos
                self._random.shuffle(infos)
            else:
                infos = self._sample_infos

            if self._sort_type == SortType.POOL:
                for i in range(0, len(infos), self._pool_size):
G
guosheng 已提交
301 302
                    infos[i:i + self._pool_size] = sorted(
                        infos[i:i + self._pool_size], key=lambda x: x.max_len)
303

Y
Yu Yang 已提交
304 305 306 307 308 309 310
        # concat batch
        batches = []
        batch_creator = TokenBatchCreator(
            self._batch_size
        ) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
        batch_creator = MinMaxFilter(self._max_length, self._min_length,
                                     batch_creator)
311

Y
Yu Yang 已提交
312 313 314
        for info in infos:
            batch = batch_creator.append(info)
            if batch is not None:
Y
Yu Yang 已提交
315
                batches.append(batch)
316

Y
Yu Yang 已提交
317
        if not self._clip_last_batch and len(batch_creator.batch) != 0:
Y
Yu Yang 已提交
318
            batches.append(batch_creator.batch)
319

Y
Yu Yang 已提交
320
        if self._shuffle_batch:
Y
Yu Yang 已提交
321
            self._random.shuffle(batches)
322

Y
Yu Yang 已提交
323
        for batch in batches:
Y
Yu Yang 已提交
324
            batch_ids = [info.i for info in batch]
Y
Yu Yang 已提交
325

Y
Yu Yang 已提交
326 327 328 329 330
            if self._only_src:
                yield [[self._src_seq_ids[idx]] for idx in batch_ids]
            else:
                yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
                        self._trg_seq_ids[idx][1:]) for idx in batch_ids]