reader.py 11.1 KB
Newer Older
1
import glob
2
import six
Y
Yu Yang 已提交
3 4
import os
import tarfile
5

6 7
import numpy as np

8 9 10 11 12 13 14

class SortType(object):
    GLOBAL = 'global'
    POOL = 'pool'
    NONE = "none"


Y
Yu Yang 已提交
15
class Converter(object):
16
    def __init__(self, vocab, beg, end, unk, delimiter):
Y
Yu Yang 已提交
17 18 19 20
        self._vocab = vocab
        self._beg = beg
        self._end = end
        self._unk = unk
21
        self._delimiter = delimiter
22

Y
Yu Yang 已提交
23 24
    def __call__(self, sentence):
        return [self._beg] + [
25 26
            self._vocab.get(w, self._unk)
            for w in sentence.split(self._delimiter)
Y
Yu Yang 已提交
27
        ] + [self._end]
28

Y
Yu Yang 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

class ComposedConverter(object):
    def __init__(self, converters):
        self._converters = converters

    def __call__(self, parallel_sentence):
        return [
            self._converters[i](parallel_sentence[i])
            for i in range(len(self._converters))
        ]


class SentenceBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self._batch_size = batch_size

    def append(self, info):
        self.batch.append(info)
        if len(self.batch) == self._batch_size:
            tmp = self.batch
            self.batch = []
            return tmp


class TokenBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self.max_len = -1
        self._batch_size = batch_size

    def append(self, info):
        cur_len = info.max_len
        max_len = max(self.max_len, cur_len)
        if max_len * (len(self.batch) + 1) > self._batch_size:
            result = self.batch
            self.batch = [info]
            self.max_len = cur_len
            return result
        else:
            self.max_len = max_len
            self.batch.append(info)


class SampleInfo(object):
    def __init__(self, i, max_len, min_len):
        self.i = i
        self.min_len = min_len
        self.max_len = max_len


class MinMaxFilter(object):
    def __init__(self, max_len, min_len, underlying_creator):
        self._min_len = min_len
        self._max_len = max_len
        self._creator = underlying_creator

    def append(self, info):
        if info.max_len > self._max_len or info.min_len < self._min_len:
            return
89
        else:
Y
Yu Yang 已提交
90 91 92 93 94
            return self._creator.append(info)

    @property
    def batch(self):
        return self._creator.batch
95 96 97 98 99


class DataReader(object):
    """
    The data reader loads all data from files and produces batches of data
100
    in the way corresponding to settings.
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

    An example of returning a generator producing data batches whose data
    is shuffled in each pass and sorted in each pool:

    ```
    train_data = DataReader(
        src_vocab_fpath='data/src_vocab_file',
        trg_vocab_fpath='data/trg_vocab_file',
        fpattern='data/part-*',
        use_token_batch=True,
        batch_size=2000,
        pool_size=10000,
        sort_type=SortType.POOL,
        shuffle=True,
        shuffle_batch=True,
        start_mark='<s>',
        end_mark='<e>',
        unk_mark='<unk>',
        clip_last_batch=False).batch_generator
    ```
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151

    :param src_vocab_fpath: The path of vocabulary file of source language.
    :type src_vocab_fpath: basestring
    :param trg_vocab_fpath: The path of vocabulary file of target language.
    :type trg_vocab_fpath: basestring
    :param fpattern: The pattern to match data files.
    :type fpattern: basestring
    :param batch_size: The number of sequences contained in a mini-batch.
        or the maximum number of tokens (include paddings) contained in a
        mini-batch.
    :type batch_size: int
    :param pool_size: The size of pool buffer.
    :type pool_size: int
    :param sort_type: The grain to sort by length: 'global' for all
        instances; 'pool' for instances in pool; 'none' for no sort.
    :type sort_type: basestring
    :param clip_last_batch: Whether to clip the last uncompleted batch.
    :type clip_last_batch: bool
    :param tar_fname: The data file in tar if fpattern matches a tar file.
    :type tar_fname: basestring
    :param min_length: The minimum length used to filt sequences.
    :type min_length: int
    :param max_length: The maximum length used to filt sequences.
    :type max_length: int
    :param shuffle: Whether to shuffle all instances.
    :type shuffle: bool
    :param shuffle_batch: Whether to shuffle the generated batches.
    :type shuffle_batch: bool
    :param use_token_batch: Whether to produce batch data according to
        token number.
    :type use_token_batch: bool
152 153 154 155 156 157
    :param field_delimiter: The delimiter used to split source and target in
        each line of data file.
    :type field_delimiter: basestring
    :param token_delimiter: The delimiter used to split tokens in source or
        target sentences.
    :type token_delimiter: basestring
158 159 160 161 162 163 164 165 166 167
    :param start_mark: The token representing for the beginning of
        sentences in dictionary.
    :type start_mark: basestring
    :param end_mark: The token representing for the end of sentences
        in dictionary.
    :type end_mark: basestring
    :param unk_mark: The token representing for unknown word in dictionary.
    :type unk_mark: basestring
    :param seed: The seed for random.
    :type seed: int
168 169 170 171 172 173 174 175
    """

    def __init__(self,
                 src_vocab_fpath,
                 trg_vocab_fpath,
                 fpattern,
                 batch_size,
                 pool_size,
Y
Yu Yang 已提交
176
                 sort_type=SortType.GLOBAL,
177 178 179 180 181 182 183
                 clip_last_batch=True,
                 tar_fname=None,
                 min_length=0,
                 max_length=100,
                 shuffle=True,
                 shuffle_batch=False,
                 use_token_batch=False,
184 185
                 field_delimiter="\t",
                 token_delimiter=" ",
186 187 188 189
                 start_mark="<s>",
                 end_mark="<e>",
                 unk_mark="<unk>",
                 seed=0):
190
        self._src_vocab = self.load_dict(src_vocab_fpath)
191 192
        self._only_src = True
        if trg_vocab_fpath is not None:
193
            self._trg_vocab = self.load_dict(trg_vocab_fpath)
194 195 196 197 198 199 200 201 202 203
            self._only_src = False
        self._pool_size = pool_size
        self._batch_size = batch_size
        self._use_token_batch = use_token_batch
        self._sort_type = sort_type
        self._clip_last_batch = clip_last_batch
        self._shuffle = shuffle
        self._shuffle_batch = shuffle_batch
        self._min_length = min_length
        self._max_length = max_length
204 205
        self._field_delimiter = field_delimiter
        self._token_delimiter = token_delimiter
206 207
        self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
                              unk_mark)
208 209
        self._random = np.random
        self._random.seed(seed)
Y
Yu Yang 已提交
210 211 212 213 214 215 216 217

    def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
                         unk_mark):
        converters = [
            Converter(
                vocab=self._src_vocab,
                beg=self._src_vocab[start_mark],
                end=self._src_vocab[end_mark],
218 219
                unk=self._src_vocab[unk_mark],
                delimiter=self._token_delimiter)
Y
Yu Yang 已提交
220
        ]
221
        if not self._only_src:
Y
Yu Yang 已提交
222 223 224 225 226
            converters.append(
                Converter(
                    vocab=self._trg_vocab,
                    beg=self._trg_vocab[start_mark],
                    end=self._trg_vocab[end_mark],
227 228
                    unk=self._trg_vocab[unk_mark],
                    delimiter=self._token_delimiter))
Y
Yu Yang 已提交
229 230 231 232 233 234 235 236 237 238 239

        converters = ComposedConverter(converters)

        self._src_seq_ids = []
        self._trg_seq_ids = None if self._only_src else []
        self._sample_infos = []

        for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
            src_trg_ids = converters(line)
            self._src_seq_ids.append(src_trg_ids[0])
            lens = [len(src_trg_ids[0])]
240
            if not self._only_src:
Y
Yu Yang 已提交
241 242 243
                self._trg_seq_ids.append(src_trg_ids[1])
                lens.append(len(src_trg_ids[1]))
            self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
244

Y
Yu Yang 已提交
245
    def _load_lines(self, fpattern, tar_fname):
246 247 248 249 250 251
        fpaths = glob.glob(fpattern)

        if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
            if tar_fname is None:
                raise Exception("If tar file provided, please set tar_fname.")

252
            f = tarfile.open(fpaths[0], "r")
Y
Yu Yang 已提交
253
            for line in f.extractfile(tar_fname):
254 255 256 257
                fields = line.strip("\n").split(self._field_delimiter)
                if (not self._only_src and len(fields) == 2) or (
                        self._only_src and len(fields) == 1):
                    yield fields
258 259 260 261 262
        else:
            for fpath in fpaths:
                if not os.path.isfile(fpath):
                    raise IOError("Invalid file: %s" % fpath)

263
                with open(fpath, "rb") as f:
Y
Yu Yang 已提交
264
                    for line in f:
265 266
                        if six.PY3:
                            line = line.decode()
267 268 269 270
                        fields = line.strip("\n").split(self._field_delimiter)
                        if (not self._only_src and len(fields) == 2) or (
                                self._only_src and len(fields) == 1):
                            yield fields
271

272 273
    @staticmethod
    def load_dict(dict_path, reverse=False):
274
        word_dict = {}
275
        with open(dict_path, "rb") as fdict:
276
            for idx, line in enumerate(fdict):
277 278
                if six.PY3:
                    line = line.decode()
279
                if reverse:
280
                    word_dict[idx] = line.strip("\n")
281
                else:
282
                    word_dict[line.strip("\n")] = idx
283 284
        return word_dict

Y
Yu Yang 已提交
285 286
    def batch_generator(self):
        # global sort or global shuffle
287
        if self._sort_type == SortType.GLOBAL:
Y
Yu Yang 已提交
288 289
            infos = sorted(
                self._sample_infos, key=lambda x: x.max_len, reverse=True)
Y
Yu Yang 已提交
290
        else:
Y
Yu Yang 已提交
291 292 293 294 295 296 297 298
            if self._shuffle:
                infos = self._sample_infos
                self._random.shuffle(infos)
            else:
                infos = self._sample_infos

            if self._sort_type == SortType.POOL:
                for i in range(0, len(infos), self._pool_size):
G
guosheng 已提交
299 300
                    infos[i:i + self._pool_size] = sorted(
                        infos[i:i + self._pool_size], key=lambda x: x.max_len)
301

Y
Yu Yang 已提交
302 303 304 305 306 307 308
        # concat batch
        batches = []
        batch_creator = TokenBatchCreator(
            self._batch_size
        ) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
        batch_creator = MinMaxFilter(self._max_length, self._min_length,
                                     batch_creator)
309

Y
Yu Yang 已提交
310 311 312
        for info in infos:
            batch = batch_creator.append(info)
            if batch is not None:
Y
Yu Yang 已提交
313
                batches.append(batch)
314

Y
Yu Yang 已提交
315
        if not self._clip_last_batch and len(batch_creator.batch) != 0:
Y
Yu Yang 已提交
316
            batches.append(batch_creator.batch)
317

Y
Yu Yang 已提交
318
        if self._shuffle_batch:
Y
Yu Yang 已提交
319
            self._random.shuffle(batches)
320

Y
Yu Yang 已提交
321
        for batch in batches:
Y
Yu Yang 已提交
322
            batch_ids = [info.i for info in batch]
Y
Yu Yang 已提交
323

Y
Yu Yang 已提交
324 325 326 327 328
            if self._only_src:
                yield [[self._src_seq_ids[idx]] for idx in batch_ids]
            else:
                yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
                        self._trg_seq_ids[idx][1:]) for idx in batch_ids]