reader.py 10.3 KB
Newer Older
1
import glob
Y
Yu Yang 已提交
2
import os
3
import random
Y
Yu Yang 已提交
4 5
import tarfile
import time
6 7 8 9 10 11 12 13


class SortType(object):
    GLOBAL = 'global'
    POOL = 'pool'
    NONE = "none"


Y
Yu Yang 已提交
14 15 16 17 18 19
class Converter(object):
    def __init__(self, vocab, beg, end, unk):
        self._vocab = vocab
        self._beg = beg
        self._end = end
        self._unk = unk
20

Y
Yu Yang 已提交
21 22 23 24
    def __call__(self, sentence):
        return [self._beg] + [
            self._vocab.get(w, self._unk) for w in sentence.split()
        ] + [self._end]
25

Y
Yu Yang 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

class ComposedConverter(object):
    def __init__(self, converters):
        self._converters = converters

    def __call__(self, parallel_sentence):
        return [
            self._converters[i](parallel_sentence[i])
            for i in range(len(self._converters))
        ]


class SentenceBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self._batch_size = batch_size

    def append(self, info):
        self.batch.append(info)
        if len(self.batch) == self._batch_size:
            tmp = self.batch
            self.batch = []
            return tmp


class TokenBatchCreator(object):
    def __init__(self, batch_size):
        self.batch = []
        self.max_len = -1
        self._batch_size = batch_size

    def append(self, info):
        cur_len = info.max_len
        max_len = max(self.max_len, cur_len)
        if max_len * (len(self.batch) + 1) > self._batch_size:
            result = self.batch
            self.batch = [info]
            self.max_len = cur_len
            return result
        else:
            self.max_len = max_len
            self.batch.append(info)


class SampleInfo(object):
    def __init__(self, i, max_len, min_len):
        self.i = i
        self.min_len = min_len
        self.max_len = max_len


class MinMaxFilter(object):
    def __init__(self, max_len, min_len, underlying_creator):
        self._min_len = min_len
        self._max_len = max_len
        self._creator = underlying_creator

    def append(self, info):
        if info.max_len > self._max_len or info.min_len < self._min_len:
            return
86
        else:
Y
Yu Yang 已提交
87 88 89 90 91
            return self._creator.append(info)

    @property
    def batch(self):
        return self._creator.batch
92 93 94 95 96


class DataReader(object):
    """
    The data reader loads all data from files and produces batches of data
97
    in the way corresponding to settings.
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

    An example of returning a generator producing data batches whose data
    is shuffled in each pass and sorted in each pool:

    ```
    train_data = DataReader(
        src_vocab_fpath='data/src_vocab_file',
        trg_vocab_fpath='data/trg_vocab_file',
        fpattern='data/part-*',
        use_token_batch=True,
        batch_size=2000,
        pool_size=10000,
        sort_type=SortType.POOL,
        shuffle=True,
        shuffle_batch=True,
        start_mark='<s>',
        end_mark='<e>',
        unk_mark='<unk>',
        clip_last_batch=False).batch_generator
    ```
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161

    :param src_vocab_fpath: The path of vocabulary file of source language.
    :type src_vocab_fpath: basestring
    :param trg_vocab_fpath: The path of vocabulary file of target language.
    :type trg_vocab_fpath: basestring
    :param fpattern: The pattern to match data files.
    :type fpattern: basestring
    :param batch_size: The number of sequences contained in a mini-batch.
        or the maximum number of tokens (include paddings) contained in a
        mini-batch.
    :type batch_size: int
    :param pool_size: The size of pool buffer.
    :type pool_size: int
    :param sort_type: The grain to sort by length: 'global' for all
        instances; 'pool' for instances in pool; 'none' for no sort.
    :type sort_type: basestring
    :param clip_last_batch: Whether to clip the last uncompleted batch.
    :type clip_last_batch: bool
    :param tar_fname: The data file in tar if fpattern matches a tar file.
    :type tar_fname: basestring
    :param min_length: The minimum length used to filt sequences.
    :type min_length: int
    :param max_length: The maximum length used to filt sequences.
    :type max_length: int
    :param shuffle: Whether to shuffle all instances.
    :type shuffle: bool
    :param shuffle_batch: Whether to shuffle the generated batches.
    :type shuffle_batch: bool
    :param use_token_batch: Whether to produce batch data according to
        token number.
    :type use_token_batch: bool
    :param delimiter: The delimiter used to split source and target in each
        line of data file.
    :type delimiter: basestring
    :param start_mark: The token representing for the beginning of
        sentences in dictionary.
    :type start_mark: basestring
    :param end_mark: The token representing for the end of sentences
        in dictionary.
    :type end_mark: basestring
    :param unk_mark: The token representing for unknown word in dictionary.
    :type unk_mark: basestring
    :param seed: The seed for random.
    :type seed: int
162 163 164 165 166 167 168 169
    """

    def __init__(self,
                 src_vocab_fpath,
                 trg_vocab_fpath,
                 fpattern,
                 batch_size,
                 pool_size,
Y
Yu Yang 已提交
170
                 sort_type=SortType.GLOBAL,
171 172 173 174 175 176 177 178 179 180 181 182
                 clip_last_batch=True,
                 tar_fname=None,
                 min_length=0,
                 max_length=100,
                 shuffle=True,
                 shuffle_batch=False,
                 use_token_batch=False,
                 delimiter="\t",
                 start_mark="<s>",
                 end_mark="<e>",
                 unk_mark="<unk>",
                 seed=0):
183
        self._src_vocab = self.load_dict(src_vocab_fpath)
184 185
        self._only_src = True
        if trg_vocab_fpath is not None:
186
            self._trg_vocab = self.load_dict(trg_vocab_fpath)
187 188 189 190 191 192 193 194 195 196 197
            self._only_src = False
        self._pool_size = pool_size
        self._batch_size = batch_size
        self._use_token_batch = use_token_batch
        self._sort_type = sort_type
        self._clip_last_batch = clip_last_batch
        self._shuffle = shuffle
        self._shuffle_batch = shuffle_batch
        self._min_length = min_length
        self._max_length = max_length
        self._delimiter = delimiter
Y
Yu Yang 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210
        self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
                              unk_mark)
        self._random = random.Random(x=seed)

    def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
                         unk_mark):
        converters = [
            Converter(
                vocab=self._src_vocab,
                beg=self._src_vocab[start_mark],
                end=self._src_vocab[end_mark],
                unk=self._src_vocab[unk_mark])
        ]
211
        if not self._only_src:
Y
Yu Yang 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
            converters.append(
                Converter(
                    vocab=self._trg_vocab,
                    beg=self._trg_vocab[start_mark],
                    end=self._trg_vocab[end_mark],
                    unk=self._trg_vocab[unk_mark]))

        converters = ComposedConverter(converters)

        self._src_seq_ids = []
        self._trg_seq_ids = None if self._only_src else []
        self._sample_infos = []

        for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
            src_trg_ids = converters(line)
            self._src_seq_ids.append(src_trg_ids[0])
            lens = [len(src_trg_ids[0])]
229
            if not self._only_src:
Y
Yu Yang 已提交
230 231 232
                self._trg_seq_ids.append(src_trg_ids[1])
                lens.append(len(src_trg_ids[1]))
            self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
233

Y
Yu Yang 已提交
234
    def _load_lines(self, fpattern, tar_fname):
235 236 237 238 239 240 241
        fpaths = glob.glob(fpattern)

        if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
            if tar_fname is None:
                raise Exception("If tar file provided, please set tar_fname.")

            f = tarfile.open(fpaths[0], 'r')
Y
Yu Yang 已提交
242
            for line in f.extractfile(tar_fname):
Y
Yu Yang 已提交
243
                yield line.split(self._delimiter)
244 245 246 247 248
        else:
            for fpath in fpaths:
                if not os.path.isfile(fpath):
                    raise IOError("Invalid file: %s" % fpath)

Y
Yu Yang 已提交
249 250
                with open(fpath, 'r') as f:
                    for line in f:
Y
Yu Yang 已提交
251
                        yield line.split(self._delimiter)
252

253 254
    @staticmethod
    def load_dict(dict_path, reverse=False):
255 256 257 258 259 260 261 262 263
        word_dict = {}
        with open(dict_path, "r") as fdict:
            for idx, line in enumerate(fdict):
                if reverse:
                    word_dict[idx] = line.strip()
                else:
                    word_dict[line.strip()] = idx
        return word_dict

Y
Yu Yang 已提交
264 265
    def batch_generator(self):
        # global sort or global shuffle
266
        if self._sort_type == SortType.GLOBAL:
Y
Yu Yang 已提交
267 268
            infos = sorted(
                self._sample_infos, key=lambda x: x.max_len, reverse=True)
Y
Yu Yang 已提交
269
        else:
Y
Yu Yang 已提交
270 271 272 273 274 275 276 277 278 279 280 281
            if self._shuffle:
                infos = self._sample_infos
                self._random.shuffle(infos)
            else:
                infos = self._sample_infos

            if self._sort_type == SortType.POOL:
                for i in range(0, len(infos), self._pool_size):
                    infos[i * self._pool_size:(i + 1) *
                          self._pool_size] = sorted(
                              infos[i * self._pool_size:(i + 1) *
                                    self._pool_size],
Y
Yu Yang 已提交
282 283
                              key=lambda x: x.max_len,
                              reverse=True)
284

Y
Yu Yang 已提交
285 286 287 288 289 290 291
        # concat batch
        batches = []
        batch_creator = TokenBatchCreator(
            self._batch_size
        ) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
        batch_creator = MinMaxFilter(self._max_length, self._min_length,
                                     batch_creator)
292

Y
Yu Yang 已提交
293 294 295
        for info in infos:
            batch = batch_creator.append(info)
            if batch is not None:
Y
Yu Yang 已提交
296
                batches.append(batch)
297

Y
Yu Yang 已提交
298
        if not self._clip_last_batch and len(batch_creator.batch) != 0:
Y
Yu Yang 已提交
299
            batches.append(batch_creator.batch)
300

Y
Yu Yang 已提交
301
        if self._shuffle_batch:
Y
Yu Yang 已提交
302
            self._random.shuffle(batches)
303

Y
Yu Yang 已提交
304
        for batch in batches:
Y
Yu Yang 已提交
305
            batch_ids = [info.i for info in batch]
Y
Yu Yang 已提交
306

Y
Yu Yang 已提交
307 308 309 310 311
            if self._only_src:
                yield [[self._src_seq_ids[idx]] for idx in batch_ids]
            else:
                yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
                        self._trg_seq_ids[idx][1:]) for idx in batch_ids]