未验证 提交 c34bb5f1 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1060 from reyoung/speed_up_transformer_python_reader

Speed up NTM reader
.DS_Store
*.pyc
.*~
fluid/neural_machine_translation/transformer/deps
fluid/neural_machine_translation/transformer/train.data
fluid/neural_machine_translation/transformer/train.pkl
fluid/neural_machine_translation/transformer/train.sh
fluid/neural_machine_translation/transformer/train.tok.clean.bpe.32000.en-de
fluid/neural_machine_translation/transformer/vocab.bpe.32000.refined
......@@ -553,7 +553,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
def infer(args, inferencer=fast_infer):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
......
......@@ -474,6 +474,7 @@ def transformer(
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
avg_cost = sum_cost / token_num
avg_cost.stop_gradient = True
return sum_cost, avg_cost, predict, token_num
......
import os
import tarfile
import glob
import os
import random
import tarfile
import cPickle
class SortType(object):
......@@ -11,54 +11,86 @@ class SortType(object):
NONE = "none"
class EndEpoch():
pass
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
def __call__(self, sentence):
return [self._beg] + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + [self._end]
class Pool(object):
def __init__(self, sample_generator, pool_size, sort):
self._pool_size = pool_size
self._pool = []
self._sample_generator = sample_generator()
self._end = False
self._sort = sort
def _fill(self):
while len(self._pool) < self._pool_size and not self._end:
try:
sample = self._sample_generator.next()
self._pool.append(sample)
except StopIteration as e:
self._end = True
break
if self._sort:
self._pool.sort(
key=lambda sample: max(len(sample[0]), len(sample[1])) \
if len(sample) > 1 else len(sample[0])
)
if self._end and len(self._pool) < self._pool_size:
self._pool.append(EndEpoch())
def push_back(self, samples):
if len(self._pool) != 0:
raise Exception("Pool should be empty.")
if len(samples) >= self._pool_size:
raise Exception("Capacity of pool should be greater than a batch. "
"Please enlarge `pool_size`.")
for sample in samples:
self._pool.append(sample)
self._fill()
def next(self, look=False):
if len(self._pool) == 0:
return None
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, parallel_sentence):
return [
self._converters[i](parallel_sentence[i])
for i in range(len(self._converters))
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, max_len, min_len):
self.i = i
self.min_len = min_len
self.max_len = max_len
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._pool[0] if look else self._pool.pop(0)
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class DataReader(object):
......@@ -140,7 +172,7 @@ class DataReader(object):
fpattern,
batch_size,
pool_size,
sort_type=SortType.NONE,
sort_type=SortType.GLOBAL,
clip_last_batch=True,
tar_fname=None,
min_length=0,
......@@ -170,92 +202,68 @@ class DataReader(object):
self._max_length = max_length
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self._epoch_batches = []
src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname)
self._src_seq_ids = [[
self._src_vocab.get(word, self._src_vocab.get(unk_mark))
for word in ([start_mark] + src_seq + [end_mark])
] for src_seq in src_seq_words]
self._sample_count = len(self._src_seq_ids)
self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
unk_mark)
self._random = random.Random(x=seed)
def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
unk_mark):
converters = [
Converter(
vocab=self._src_vocab,
beg=self._src_vocab[start_mark],
end=self._src_vocab[end_mark],
unk=self._src_vocab[unk_mark],
delimiter=self._token_delimiter)
]
if not self._only_src:
self._trg_seq_ids = [[
self._trg_vocab.get(word, self._trg_vocab.get(unk_mark))
for word in ([start_mark] + trg_seq + [end_mark])
] for trg_seq in trg_seq_words]
if len(self._trg_seq_ids) != self._sample_count:
raise Exception("Inconsistent sample count between "
"source sequences and target sequences.")
else:
self._trg_seq_ids = None
self._sample_idxs = [i for i in xrange(self._sample_count)]
self._sorted = False
random.seed(seed)
def _parse_file(self, f_obj):
src_seq_words = []
trg_seq_words = []
for line in f_obj:
fields = line.strip().split(self._field_delimiter)
if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1):
continue
sample_words = []
is_valid_sample = True
max_len = -1
for i, seq in enumerate(fields):
seq_words = seq.split(self._token_delimiter)
max_len = max(max_len, len(seq_words))
if len(seq_words) == 0 or \
len(seq_words) < self._min_length or \
len(seq_words) > self._max_length or \
(self._use_token_batch and max_len > self._batch_size):
is_valid_sample = False
break
sample_words.append(seq_words)
if not is_valid_sample: continue
src_seq_words.append(sample_words[0])
converters.append(
Converter(
vocab=self._trg_vocab,
beg=self._trg_vocab[start_mark],
end=self._trg_vocab[end_mark],
unk=self._trg_vocab[unk_mark],
delimiter=self._token_delimiter))
converters = ComposedConverter(converters)
self._src_seq_ids = []
self._trg_seq_ids = None if self._only_src else []
self._sample_infos = []
for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
src_trg_ids = converters(line)
self._src_seq_ids.append(src_trg_ids[0])
lens = [len(src_trg_ids[0])]
if not self._only_src:
trg_seq_words.append(sample_words[1])
self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
return (src_seq_words, trg_seq_words)
def _load_data(self, fpattern, tar_fname):
def _load_lines(self, fpattern, tar_fname):
fpaths = glob.glob(fpattern)
src_seq_words = []
trg_seq_words = []
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], 'r')
part_file_data = self._parse_file(f.extractfile(tar_fname))
src_seq_words = part_file_data[0]
trg_seq_words = part_file_data[1]
f = tarfile.open(fpaths[0], "r")
for line in f.extractfile(tar_fname):
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
else:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
part_file_data = self._parse_file(open(fpath, 'r'))
src_seq_words.extend(part_file_data[0])
trg_seq_words.extend(part_file_data[1])
return src_seq_words, trg_seq_words
with open(fpath, "r") as f:
for line in f:
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
......@@ -263,100 +271,52 @@ class DataReader(object):
with open(dict_path, "r") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip('\n')
word_dict[idx] = line.strip("\n")
else:
word_dict[line.strip('\n')] = idx
word_dict[line.strip("\n")] = idx
return word_dict
def _sample_generator(self):
def batch_generator(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
if not self._sorted:
self._sample_idxs.sort(
key=lambda idx: max(len(self._src_seq_ids[idx]),
len(self._trg_seq_ids[idx] if not self._only_src else 0))
)
self._sorted = True
elif self._shuffle:
random.shuffle(self._sample_idxs)
for sample_idx in self._sample_idxs:
if self._only_src:
yield (self._src_seq_ids[sample_idx], )
infos = sorted(
self._sample_infos, key=lambda x: x.max_len, reverse=True)
else:
if self._shuffle:
infos = self._sample_infos
self._random.shuffle(infos)
else:
yield (self._src_seq_ids[sample_idx],
self._trg_seq_ids[sample_idx][:-1],
self._trg_seq_ids[sample_idx][1:])
infos = self._sample_infos
def batch_generator(self):
pool = Pool(self._sample_generator, self._pool_size, True
if self._sort_type == SortType.POOL else False)
def next_batch():
batch_data = []
max_len = -1
batch_max_seq_len = -1
if self._sort_type == SortType.POOL:
for i in range(0, len(infos), self._pool_size):
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size], key=lambda x: x.max_len)
while True:
sample = pool.next(look=True)
# concat batch
batches = []
batch_creator = TokenBatchCreator(
self._batch_size
) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
if sample is None:
pool.push_back(batch_data)
batch_data = []
continue
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if isinstance(sample, EndEpoch):
return batch_data, batch_max_seq_len, True
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
max_len = max(max_len, len(sample[0]))
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._only_src:
max_len = max(max_len, len(sample[1]))
for batch in batches:
batch_ids = [info.i for info in batch]
if self._use_token_batch:
if max_len * (len(batch_data) + 1) < self._batch_size:
batch_max_seq_len = max_len
batch_data.append(pool.next())
else:
return batch_data, batch_max_seq_len, False
else:
if len(batch_data) < self._batch_size:
batch_max_seq_len = max_len
batch_data.append(pool.next())
else:
return batch_data, batch_max_seq_len, False
if not self._shuffle_batch:
batch_data, batch_max_seq_len, last_batch = next_batch()
while not last_batch:
yield batch_data
batch_data, batch_max_seq_len, last_batch = next_batch()
batch_size = len(batch_data)
if self._use_token_batch:
batch_size *= batch_max_seq_len
if (not self._clip_last_batch and len(batch_data) > 0) \
or (batch_size == self._batch_size):
yield batch_data
else:
# should re-generate batches
if self._sort_type == SortType.POOL \
or len(self._epoch_batches) == 0:
self._epoch_batches = []
batch_data, batch_max_seq_len, last_batch = next_batch()
while not last_batch:
self._epoch_batches.append(batch_data)
batch_data, batch_max_seq_len, last_batch = next_batch()
batch_size = len(batch_data)
if self._use_token_batch:
batch_size *= batch_max_seq_len
if (not self._clip_last_batch and len(batch_data) > 0) \
or (batch_size == self._batch_size):
self._epoch_batches.append(batch_data)
random.shuffle(self._epoch_batches)
for batch_data in self._epoch_batches:
yield batch_data
if self._only_src:
yield [[self._src_seq_ids[idx]] for idx in batch_ids]
else:
yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]) for idx in batch_ids]
import os
import time
import argparse
import ast
import numpy as np
import multiprocessing
import os
import time
from functools import partial
import paddle
import numpy as np
import paddle.fluid as fluid
import reader
from config import *
from model import transformer, position_encoding_init
from optim import LearningRateScheduler
from config import *
import reader
def parse_args():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册