提交 dcdb6a00 编写于 作者: Y Yu Yang

Speed up NTM reader

* Load Data --> 64 sec
* Shuffle/Batch --> 14 sec
上级 0b48d785
...@@ -529,7 +529,6 @@ def fast_infer(test_data, trg_idx2word): ...@@ -529,7 +529,6 @@ def fast_infer(test_data, trg_idx2word):
def infer(args, inferencer=fast_infer): def infer(args, inferencer=fast_infer):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
test_data = reader.DataReader( test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath, src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath,
......
...@@ -466,6 +466,7 @@ def transformer( ...@@ -466,6 +466,7 @@ def transformer(
sum_cost = layers.reduce_sum(weighted_cost) sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights) token_num = layers.reduce_sum(weights)
avg_cost = sum_cost / token_num avg_cost = sum_cost / token_num
avg_cost.stop_gradient = True
return sum_cost, avg_cost, predict, token_num return sum_cost, avg_cost, predict, token_num
......
import os
import tarfile
import glob import glob
import os
import random import random
import tarfile
import time
class SortType(object): class SortType(object):
...@@ -11,54 +11,84 @@ class SortType(object): ...@@ -11,54 +11,84 @@ class SortType(object):
NONE = "none" NONE = "none"
class EndEpoch(): class Converter(object):
pass def __init__(self, vocab, beg, end, unk):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
def __call__(self, sentence):
return [self._beg] + [
self._vocab.get(w, self._unk) for w in sentence.split()
] + [self._end]
class Pool(object):
def __init__(self, sample_generator, pool_size, sort): class ComposedConverter(object):
self._pool_size = pool_size def __init__(self, converters):
self._pool = [] self._converters = converters
self._sample_generator = sample_generator()
self._end = False def __call__(self, parallel_sentence):
self._sort = sort return [
self._converters[i](parallel_sentence[i])
def _fill(self): for i in range(len(self._converters))
while len(self._pool) < self._pool_size and not self._end: ]
try:
sample = self._sample_generator.next()
self._pool.append(sample) class SentenceBatchCreator(object):
except StopIteration as e: def __init__(self, batch_size):
self._end = True self.batch = []
break self._batch_size = batch_size
if self._sort: def append(self, info):
self._pool.sort( self.batch.append(info)
key=lambda sample: max(len(sample[0]), len(sample[1])) \ if len(self.batch) == self._batch_size:
if len(sample) > 1 else len(sample[0]) tmp = self.batch
) self.batch = []
return tmp
if self._end and len(self._pool) < self._pool_size:
self._pool.append(EndEpoch())
class TokenBatchCreator(object):
def push_back(self, samples): def __init__(self, batch_size):
if len(self._pool) != 0: self.batch = []
raise Exception("Pool should be empty.") self.max_len = -1
self._batch_size = batch_size
if len(samples) >= self._pool_size:
raise Exception("Capacity of pool should be greater than a batch. " def append(self, info):
"Please enlarge `pool_size`.") cur_len = info.max_len
max_len = max(self.max_len, cur_len)
for sample in samples: if max_len * (len(self.batch) + 1) > self._batch_size:
self._pool.append(sample) result = self.batch
self.batch = [info]
self._fill() self.max_len = cur_len
return result
def next(self, look=False): else:
if len(self._pool) == 0: self.max_len = max_len
return None self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, max_len, min_len):
self.i = i
self.min_len = min_len
self.max_len = max_len
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else: else:
return self._pool[0] if look else self._pool.pop(0) return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class DataReader(object): class DataReader(object):
...@@ -137,7 +167,7 @@ class DataReader(object): ...@@ -137,7 +167,7 @@ class DataReader(object):
fpattern, fpattern,
batch_size, batch_size,
pool_size, pool_size,
sort_type=SortType.NONE, sort_type=SortType.GLOBAL,
clip_last_batch=True, clip_last_batch=True,
tar_fname=None, tar_fname=None,
min_length=0, min_length=0,
...@@ -165,92 +195,61 @@ class DataReader(object): ...@@ -165,92 +195,61 @@ class DataReader(object):
self._min_length = min_length self._min_length = min_length
self._max_length = max_length self._max_length = max_length
self._delimiter = delimiter self._delimiter = delimiter
self._epoch_batches = [] self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
unk_mark)
src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname) self._random = random.Random(x=seed)
self._src_seq_ids = [[
self._src_vocab.get(word, self._src_vocab.get(unk_mark)) def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
for word in ([start_mark] + src_seq + [end_mark]) unk_mark):
] for src_seq in src_seq_words] converters = [
Converter(
self._sample_count = len(self._src_seq_ids) vocab=self._src_vocab,
beg=self._src_vocab[start_mark],
end=self._src_vocab[end_mark],
unk=self._src_vocab[unk_mark])
]
if not self._only_src: if not self._only_src:
self._trg_seq_ids = [[ converters.append(
self._trg_vocab.get(word, self._trg_vocab.get(unk_mark)) Converter(
for word in ([start_mark] + trg_seq + [end_mark]) vocab=self._trg_vocab,
] for trg_seq in trg_seq_words] beg=self._trg_vocab[start_mark],
if len(self._trg_seq_ids) != self._sample_count: end=self._trg_vocab[end_mark],
raise Exception("Inconsistent sample count between " unk=self._trg_vocab[unk_mark]))
"source sequences and target sequences.")
else: converters = ComposedConverter(converters)
self._trg_seq_ids = None
self._src_seq_ids = []
self._sample_idxs = [i for i in xrange(self._sample_count)] self._trg_seq_ids = None if self._only_src else []
self._sorted = False self._sample_infos = []
random.seed(seed) for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
src_trg_ids = converters(line)
def _parse_file(self, f_obj): self._src_seq_ids.append(src_trg_ids[0])
src_seq_words = [] lens = [len(src_trg_ids[0])]
trg_seq_words = []
for line in f_obj:
fields = line.strip().split(self._delimiter)
if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1):
continue
sample_words = []
is_valid_sample = True
max_len = -1
for i, seq in enumerate(fields):
seq_words = seq.split()
max_len = max(max_len, len(seq_words))
if len(seq_words) == 0 or \
len(seq_words) < self._min_length or \
len(seq_words) > self._max_length or \
(self._use_token_batch and max_len > self._batch_size):
is_valid_sample = False
break
sample_words.append(seq_words)
if not is_valid_sample: continue
src_seq_words.append(sample_words[0])
if not self._only_src: if not self._only_src:
trg_seq_words.append(sample_words[1]) self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
return (src_seq_words, trg_seq_words) self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
def _load_data(self, fpattern, tar_fname): @staticmethod
def _load_lines(fpattern, tar_fname):
fpaths = glob.glob(fpattern) fpaths = glob.glob(fpattern)
src_seq_words = []
trg_seq_words = []
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]): if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None: if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.") raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], 'r') f = tarfile.open(fpaths[0], 'r')
part_file_data = self._parse_file(f.extractfile(tar_fname)) for line in f.extractfile(tar_fname):
src_seq_words = part_file_data[0] yield line.split()
trg_seq_words = part_file_data[1]
else: else:
for fpath in fpaths: for fpath in fpaths:
if not os.path.isfile(fpath): if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath) raise IOError("Invalid file: %s" % fpath)
part_file_data = self._parse_file(open(fpath, 'r')) with open(fpath, 'r') as f:
src_seq_words.extend(part_file_data[0]) for line in f:
trg_seq_words.extend(part_file_data[1]) yield line.split()
return src_seq_words, trg_seq_words
@staticmethod @staticmethod
def load_dict(dict_path, reverse=False): def load_dict(dict_path, reverse=False):
...@@ -263,95 +262,41 @@ class DataReader(object): ...@@ -263,95 +262,41 @@ class DataReader(object):
word_dict[line.strip()] = idx word_dict[line.strip()] = idx
return word_dict return word_dict
def _sample_generator(self): def batch_generator(self):
# global sort or global shuffle
beg = time.time()
if self._sort_type == SortType.GLOBAL: if self._sort_type == SortType.GLOBAL:
if not self._sorted: infos = sorted(
self._sample_idxs.sort( self._sample_infos,
key=lambda idx: max(len(self._src_seq_ids[idx]), key=lambda x: max(x[1], x[2]) if not self._only_src else x[1])
len(self._trg_seq_ids[idx] if not self._only_src else 0))
)
self._sorted = True
elif self._shuffle: elif self._shuffle:
random.shuffle(self._sample_idxs) infos = self._sample_infos
self._random.shuffle(infos)
for sample_idx in self._sample_idxs: else:
if self._only_src: infos = self._sample_infos
yield (self._src_seq_ids[sample_idx], )
else:
yield (self._src_seq_ids[sample_idx],
self._trg_seq_ids[sample_idx][:-1],
self._trg_seq_ids[sample_idx][1:])
def batch_generator(self):
pool = Pool(self._sample_generator, self._pool_size, True
if self._sort_type == SortType.POOL else False)
def next_batch():
batch_data = []
max_len = -1
batch_max_seq_len = -1
while True:
sample = pool.next(look=True)
if sample is None: # concat batch
pool.push_back(batch_data) batches = []
batch_data = [] batch_creator = TokenBatchCreator(
continue self._batch_size
) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
if isinstance(sample, EndEpoch): for info in infos:
return batch_data, batch_max_seq_len, True batch = batch_creator.append(info)
if batch is not None:
batches.append([info.i for info in batch])
max_len = max(max_len, len(sample[0])) if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append([info.i for info in batch_creator.batch])
if not self._only_src: if self._shuffle:
max_len = max(max_len, len(sample[1])) self._random.shuffle(batches)
if self._use_token_batch: for batch_ids in batches:
if max_len * (len(batch_data) + 1) < self._batch_size: if self._only_src:
batch_max_seq_len = max_len yield [[self._src_seq_ids[idx]] for idx in batch_ids]
batch_data.append(pool.next()) else:
else: yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
return batch_data, batch_max_seq_len, False self._trg_seq_ids[idx][1:]) for idx in batch_ids]
else:
if len(batch_data) < self._batch_size:
batch_max_seq_len = max_len
batch_data.append(pool.next())
else:
return batch_data, batch_max_seq_len, False
if not self._shuffle_batch:
batch_data, batch_max_seq_len, last_batch = next_batch()
while not last_batch:
yield batch_data
batch_data, batch_max_seq_len, last_batch = next_batch()
batch_size = len(batch_data)
if self._use_token_batch:
batch_size *= batch_max_seq_len
if (not self._clip_last_batch and len(batch_data) > 0) \
or (batch_size == self._batch_size):
yield batch_data
else:
# should re-generate batches
if self._sort_type == SortType.POOL \
or len(self._epoch_batches) == 0:
self._epoch_batches = []
batch_data, batch_max_seq_len, last_batch = next_batch()
while not last_batch:
self._epoch_batches.append(batch_data)
batch_data, batch_max_seq_len, last_batch = next_batch()
batch_size = len(batch_data)
if self._use_token_batch:
batch_size *= batch_max_seq_len
if (not self._clip_last_batch and len(batch_data) > 0) \
or (batch_size == self._batch_size):
self._epoch_batches.append(batch_data)
random.shuffle(self._epoch_batches)
for batch_data in self._epoch_batches:
yield batch_data
import os
import time
import argparse import argparse
import ast import ast
import numpy as np
import multiprocessing import multiprocessing
import os
import time
import paddle import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import reader
from config import *
from model import transformer, position_encoding_init from model import transformer, position_encoding_init
from optim import LearningRateScheduler from optim import LearningRateScheduler
from config import *
import reader
def parse_args(): def parse_args():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册