reader.py 10.9 KB
Newer Older
1
import glob
Y
Yu Yang 已提交
2 3
import os
import tarfile
4

5 6
import numpy as np

7 8 9 10 11 12 13

class SortType(object):
    GLOBAL = 'global'
    POOL = 'pool'
    NONE = "none"


Y
Yu Yang 已提交
14
class Converter(object):
15
    def __init__(self, vocab, beg, end, unk, delimiter):
Y
Yu Yang 已提交
16 17 18 19
        self._vocab = vocab
        self._beg = beg
        self._end = end
        self._unk = unk
20
        self._delimiter = delimiter
21

Y
Yu Yang 已提交
22 23
    def __call__(self, sentence):
        return [self._beg] + [
24 25
            self._vocab.get(w, self._unk)
            for w in sentence.split(self._delimiter)
Y
Yu Yang 已提交
26
        ] + [self._end]
27

Y
Yu Yang 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

class ComposedConverter(object):
    def __init__(self, converters):
        self._converters = converters

    def __call__(self, parallel_sentence):
        return [
            self._converters[i](parallel_sentence[i])
            for i in range(len(self._converters))
        ]


class SentenceBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self._batch_size = batch_size

    def append(self, info):
        self.batch.append(info)
        if len(self.batch) == self._batch_size:
            tmp = self.batch
            self.batch = []
            return tmp


class TokenBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self.max_len = -1
        self._batch_size = batch_size

    def append(self, info):
        cur_len = info.max_len
        max_len = max(self.max_len, cur_len)
        if max_len * (len(self.batch) + 1) > self._batch_size:
            result = self.batch
            self.batch = [info]
            self.max_len = cur_len
            return result
        else:
            self.max_len = max_len
            self.batch.append(info)


class SampleInfo(object):
    def __init__(self, i, max_len, min_len):
        self.i = i
        self.min_len = min_len
        self.max_len = max_len


class MinMaxFilter(object):
    def __init__(self, max_len, min_len, underlying_creator):
        self._min_len = min_len
        self._max_len = max_len
        self._creator = underlying_creator

    def append(self, info):
        if info.max_len > self._max_len or info.min_len < self._min_len:
            return
88
        else:
Y
Yu Yang 已提交
89 90 91 92 93
            return self._creator.append(info)

    @property
    def batch(self):
        return self._creator.batch
94 95 96 97 98


class DataReader(object):
    """
    The data reader loads all data from files and produces batches of data
99
    in the way corresponding to settings.
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119

    An example of returning a generator producing data batches whose data
    is shuffled in each pass and sorted in each pool:

    ```
    train_data = DataReader(
        src_vocab_fpath='data/src_vocab_file',
        trg_vocab_fpath='data/trg_vocab_file',
        fpattern='data/part-*',
        use_token_batch=True,
        batch_size=2000,
        pool_size=10000,
        sort_type=SortType.POOL,
        shuffle=True,
        shuffle_batch=True,
        start_mark='<s>',
        end_mark='<e>',
        unk_mark='<unk>',
        clip_last_batch=False).batch_generator
    ```
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150

    :param src_vocab_fpath: The path of vocabulary file of source language.
    :type src_vocab_fpath: basestring
    :param trg_vocab_fpath: The path of vocabulary file of target language.
    :type trg_vocab_fpath: basestring
    :param fpattern: The pattern to match data files.
    :type fpattern: basestring
    :param batch_size: The number of sequences contained in a mini-batch.
        or the maximum number of tokens (include paddings) contained in a
        mini-batch.
    :type batch_size: int
    :param pool_size: The size of pool buffer.
    :type pool_size: int
    :param sort_type: The grain to sort by length: 'global' for all
        instances; 'pool' for instances in pool; 'none' for no sort.
    :type sort_type: basestring
    :param clip_last_batch: Whether to clip the last uncompleted batch.
    :type clip_last_batch: bool
    :param tar_fname: The data file in tar if fpattern matches a tar file.
    :type tar_fname: basestring
    :param min_length: The minimum length used to filt sequences.
    :type min_length: int
    :param max_length: The maximum length used to filt sequences.
    :type max_length: int
    :param shuffle: Whether to shuffle all instances.
    :type shuffle: bool
    :param shuffle_batch: Whether to shuffle the generated batches.
    :type shuffle_batch: bool
    :param use_token_batch: Whether to produce batch data according to
        token number.
    :type use_token_batch: bool
151 152 153 154 155 156
    :param field_delimiter: The delimiter used to split source and target in
        each line of data file.
    :type field_delimiter: basestring
    :param token_delimiter: The delimiter used to split tokens in source or
        target sentences.
    :type token_delimiter: basestring
157 158 159 160 161 162 163 164 165 166
    :param start_mark: The token representing for the beginning of
        sentences in dictionary.
    :type start_mark: basestring
    :param end_mark: The token representing for the end of sentences
        in dictionary.
    :type end_mark: basestring
    :param unk_mark: The token representing for unknown word in dictionary.
    :type unk_mark: basestring
    :param seed: The seed for random.
    :type seed: int
167 168 169 170 171 172 173 174
    """

    def __init__(self,
                 src_vocab_fpath,
                 trg_vocab_fpath,
                 fpattern,
                 batch_size,
                 pool_size,
Y
Yu Yang 已提交
175
                 sort_type=SortType.GLOBAL,
176 177 178 179 180 181 182
                 clip_last_batch=True,
                 tar_fname=None,
                 min_length=0,
                 max_length=100,
                 shuffle=True,
                 shuffle_batch=False,
                 use_token_batch=False,
183 184
                 field_delimiter="\t",
                 token_delimiter=" ",
185 186 187 188
                 start_mark="<s>",
                 end_mark="<e>",
                 unk_mark="<unk>",
                 seed=0):
189
        self._src_vocab = self.load_dict(src_vocab_fpath)
190 191
        self._only_src = True
        if trg_vocab_fpath is not None:
192
            self._trg_vocab = self.load_dict(trg_vocab_fpath)
193 194 195 196 197 198 199 200 201 202
            self._only_src = False
        self._pool_size = pool_size
        self._batch_size = batch_size
        self._use_token_batch = use_token_batch
        self._sort_type = sort_type
        self._clip_last_batch = clip_last_batch
        self._shuffle = shuffle
        self._shuffle_batch = shuffle_batch
        self._min_length = min_length
        self._max_length = max_length
203 204
        self._field_delimiter = field_delimiter
        self._token_delimiter = token_delimiter
205 206
        self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
                              unk_mark)
207 208
        self._random = np.random
        self._random.seed(seed)
Y
Yu Yang 已提交
209 210 211 212 213 214 215 216

    def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
                         unk_mark):
        converters = [
            Converter(
                vocab=self._src_vocab,
                beg=self._src_vocab[start_mark],
                end=self._src_vocab[end_mark],
217 218
                unk=self._src_vocab[unk_mark],
                delimiter=self._token_delimiter)
Y
Yu Yang 已提交
219
        ]
220
        if not self._only_src:
Y
Yu Yang 已提交
221 222 223 224 225
            converters.append(
                Converter(
                    vocab=self._trg_vocab,
                    beg=self._trg_vocab[start_mark],
                    end=self._trg_vocab[end_mark],
226 227
                    unk=self._trg_vocab[unk_mark],
                    delimiter=self._token_delimiter))
Y
Yu Yang 已提交
228 229 230 231 232 233 234 235 236 237 238

        converters = ComposedConverter(converters)

        self._src_seq_ids = []
        self._trg_seq_ids = None if self._only_src else []
        self._sample_infos = []

        for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
            src_trg_ids = converters(line)
            self._src_seq_ids.append(src_trg_ids[0])
            lens = [len(src_trg_ids[0])]
239
            if not self._only_src:
Y
Yu Yang 已提交
240 241 242
                self._trg_seq_ids.append(src_trg_ids[1])
                lens.append(len(src_trg_ids[1]))
            self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
243

Y
Yu Yang 已提交
244
    def _load_lines(self, fpattern, tar_fname):
245 246 247 248 249 250
        fpaths = glob.glob(fpattern)

        if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
            if tar_fname is None:
                raise Exception("If tar file provided, please set tar_fname.")

251
            f = tarfile.open(fpaths[0], "r")
Y
Yu Yang 已提交
252
            for line in f.extractfile(tar_fname):
253 254 255 256
                fields = line.strip("\n").split(self._field_delimiter)
                if (not self._only_src and len(fields) == 2) or (
                        self._only_src and len(fields) == 1):
                    yield fields
257 258 259 260 261
        else:
            for fpath in fpaths:
                if not os.path.isfile(fpath):
                    raise IOError("Invalid file: %s" % fpath)

262
                with open(fpath, "r") as f:
Y
Yu Yang 已提交
263
                    for line in f:
264 265 266 267
                        fields = line.strip("\n").split(self._field_delimiter)
                        if (not self._only_src and len(fields) == 2) or (
                                self._only_src and len(fields) == 1):
                            yield fields
268

269 270
    @staticmethod
    def load_dict(dict_path, reverse=False):
271 272 273 274
        word_dict = {}
        with open(dict_path, "r") as fdict:
            for idx, line in enumerate(fdict):
                if reverse:
275
                    word_dict[idx] = line.strip("\n")
276
                else:
277
                    word_dict[line.strip("\n")] = idx
278 279
        return word_dict

Y
Yu Yang 已提交
280 281
    def batch_generator(self):
        # global sort or global shuffle
282
        if self._sort_type == SortType.GLOBAL:
Y
Yu Yang 已提交
283 284
            infos = sorted(
                self._sample_infos, key=lambda x: x.max_len, reverse=True)
Y
Yu Yang 已提交
285
        else:
Y
Yu Yang 已提交
286 287 288 289 290 291 292 293
            if self._shuffle:
                infos = self._sample_infos
                self._random.shuffle(infos)
            else:
                infos = self._sample_infos

            if self._sort_type == SortType.POOL:
                for i in range(0, len(infos), self._pool_size):
G
guosheng 已提交
294 295
                    infos[i:i + self._pool_size] = sorted(
                        infos[i:i + self._pool_size], key=lambda x: x.max_len)
296

Y
Yu Yang 已提交
297 298 299 300 301 302 303
        # concat batch
        batches = []
        batch_creator = TokenBatchCreator(
            self._batch_size
        ) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
        batch_creator = MinMaxFilter(self._max_length, self._min_length,
                                     batch_creator)
304

Y
Yu Yang 已提交
305 306 307
        for info in infos:
            batch = batch_creator.append(info)
            if batch is not None:
Y
Yu Yang 已提交
308
                batches.append(batch)
309

Y
Yu Yang 已提交
310
        if not self._clip_last_batch and len(batch_creator.batch) != 0:
Y
Yu Yang 已提交
311
            batches.append(batch_creator.batch)
312

Y
Yu Yang 已提交
313
        if self._shuffle_batch:
Y
Yu Yang 已提交
314
            self._random.shuffle(batches)
315

Y
Yu Yang 已提交
316
        for batch in batches:
Y
Yu Yang 已提交
317
            batch_ids = [info.i for info in batch]
Y
Yu Yang 已提交
318

Y
Yu Yang 已提交
319 320 321 322 323
            if self._only_src:
                yield [[self._src_seq_ids[idx]] for idx in batch_ids]
            else:
                yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
                        self._trg_seq_ids[idx][1:]) for idx in batch_ids]