reader.py 10.9 KB
Newer Older
1
import glob
Y
Yu Yang 已提交
2
import os
3
import random
Y
Yu Yang 已提交
4
import tarfile
5 6 7 8 9 10 11 12


class SortType(object):
    GLOBAL = 'global'
    POOL = 'pool'
    NONE = "none"


Y
Yu Yang 已提交
13
class Converter(object):
14
    def __init__(self, vocab, beg, end, unk, delimiter):
Y
Yu Yang 已提交
15 16 17 18
        self._vocab = vocab
        self._beg = beg
        self._end = end
        self._unk = unk
19
        self._delimiter = delimiter
20

Y
Yu Yang 已提交
21 22
    def __call__(self, sentence):
        return [self._beg] + [
23 24
            self._vocab.get(w, self._unk)
            for w in sentence.split(self._delimiter)
Y
Yu Yang 已提交
25
        ] + [self._end]
26

Y
Yu Yang 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86

class ComposedConverter(object):
    def __init__(self, converters):
        self._converters = converters

    def __call__(self, parallel_sentence):
        return [
            self._converters[i](parallel_sentence[i])
            for i in range(len(self._converters))
        ]


class SentenceBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self._batch_size = batch_size

    def append(self, info):
        self.batch.append(info)
        if len(self.batch) == self._batch_size:
            tmp = self.batch
            self.batch = []
            return tmp


class TokenBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self.max_len = -1
        self._batch_size = batch_size

    def append(self, info):
        cur_len = info.max_len
        max_len = max(self.max_len, cur_len)
        if max_len * (len(self.batch) + 1) > self._batch_size:
            result = self.batch
            self.batch = [info]
            self.max_len = cur_len
            return result
        else:
            self.max_len = max_len
            self.batch.append(info)


class SampleInfo(object):
    def __init__(self, i, max_len, min_len):
        self.i = i
        self.min_len = min_len
        self.max_len = max_len


class MinMaxFilter(object):
    def __init__(self, max_len, min_len, underlying_creator):
        self._min_len = min_len
        self._max_len = max_len
        self._creator = underlying_creator

    def append(self, info):
        if info.max_len > self._max_len or info.min_len < self._min_len:
            return
87
        else:
Y
Yu Yang 已提交
88 89 90 91 92
            return self._creator.append(info)

    @property
    def batch(self):
        return self._creator.batch
93 94 95 96 97


class DataReader(object):
    """
    The data reader loads all data from files and produces batches of data
98
    in the way corresponding to settings.
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118

    An example of returning a generator producing data batches whose data
    is shuffled in each pass and sorted in each pool:

    ```
    train_data = DataReader(
        src_vocab_fpath='data/src_vocab_file',
        trg_vocab_fpath='data/trg_vocab_file',
        fpattern='data/part-*',
        use_token_batch=True,
        batch_size=2000,
        pool_size=10000,
        sort_type=SortType.POOL,
        shuffle=True,
        shuffle_batch=True,
        start_mark='<s>',
        end_mark='<e>',
        unk_mark='<unk>',
        clip_last_batch=False).batch_generator
    ```
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

    :param src_vocab_fpath: The path of vocabulary file of source language.
    :type src_vocab_fpath: basestring
    :param trg_vocab_fpath: The path of vocabulary file of target language.
    :type trg_vocab_fpath: basestring
    :param fpattern: The pattern to match data files.
    :type fpattern: basestring
    :param batch_size: The number of sequences contained in a mini-batch.
        or the maximum number of tokens (include paddings) contained in a
        mini-batch.
    :type batch_size: int
    :param pool_size: The size of pool buffer.
    :type pool_size: int
    :param sort_type: The grain to sort by length: 'global' for all
        instances; 'pool' for instances in pool; 'none' for no sort.
    :type sort_type: basestring
    :param clip_last_batch: Whether to clip the last uncompleted batch.
    :type clip_last_batch: bool
    :param tar_fname: The data file in tar if fpattern matches a tar file.
    :type tar_fname: basestring
    :param min_length: The minimum length used to filt sequences.
    :type min_length: int
    :param max_length: The maximum length used to filt sequences.
    :type max_length: int
    :param shuffle: Whether to shuffle all instances.
    :type shuffle: bool
    :param shuffle_batch: Whether to shuffle the generated batches.
    :type shuffle_batch: bool
    :param use_token_batch: Whether to produce batch data according to
        token number.
    :type use_token_batch: bool
150 151 152 153 154 155
    :param field_delimiter: The delimiter used to split source and target in
        each line of data file.
    :type field_delimiter: basestring
    :param token_delimiter: The delimiter used to split tokens in source or
        target sentences.
    :type token_delimiter: basestring
156 157 158 159 160 161 162 163 164 165
    :param start_mark: The token representing for the beginning of
        sentences in dictionary.
    :type start_mark: basestring
    :param end_mark: The token representing for the end of sentences
        in dictionary.
    :type end_mark: basestring
    :param unk_mark: The token representing for unknown word in dictionary.
    :type unk_mark: basestring
    :param seed: The seed for random.
    :type seed: int
166 167 168 169 170 171 172 173
    """

    def __init__(self,
                 src_vocab_fpath,
                 trg_vocab_fpath,
                 fpattern,
                 batch_size,
                 pool_size,
Y
Yu Yang 已提交
174
                 sort_type=SortType.GLOBAL,
175 176 177 178 179 180 181
                 clip_last_batch=True,
                 tar_fname=None,
                 min_length=0,
                 max_length=100,
                 shuffle=True,
                 shuffle_batch=False,
                 use_token_batch=False,
182 183
                 field_delimiter="\t",
                 token_delimiter=" ",
184 185 186 187
                 start_mark="<s>",
                 end_mark="<e>",
                 unk_mark="<unk>",
                 seed=0):
188
        self._src_vocab = self.load_dict(src_vocab_fpath)
189 190
        self._only_src = True
        if trg_vocab_fpath is not None:
191
            self._trg_vocab = self.load_dict(trg_vocab_fpath)
192 193 194 195 196 197 198 199 200 201
            self._only_src = False
        self._pool_size = pool_size
        self._batch_size = batch_size
        self._use_token_batch = use_token_batch
        self._sort_type = sort_type
        self._clip_last_batch = clip_last_batch
        self._shuffle = shuffle
        self._shuffle_batch = shuffle_batch
        self._min_length = min_length
        self._max_length = max_length
202 203
        self._field_delimiter = field_delimiter
        self._token_delimiter = token_delimiter
204 205
        self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
                              unk_mark)
Y
Yu Yang 已提交
206 207 208 209 210 211 212 213 214
        self._random = random.Random(x=seed)

    def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
                         unk_mark):
        converters = [
            Converter(
                vocab=self._src_vocab,
                beg=self._src_vocab[start_mark],
                end=self._src_vocab[end_mark],
215 216
                unk=self._src_vocab[unk_mark],
                delimiter=self._token_delimiter)
Y
Yu Yang 已提交
217
        ]
218
        if not self._only_src:
Y
Yu Yang 已提交
219 220 221 222 223
            converters.append(
                Converter(
                    vocab=self._trg_vocab,
                    beg=self._trg_vocab[start_mark],
                    end=self._trg_vocab[end_mark],
224 225
                    unk=self._trg_vocab[unk_mark],
                    delimiter=self._token_delimiter))
Y
Yu Yang 已提交
226 227 228 229 230 231 232 233 234 235 236

        converters = ComposedConverter(converters)

        self._src_seq_ids = []
        self._trg_seq_ids = None if self._only_src else []
        self._sample_infos = []

        for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
            src_trg_ids = converters(line)
            self._src_seq_ids.append(src_trg_ids[0])
            lens = [len(src_trg_ids[0])]
237
            if not self._only_src:
Y
Yu Yang 已提交
238 239 240
                self._trg_seq_ids.append(src_trg_ids[1])
                lens.append(len(src_trg_ids[1]))
            self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
241

Y
Yu Yang 已提交
242
    def _load_lines(self, fpattern, tar_fname):
243 244 245 246 247 248
        fpaths = glob.glob(fpattern)

        if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
            if tar_fname is None:
                raise Exception("If tar file provided, please set tar_fname.")

249
            f = tarfile.open(fpaths[0], "r")
Y
Yu Yang 已提交
250
            for line in f.extractfile(tar_fname):
251 252 253 254
                fields = line.strip("\n").split(self._field_delimiter)
                if (not self._only_src and len(fields) == 2) or (
                        self._only_src and len(fields) == 1):
                    yield fields
255 256 257 258 259
        else:
            for fpath in fpaths:
                if not os.path.isfile(fpath):
                    raise IOError("Invalid file: %s" % fpath)

260
                with open(fpath, "r") as f:
Y
Yu Yang 已提交
261
                    for line in f:
262 263 264 265
                        fields = line.strip("\n").split(self._field_delimiter)
                        if (not self._only_src and len(fields) == 2) or (
                                self._only_src and len(fields) == 1):
                            yield fields
266

267 268
    @staticmethod
    def load_dict(dict_path, reverse=False):
269 270 271 272
        word_dict = {}
        with open(dict_path, "r") as fdict:
            for idx, line in enumerate(fdict):
                if reverse:
273
                    word_dict[idx] = line.strip("\n")
274
                else:
275
                    word_dict[line.strip("\n")] = idx
276 277
        return word_dict

Y
Yu Yang 已提交
278 279
    def batch_generator(self):
        # global sort or global shuffle
280
        if self._sort_type == SortType.GLOBAL:
Y
Yu Yang 已提交
281 282
            infos = sorted(
                self._sample_infos, key=lambda x: x.max_len, reverse=True)
Y
Yu Yang 已提交
283
        else:
Y
Yu Yang 已提交
284 285 286 287 288 289 290 291
            if self._shuffle:
                infos = self._sample_infos
                self._random.shuffle(infos)
            else:
                infos = self._sample_infos

            if self._sort_type == SortType.POOL:
                for i in range(0, len(infos), self._pool_size):
G
guosheng 已提交
292 293
                    infos[i:i + self._pool_size] = sorted(
                        infos[i:i + self._pool_size], key=lambda x: x.max_len)
294

Y
Yu Yang 已提交
295 296 297 298 299 300 301
        # concat batch
        batches = []
        batch_creator = TokenBatchCreator(
            self._batch_size
        ) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
        batch_creator = MinMaxFilter(self._max_length, self._min_length,
                                     batch_creator)
302

Y
Yu Yang 已提交
303 304 305
        for info in infos:
            batch = batch_creator.append(info)
            if batch is not None:
Y
Yu Yang 已提交
306
                batches.append(batch)
307

Y
Yu Yang 已提交
308
        if not self._clip_last_batch and len(batch_creator.batch) != 0:
Y
Yu Yang 已提交
309
            batches.append(batch_creator.batch)
310

Y
Yu Yang 已提交
311
        if self._shuffle_batch:
Y
Yu Yang 已提交
312
            self._random.shuffle(batches)
313

Y
Yu Yang 已提交
314
        for batch in batches:
Y
Yu Yang 已提交
315
            batch_ids = [info.i for info in batch]
Y
Yu Yang 已提交
316

Y
Yu Yang 已提交
317 318 319 320 321
            if self._only_src:
                yield [[self._src_seq_ids[idx]] for idx in batch_ids]
            else:
                yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
                        self._trg_seq_ids[idx][1:]) for idx in batch_ids]