From dcdb6a00f86f733eb136ed02fe7df26f9a3c87b1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 20 Jul 2018 11:13:58 +0000 Subject: [PATCH] Speed up NTM reader * Load Data --> 64 sec * Shuffle/Batch --> 14 sec --- .../transformer/infer.py | 1 - .../transformer/model.py | 1 + .../transformer/reader.py | 355 ++++++++---------- .../transformer/train.py | 11 +- 4 files changed, 156 insertions(+), 212 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 505bf0b0..e3cc75d0 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -529,7 +529,6 @@ def fast_infer(test_data, trg_idx2word): def infer(args, inferencer=fast_infer): place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() - exe = fluid.Executor(place) test_data = reader.DataReader( src_vocab_fpath=args.src_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath, diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 46c9f7a9..39b7ca6c 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -466,6 +466,7 @@ def transformer( sum_cost = layers.reduce_sum(weighted_cost) token_num = layers.reduce_sum(weights) avg_cost = sum_cost / token_num + avg_cost.stop_gradient = True return sum_cost, avg_cost, predict, token_num diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 27bd82b1..6170c57f 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -1,8 +1,8 @@ -import os -import tarfile import glob - +import os import random +import tarfile +import time class SortType(object): @@ -11,54 +11,84 @@ class SortType(object): NONE = "none" -class EndEpoch(): - pass +class Converter(object): + def __init__(self, vocab, beg, end, unk): + self._vocab = vocab + self._beg = beg + self._end = end + self._unk = unk + def __call__(self, sentence): + return [self._beg] + [ + self._vocab.get(w, self._unk) for w in sentence.split() + ] + [self._end] -class Pool(object): - def __init__(self, sample_generator, pool_size, sort): - self._pool_size = pool_size - self._pool = [] - self._sample_generator = sample_generator() - self._end = False - self._sort = sort - - def _fill(self): - while len(self._pool) < self._pool_size and not self._end: - try: - sample = self._sample_generator.next() - self._pool.append(sample) - except StopIteration as e: - self._end = True - break - - if self._sort: - self._pool.sort( - key=lambda sample: max(len(sample[0]), len(sample[1])) \ - if len(sample) > 1 else len(sample[0]) - ) - - if self._end and len(self._pool) < self._pool_size: - self._pool.append(EndEpoch()) - - def push_back(self, samples): - if len(self._pool) != 0: - raise Exception("Pool should be empty.") - - if len(samples) >= self._pool_size: - raise Exception("Capacity of pool should be greater than a batch. " - "Please enlarge `pool_size`.") - - for sample in samples: - self._pool.append(sample) - - self._fill() - - def next(self, look=False): - if len(self._pool) == 0: - return None + +class ComposedConverter(object): + def __init__(self, converters): + self._converters = converters + + def __call__(self, parallel_sentence): + return [ + self._converters[i](parallel_sentence[i]) + for i in range(len(self._converters)) + ] + + +class SentenceBatchCreator(object): + def __init__(self, batch_size): + self.batch = [] + self._batch_size = batch_size + + def append(self, info): + self.batch.append(info) + if len(self.batch) == self._batch_size: + tmp = self.batch + self.batch = [] + return tmp + + +class TokenBatchCreator(object): + def __init__(self, batch_size): + self.batch = [] + self.max_len = -1 + self._batch_size = batch_size + + def append(self, info): + cur_len = info.max_len + max_len = max(self.max_len, cur_len) + if max_len * (len(self.batch) + 1) > self._batch_size: + result = self.batch + self.batch = [info] + self.max_len = cur_len + return result + else: + self.max_len = max_len + self.batch.append(info) + + +class SampleInfo(object): + def __init__(self, i, max_len, min_len): + self.i = i + self.min_len = min_len + self.max_len = max_len + + +class MinMaxFilter(object): + def __init__(self, max_len, min_len, underlying_creator): + self._min_len = min_len + self._max_len = max_len + self._creator = underlying_creator + + def append(self, info): + if info.max_len > self._max_len or info.min_len < self._min_len: + return else: - return self._pool[0] if look else self._pool.pop(0) + return self._creator.append(info) + + @property + def batch(self): + return self._creator.batch class DataReader(object): @@ -137,7 +167,7 @@ class DataReader(object): fpattern, batch_size, pool_size, - sort_type=SortType.NONE, + sort_type=SortType.GLOBAL, clip_last_batch=True, tar_fname=None, min_length=0, @@ -165,92 +195,61 @@ class DataReader(object): self._min_length = min_length self._max_length = max_length self._delimiter = delimiter - self._epoch_batches = [] - - src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname) - self._src_seq_ids = [[ - self._src_vocab.get(word, self._src_vocab.get(unk_mark)) - for word in ([start_mark] + src_seq + [end_mark]) - ] for src_seq in src_seq_words] - - self._sample_count = len(self._src_seq_ids) - + self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname, + unk_mark) + self._random = random.Random(x=seed) + + def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname, + unk_mark): + converters = [ + Converter( + vocab=self._src_vocab, + beg=self._src_vocab[start_mark], + end=self._src_vocab[end_mark], + unk=self._src_vocab[unk_mark]) + ] if not self._only_src: - self._trg_seq_ids = [[ - self._trg_vocab.get(word, self._trg_vocab.get(unk_mark)) - for word in ([start_mark] + trg_seq + [end_mark]) - ] for trg_seq in trg_seq_words] - if len(self._trg_seq_ids) != self._sample_count: - raise Exception("Inconsistent sample count between " - "source sequences and target sequences.") - else: - self._trg_seq_ids = None - - self._sample_idxs = [i for i in xrange(self._sample_count)] - self._sorted = False - - random.seed(seed) - - def _parse_file(self, f_obj): - src_seq_words = [] - trg_seq_words = [] - - for line in f_obj: - fields = line.strip().split(self._delimiter) - - if (not self._only_src and len(fields) != 2) or (self._only_src and - len(fields) != 1): - continue - - sample_words = [] - is_valid_sample = True - max_len = -1 - - for i, seq in enumerate(fields): - seq_words = seq.split() - max_len = max(max_len, len(seq_words)) - if len(seq_words) == 0 or \ - len(seq_words) < self._min_length or \ - len(seq_words) > self._max_length or \ - (self._use_token_batch and max_len > self._batch_size): - is_valid_sample = False - break - - sample_words.append(seq_words) - - if not is_valid_sample: continue - - src_seq_words.append(sample_words[0]) - + converters.append( + Converter( + vocab=self._trg_vocab, + beg=self._trg_vocab[start_mark], + end=self._trg_vocab[end_mark], + unk=self._trg_vocab[unk_mark])) + + converters = ComposedConverter(converters) + + self._src_seq_ids = [] + self._trg_seq_ids = None if self._only_src else [] + self._sample_infos = [] + + for i, line in enumerate(self._load_lines(fpattern, tar_fname)): + src_trg_ids = converters(line) + self._src_seq_ids.append(src_trg_ids[0]) + lens = [len(src_trg_ids[0])] if not self._only_src: - trg_seq_words.append(sample_words[1]) - - return (src_seq_words, trg_seq_words) + self._trg_seq_ids.append(src_trg_ids[1]) + lens.append(len(src_trg_ids[1])) + self._sample_infos.append(SampleInfo(i, max(lens), min(lens))) - def _load_data(self, fpattern, tar_fname): + @staticmethod + def _load_lines(fpattern, tar_fname): fpaths = glob.glob(fpattern) - src_seq_words = [] - trg_seq_words = [] - if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]): if tar_fname is None: raise Exception("If tar file provided, please set tar_fname.") f = tarfile.open(fpaths[0], 'r') - part_file_data = self._parse_file(f.extractfile(tar_fname)) - src_seq_words = part_file_data[0] - trg_seq_words = part_file_data[1] + for line in f.extractfile(tar_fname): + yield line.split() else: for fpath in fpaths: if not os.path.isfile(fpath): raise IOError("Invalid file: %s" % fpath) - part_file_data = self._parse_file(open(fpath, 'r')) - src_seq_words.extend(part_file_data[0]) - trg_seq_words.extend(part_file_data[1]) - - return src_seq_words, trg_seq_words + with open(fpath, 'r') as f: + for line in f: + yield line.split() @staticmethod def load_dict(dict_path, reverse=False): @@ -263,95 +262,41 @@ class DataReader(object): word_dict[line.strip()] = idx return word_dict - def _sample_generator(self): + def batch_generator(self): + # global sort or global shuffle + beg = time.time() if self._sort_type == SortType.GLOBAL: - if not self._sorted: - self._sample_idxs.sort( - key=lambda idx: max(len(self._src_seq_ids[idx]), - len(self._trg_seq_ids[idx] if not self._only_src else 0)) - ) - self._sorted = True + infos = sorted( + self._sample_infos, + key=lambda x: max(x[1], x[2]) if not self._only_src else x[1]) elif self._shuffle: - random.shuffle(self._sample_idxs) - - for sample_idx in self._sample_idxs: - if self._only_src: - yield (self._src_seq_ids[sample_idx], ) - else: - yield (self._src_seq_ids[sample_idx], - self._trg_seq_ids[sample_idx][:-1], - self._trg_seq_ids[sample_idx][1:]) - - def batch_generator(self): - pool = Pool(self._sample_generator, self._pool_size, True - if self._sort_type == SortType.POOL else False) - - def next_batch(): - batch_data = [] - max_len = -1 - batch_max_seq_len = -1 - - while True: - sample = pool.next(look=True) + infos = self._sample_infos + self._random.shuffle(infos) + else: + infos = self._sample_infos - if sample is None: - pool.push_back(batch_data) - batch_data = [] - continue + # concat batch + batches = [] + batch_creator = TokenBatchCreator( + self._batch_size + ) if self._use_token_batch else SentenceBatchCreator(self._batch_size) + batch_creator = MinMaxFilter(self._max_length, self._min_length, + batch_creator) - if isinstance(sample, EndEpoch): - return batch_data, batch_max_seq_len, True + for info in infos: + batch = batch_creator.append(info) + if batch is not None: + batches.append([info.i for info in batch]) - max_len = max(max_len, len(sample[0])) + if not self._clip_last_batch and len(batch_creator.batch) != 0: + batches.append([info.i for info in batch_creator.batch]) - if not self._only_src: - max_len = max(max_len, len(sample[1])) + if self._shuffle: + self._random.shuffle(batches) - if self._use_token_batch: - if max_len * (len(batch_data) + 1) < self._batch_size: - batch_max_seq_len = max_len - batch_data.append(pool.next()) - else: - return batch_data, batch_max_seq_len, False - else: - if len(batch_data) < self._batch_size: - batch_max_seq_len = max_len - batch_data.append(pool.next()) - else: - return batch_data, batch_max_seq_len, False - - if not self._shuffle_batch: - batch_data, batch_max_seq_len, last_batch = next_batch() - while not last_batch: - yield batch_data - batch_data, batch_max_seq_len, last_batch = next_batch() - - batch_size = len(batch_data) - if self._use_token_batch: - batch_size *= batch_max_seq_len - - if (not self._clip_last_batch and len(batch_data) > 0) \ - or (batch_size == self._batch_size): - yield batch_data - else: - # should re-generate batches - if self._sort_type == SortType.POOL \ - or len(self._epoch_batches) == 0: - self._epoch_batches = [] - batch_data, batch_max_seq_len, last_batch = next_batch() - while not last_batch: - self._epoch_batches.append(batch_data) - batch_data, batch_max_seq_len, last_batch = next_batch() - - batch_size = len(batch_data) - if self._use_token_batch: - batch_size *= batch_max_seq_len - - if (not self._clip_last_batch and len(batch_data) > 0) \ - or (batch_size == self._batch_size): - self._epoch_batches.append(batch_data) - - random.shuffle(self._epoch_batches) - - for batch_data in self._epoch_batches: - yield batch_data + for batch_ids in batches: + if self._only_src: + yield [[self._src_seq_ids[idx]] for idx in batch_ids] + else: + yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1], + self._trg_seq_ids[idx][1:]) for idx in batch_ids] diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index d2cd5a18..15f2fe86 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -1,17 +1,16 @@ -import os -import time import argparse import ast -import numpy as np import multiprocessing +import os +import time -import paddle +import numpy as np import paddle.fluid as fluid +import reader +from config import * from model import transformer, position_encoding_init from optim import LearningRateScheduler -from config import * -import reader def parse_args(): -- GitLab