reader.py 12.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
import os
import tarfile
import glob

import random


class SortType(object):
    GLOBAL = 'global'
    POOL = 'pool'
    NONE = "none"


class EndEpoch():
    pass


class Pool(object):
    def __init__(self, sample_generator, pool_size, sort):
        self._pool_size = pool_size
        self._pool = []
        self._sample_generator = sample_generator()
        self._end = False
        self._sort = sort

    def _fill(self):
        while len(self._pool) < self._pool_size and not self._end:
            try:
                sample = self._sample_generator.next()
                self._pool.append(sample)
            except StopIteration as e:
                self._end = True
                break

        if self._sort:
            self._pool.sort(
37 38
                key=lambda sample: max(len(sample[0]), len(sample[1])) \
                if len(sample) > 1 else len(sample[0])
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
            )

        if self._end and len(self._pool) < self._pool_size:
            self._pool.append(EndEpoch())

    def push_back(self, samples):
        if len(self._pool) != 0:
            raise Exception("Pool should be empty.")

        if len(samples) >= self._pool_size:
            raise Exception("Capacity of pool should be greater than a batch. "
                            "Please enlarge `pool_size`.")

        for sample in samples:
            self._pool.append(sample)

        self._fill()

    def next(self, look=False):
        if len(self._pool) == 0:
            return None
        else:
            return self._pool[0] if look else self._pool.pop(0)


class DataReader(object):
    """
    The data reader loads all data from files and produces batches of data
67
    in the way corresponding to settings.
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

    An example of returning a generator producing data batches whose data
    is shuffled in each pass and sorted in each pool:

    ```
    train_data = DataReader(
        src_vocab_fpath='data/src_vocab_file',
        trg_vocab_fpath='data/trg_vocab_file',
        fpattern='data/part-*',
        use_token_batch=True,
        batch_size=2000,
        pool_size=10000,
        sort_type=SortType.POOL,
        shuffle=True,
        shuffle_batch=True,
        start_mark='<s>',
        end_mark='<e>',
        unk_mark='<unk>',
        clip_last_batch=False).batch_generator
    ```
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118

    :param src_vocab_fpath: The path of vocabulary file of source language.
    :type src_vocab_fpath: basestring
    :param trg_vocab_fpath: The path of vocabulary file of target language.
    :type trg_vocab_fpath: basestring
    :param fpattern: The pattern to match data files.
    :type fpattern: basestring
    :param batch_size: The number of sequences contained in a mini-batch.
        or the maximum number of tokens (include paddings) contained in a
        mini-batch.
    :type batch_size: int
    :param pool_size: The size of pool buffer.
    :type pool_size: int
    :param sort_type: The grain to sort by length: 'global' for all
        instances; 'pool' for instances in pool; 'none' for no sort.
    :type sort_type: basestring
    :param clip_last_batch: Whether to clip the last uncompleted batch.
    :type clip_last_batch: bool
    :param tar_fname: The data file in tar if fpattern matches a tar file.
    :type tar_fname: basestring
    :param min_length: The minimum length used to filt sequences.
    :type min_length: int
    :param max_length: The maximum length used to filt sequences.
    :type max_length: int
    :param shuffle: Whether to shuffle all instances.
    :type shuffle: bool
    :param shuffle_batch: Whether to shuffle the generated batches.
    :type shuffle_batch: bool
    :param use_token_batch: Whether to produce batch data according to
        token number.
    :type use_token_batch: bool
119 120 121 122 123 124
    :param field_delimiter: The delimiter used to split source and target in
        each line of data file.
    :type field_delimiter: basestring
    :param token_delimiter: The delimiter used to split tokens in source or
        target sentences.
    :type token_delimiter: basestring
125 126 127 128 129 130 131 132 133 134
    :param start_mark: The token representing for the beginning of
        sentences in dictionary.
    :type start_mark: basestring
    :param end_mark: The token representing for the end of sentences
        in dictionary.
    :type end_mark: basestring
    :param unk_mark: The token representing for unknown word in dictionary.
    :type unk_mark: basestring
    :param seed: The seed for random.
    :type seed: int
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
    """

    def __init__(self,
                 src_vocab_fpath,
                 trg_vocab_fpath,
                 fpattern,
                 batch_size,
                 pool_size,
                 sort_type=SortType.NONE,
                 clip_last_batch=True,
                 tar_fname=None,
                 min_length=0,
                 max_length=100,
                 shuffle=True,
                 shuffle_batch=False,
                 use_token_batch=False,
151 152
                 field_delimiter="\t",
                 token_delimiter=" ",
153 154 155 156
                 start_mark="<s>",
                 end_mark="<e>",
                 unk_mark="<unk>",
                 seed=0):
157
        self._src_vocab = self.load_dict(src_vocab_fpath)
158 159
        self._only_src = True
        if trg_vocab_fpath is not None:
160
            self._trg_vocab = self.load_dict(trg_vocab_fpath)
161 162 163 164 165 166 167 168 169 170
            self._only_src = False
        self._pool_size = pool_size
        self._batch_size = batch_size
        self._use_token_batch = use_token_batch
        self._sort_type = sort_type
        self._clip_last_batch = clip_last_batch
        self._shuffle = shuffle
        self._shuffle_batch = shuffle_batch
        self._min_length = min_length
        self._max_length = max_length
171 172
        self._field_delimiter = field_delimiter
        self._token_delimiter = token_delimiter
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
        self._epoch_batches = []

        src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname)
        self._src_seq_ids = [[
            self._src_vocab.get(word, self._src_vocab.get(unk_mark))
            for word in ([start_mark] + src_seq + [end_mark])
        ] for src_seq in src_seq_words]

        self._sample_count = len(self._src_seq_ids)

        if not self._only_src:
            self._trg_seq_ids = [[
                self._trg_vocab.get(word, self._trg_vocab.get(unk_mark))
                for word in ([start_mark] + trg_seq + [end_mark])
            ] for trg_seq in trg_seq_words]
            if len(self._trg_seq_ids) != self._sample_count:
                raise Exception("Inconsistent sample count between "
                                "source sequences and target sequences.")
        else:
            self._trg_seq_ids = None

        self._sample_idxs = [i for i in xrange(self._sample_count)]
        self._sorted = False

        random.seed(seed)

    def _parse_file(self, f_obj):
        src_seq_words = []
        trg_seq_words = []

        for line in f_obj:
204
            fields = line.strip().split(self._field_delimiter)
205

206 207
            if (not self._only_src and len(fields) != 2) or (self._only_src and
                                                             len(fields) != 1):
208 209 210 211 212 213 214
                continue

            sample_words = []
            is_valid_sample = True
            max_len = -1

            for i, seq in enumerate(fields):
215
                seq_words = seq.split(self._token_delimiter)
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
                max_len = max(max_len, len(seq_words))
                if len(seq_words) == 0 or \
                        len(seq_words) < self._min_length or \
                        len(seq_words) > self._max_length or \
                        (self._use_token_batch and max_len > self._batch_size):
                    is_valid_sample = False
                    break

                sample_words.append(seq_words)

            if not is_valid_sample: continue

            src_seq_words.append(sample_words[0])

            if not self._only_src:
                trg_seq_words.append(sample_words[1])

        return (src_seq_words, trg_seq_words)

    def _load_data(self, fpattern, tar_fname):
        fpaths = glob.glob(fpattern)

        src_seq_words = []
        trg_seq_words = []

        if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
            if tar_fname is None:
                raise Exception("If tar file provided, please set tar_fname.")

            f = tarfile.open(fpaths[0], 'r')
            part_file_data = self._parse_file(f.extractfile(tar_fname))
            src_seq_words = part_file_data[0]
            trg_seq_words = part_file_data[1]
        else:
            for fpath in fpaths:
                if not os.path.isfile(fpath):
                    raise IOError("Invalid file: %s" % fpath)

                part_file_data = self._parse_file(open(fpath, 'r'))
                src_seq_words.extend(part_file_data[0])
                trg_seq_words.extend(part_file_data[1])

        return src_seq_words, trg_seq_words

260 261
    @staticmethod
    def load_dict(dict_path, reverse=False):
262 263 264 265
        word_dict = {}
        with open(dict_path, "r") as fdict:
            for idx, line in enumerate(fdict):
                if reverse:
266
                    word_dict[idx] = line.strip('\n')
267
                else:
268
                    word_dict[line.strip('\n')] = idx
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
        return word_dict

    def _sample_generator(self):
        if self._sort_type == SortType.GLOBAL:
            if not self._sorted:
                self._sample_idxs.sort(
                    key=lambda idx: max(len(self._src_seq_ids[idx]),
                    len(self._trg_seq_ids[idx] if not self._only_src else 0))
                )
                self._sorted = True
        elif self._shuffle:
            random.shuffle(self._sample_idxs)

        for sample_idx in self._sample_idxs:
            if self._only_src:
284
                yield (self._src_seq_ids[sample_idx], )
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
            else:
                yield (self._src_seq_ids[sample_idx],
                       self._trg_seq_ids[sample_idx][:-1],
                       self._trg_seq_ids[sample_idx][1:])

    def batch_generator(self):
        pool = Pool(self._sample_generator, self._pool_size, True
                    if self._sort_type == SortType.POOL else False)

        def next_batch():
            batch_data = []
            max_len = -1
            batch_max_seq_len = -1

            while True:
                sample = pool.next(look=True)

                if sample is None:
                    pool.push_back(batch_data)
                    batch_data = []
                    continue

                if isinstance(sample, EndEpoch):
                    return batch_data, batch_max_seq_len, True

                max_len = max(max_len, len(sample[0]))

                if not self._only_src:
                    max_len = max(max_len, len(sample[1]))

                if self._use_token_batch:
                    if max_len * (len(batch_data) + 1) < self._batch_size:
                        batch_max_seq_len = max_len
                        batch_data.append(pool.next())
                    else:
                        return batch_data, batch_max_seq_len, False
                else:
                    if len(batch_data) < self._batch_size:
                        batch_max_seq_len = max_len
                        batch_data.append(pool.next())
                    else:
                        return batch_data, batch_max_seq_len, False

        if not self._shuffle_batch:
            batch_data, batch_max_seq_len, last_batch = next_batch()
            while not last_batch:
                yield batch_data
                batch_data, batch_max_seq_len, last_batch = next_batch()

            batch_size = len(batch_data)
            if self._use_token_batch:
                batch_size *= batch_max_seq_len

            if (not self._clip_last_batch and len(batch_data) > 0) \
                    or (batch_size == self._batch_size):
                yield batch_data
        else:
            # should re-generate batches
            if self._sort_type == SortType.POOL \
                    or len(self._epoch_batches) == 0:
                self._epoch_batches = []
                batch_data, batch_max_seq_len, last_batch = next_batch()
                while not last_batch:
                    self._epoch_batches.append(batch_data)
                    batch_data, batch_max_seq_len, last_batch = next_batch()

                batch_size = len(batch_data)
                if self._use_token_batch:
                    batch_size *= batch_max_seq_len

                if (not self._clip_last_batch and len(batch_data) > 0) \
                        or (batch_size == self._batch_size):
                    self._epoch_batches.append(batch_data)

            random.shuffle(self._epoch_batches)

            for batch_data in self._epoch_batches:
                yield batch_data