提交 dcdb6a00 编写于 作者: Y Yu Yang

Speed up NTM reader

* Load Data --> 64 sec
* Shuffle/Batch --> 14 sec
上级 0b48d785
......@@ -529,7 +529,6 @@ def fast_infer(test_data, trg_idx2word):
def infer(args, inferencer=fast_infer):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
......
......@@ -466,6 +466,7 @@ def transformer(
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
avg_cost = sum_cost / token_num
avg_cost.stop_gradient = True
return sum_cost, avg_cost, predict, token_num
......
import os
import tarfile
import glob
import os
import random
import tarfile
import time
class SortType(object):
......@@ -11,54 +11,84 @@ class SortType(object):
NONE = "none"
class EndEpoch():
pass
class Converter(object):
def __init__(self, vocab, beg, end, unk):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
def __call__(self, sentence):
return [self._beg] + [
self._vocab.get(w, self._unk) for w in sentence.split()
] + [self._end]
class Pool(object):
def __init__(self, sample_generator, pool_size, sort):
self._pool_size = pool_size
self._pool = []
self._sample_generator = sample_generator()
self._end = False
self._sort = sort
def _fill(self):
while len(self._pool) < self._pool_size and not self._end:
try:
sample = self._sample_generator.next()
self._pool.append(sample)
except StopIteration as e:
self._end = True
break
if self._sort:
self._pool.sort(
key=lambda sample: max(len(sample[0]), len(sample[1])) \
if len(sample) > 1 else len(sample[0])
)
if self._end and len(self._pool) < self._pool_size:
self._pool.append(EndEpoch())
def push_back(self, samples):
if len(self._pool) != 0:
raise Exception("Pool should be empty.")
if len(samples) >= self._pool_size:
raise Exception("Capacity of pool should be greater than a batch. "
"Please enlarge `pool_size`.")
for sample in samples:
self._pool.append(sample)
self._fill()
def next(self, look=False):
if len(self._pool) == 0:
return None
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, parallel_sentence):
return [
self._converters[i](parallel_sentence[i])
for i in range(len(self._converters))
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, max_len, min_len):
self.i = i
self.min_len = min_len
self.max_len = max_len
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._pool[0] if look else self._pool.pop(0)
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class DataReader(object):
......@@ -137,7 +167,7 @@ class DataReader(object):
fpattern,
batch_size,
pool_size,
sort_type=SortType.NONE,
sort_type=SortType.GLOBAL,
clip_last_batch=True,
tar_fname=None,
min_length=0,
......@@ -165,92 +195,61 @@ class DataReader(object):
self._min_length = min_length
self._max_length = max_length
self._delimiter = delimiter
self._epoch_batches = []
src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname)
self._src_seq_ids = [[
self._src_vocab.get(word, self._src_vocab.get(unk_mark))
for word in ([start_mark] + src_seq + [end_mark])
] for src_seq in src_seq_words]
self._sample_count = len(self._src_seq_ids)
self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
unk_mark)
self._random = random.Random(x=seed)
def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
unk_mark):
converters = [
Converter(
vocab=self._src_vocab,
beg=self._src_vocab[start_mark],
end=self._src_vocab[end_mark],
unk=self._src_vocab[unk_mark])
]
if not self._only_src:
self._trg_seq_ids = [[
self._trg_vocab.get(word, self._trg_vocab.get(unk_mark))
for word in ([start_mark] + trg_seq + [end_mark])
] for trg_seq in trg_seq_words]
if len(self._trg_seq_ids) != self._sample_count:
raise Exception("Inconsistent sample count between "
"source sequences and target sequences.")
else:
self._trg_seq_ids = None
self._sample_idxs = [i for i in xrange(self._sample_count)]
self._sorted = False
random.seed(seed)
def _parse_file(self, f_obj):
src_seq_words = []
trg_seq_words = []
for line in f_obj:
fields = line.strip().split(self._delimiter)
if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1):
continue
sample_words = []
is_valid_sample = True
max_len = -1
for i, seq in enumerate(fields):
seq_words = seq.split()
max_len = max(max_len, len(seq_words))
if len(seq_words) == 0 or \
len(seq_words) < self._min_length or \
len(seq_words) > self._max_length or \
(self._use_token_batch and max_len > self._batch_size):
is_valid_sample = False
break
sample_words.append(seq_words)
if not is_valid_sample: continue
src_seq_words.append(sample_words[0])
converters.append(
Converter(
vocab=self._trg_vocab,
beg=self._trg_vocab[start_mark],
end=self._trg_vocab[end_mark],
unk=self._trg_vocab[unk_mark]))
converters = ComposedConverter(converters)
self._src_seq_ids = []
self._trg_seq_ids = None if self._only_src else []
self._sample_infos = []
for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
src_trg_ids = converters(line)
self._src_seq_ids.append(src_trg_ids[0])
lens = [len(src_trg_ids[0])]
if not self._only_src:
trg_seq_words.append(sample_words[1])
return (src_seq_words, trg_seq_words)
self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
def _load_data(self, fpattern, tar_fname):
@staticmethod
def _load_lines(fpattern, tar_fname):
fpaths = glob.glob(fpattern)
src_seq_words = []
trg_seq_words = []
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], 'r')
part_file_data = self._parse_file(f.extractfile(tar_fname))
src_seq_words = part_file_data[0]
trg_seq_words = part_file_data[1]
for line in f.extractfile(tar_fname):
yield line.split()
else:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
part_file_data = self._parse_file(open(fpath, 'r'))
src_seq_words.extend(part_file_data[0])
trg_seq_words.extend(part_file_data[1])
return src_seq_words, trg_seq_words
with open(fpath, 'r') as f:
for line in f:
yield line.split()
@staticmethod
def load_dict(dict_path, reverse=False):
......@@ -263,95 +262,41 @@ class DataReader(object):
word_dict[line.strip()] = idx
return word_dict
def _sample_generator(self):
def batch_generator(self):
# global sort or global shuffle
beg = time.time()
if self._sort_type == SortType.GLOBAL:
if not self._sorted:
self._sample_idxs.sort(
key=lambda idx: max(len(self._src_seq_ids[idx]),
len(self._trg_seq_ids[idx] if not self._only_src else 0))
)
self._sorted = True
infos = sorted(
self._sample_infos,
key=lambda x: max(x[1], x[2]) if not self._only_src else x[1])
elif self._shuffle:
random.shuffle(self._sample_idxs)
for sample_idx in self._sample_idxs:
if self._only_src:
yield (self._src_seq_ids[sample_idx], )
else:
yield (self._src_seq_ids[sample_idx],
self._trg_seq_ids[sample_idx][:-1],
self._trg_seq_ids[sample_idx][1:])
def batch_generator(self):
pool = Pool(self._sample_generator, self._pool_size, True
if self._sort_type == SortType.POOL else False)
def next_batch():
batch_data = []
max_len = -1
batch_max_seq_len = -1
while True:
sample = pool.next(look=True)
infos = self._sample_infos
self._random.shuffle(infos)
else:
infos = self._sample_infos
if sample is None:
pool.push_back(batch_data)
batch_data = []
continue
# concat batch
batches = []
batch_creator = TokenBatchCreator(
self._batch_size
) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
if isinstance(sample, EndEpoch):
return batch_data, batch_max_seq_len, True
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append([info.i for info in batch])
max_len = max(max_len, len(sample[0]))
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append([info.i for info in batch_creator.batch])
if not self._only_src:
max_len = max(max_len, len(sample[1]))
if self._shuffle:
self._random.shuffle(batches)
if self._use_token_batch:
if max_len * (len(batch_data) + 1) < self._batch_size:
batch_max_seq_len = max_len
batch_data.append(pool.next())
else:
return batch_data, batch_max_seq_len, False
else:
if len(batch_data) < self._batch_size:
batch_max_seq_len = max_len
batch_data.append(pool.next())
else:
return batch_data, batch_max_seq_len, False
if not self._shuffle_batch:
batch_data, batch_max_seq_len, last_batch = next_batch()
while not last_batch:
yield batch_data
batch_data, batch_max_seq_len, last_batch = next_batch()
batch_size = len(batch_data)
if self._use_token_batch:
batch_size *= batch_max_seq_len
if (not self._clip_last_batch and len(batch_data) > 0) \
or (batch_size == self._batch_size):
yield batch_data
else:
# should re-generate batches
if self._sort_type == SortType.POOL \
or len(self._epoch_batches) == 0:
self._epoch_batches = []
batch_data, batch_max_seq_len, last_batch = next_batch()
while not last_batch:
self._epoch_batches.append(batch_data)
batch_data, batch_max_seq_len, last_batch = next_batch()
batch_size = len(batch_data)
if self._use_token_batch:
batch_size *= batch_max_seq_len
if (not self._clip_last_batch and len(batch_data) > 0) \
or (batch_size == self._batch_size):
self._epoch_batches.append(batch_data)
random.shuffle(self._epoch_batches)
for batch_data in self._epoch_batches:
yield batch_data
for batch_ids in batches:
if self._only_src:
yield [[self._src_seq_ids[idx]] for idx in batch_ids]
else:
yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]) for idx in batch_ids]
import os
import time
import argparse
import ast
import numpy as np
import multiprocessing
import os
import time
import paddle
import numpy as np
import paddle.fluid as fluid
import reader
from config import *
from model import transformer, position_encoding_init
from optim import LearningRateScheduler
from config import *
import reader
def parse_args():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册