未验证 提交 c34bb5f1 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1060 from reyoung/speed_up_transformer_python_reader

Speed up NTM reader
.DS_Store .DS_Store
*.pyc *.pyc
.*~ .*~
fluid/neural_machine_translation/transformer/deps
fluid/neural_machine_translation/transformer/train.data
fluid/neural_machine_translation/transformer/train.pkl
fluid/neural_machine_translation/transformer/train.sh
fluid/neural_machine_translation/transformer/train.tok.clean.bpe.32000.en-de
fluid/neural_machine_translation/transformer/vocab.bpe.32000.refined
...@@ -553,7 +553,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece): ...@@ -553,7 +553,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
def infer(args, inferencer=fast_infer): def infer(args, inferencer=fast_infer):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
test_data = reader.DataReader( test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath, src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath,
......
...@@ -474,6 +474,7 @@ def transformer( ...@@ -474,6 +474,7 @@ def transformer(
sum_cost = layers.reduce_sum(weighted_cost) sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights) token_num = layers.reduce_sum(weights)
avg_cost = sum_cost / token_num avg_cost = sum_cost / token_num
avg_cost.stop_gradient = True
return sum_cost, avg_cost, predict, token_num return sum_cost, avg_cost, predict, token_num
......
import os
import tarfile
import glob import glob
import os
import random import random
import tarfile
import cPickle
class SortType(object): class SortType(object):
...@@ -11,54 +11,86 @@ class SortType(object): ...@@ -11,54 +11,86 @@ class SortType(object):
NONE = "none" NONE = "none"
class EndEpoch(): class Converter(object):
pass def __init__(self, vocab, beg, end, unk, delimiter):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
def __call__(self, sentence):
return [self._beg] + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + [self._end]
class Pool(object):
def __init__(self, sample_generator, pool_size, sort): class ComposedConverter(object):
self._pool_size = pool_size def __init__(self, converters):
self._pool = [] self._converters = converters
self._sample_generator = sample_generator()
self._end = False def __call__(self, parallel_sentence):
self._sort = sort return [
self._converters[i](parallel_sentence[i])
def _fill(self): for i in range(len(self._converters))
while len(self._pool) < self._pool_size and not self._end: ]
try:
sample = self._sample_generator.next()
self._pool.append(sample) class SentenceBatchCreator(object):
except StopIteration as e: def __init__(self, batch_size):
self._end = True self.batch = []
break self._batch_size = batch_size
if self._sort: def append(self, info):
self._pool.sort( self.batch.append(info)
key=lambda sample: max(len(sample[0]), len(sample[1])) \ if len(self.batch) == self._batch_size:
if len(sample) > 1 else len(sample[0]) tmp = self.batch
) self.batch = []
return tmp
if self._end and len(self._pool) < self._pool_size:
self._pool.append(EndEpoch())
class TokenBatchCreator(object):
def push_back(self, samples): def __init__(self, batch_size):
if len(self._pool) != 0: self.batch = []
raise Exception("Pool should be empty.") self.max_len = -1
self._batch_size = batch_size
if len(samples) >= self._pool_size:
raise Exception("Capacity of pool should be greater than a batch. " def append(self, info):
"Please enlarge `pool_size`.") cur_len = info.max_len
max_len = max(self.max_len, cur_len)
for sample in samples: if max_len * (len(self.batch) + 1) > self._batch_size:
self._pool.append(sample) result = self.batch
self.batch = [info]
self._fill() self.max_len = cur_len
return result
def next(self, look=False): else:
if len(self._pool) == 0: self.max_len = max_len
return None self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, max_len, min_len):
self.i = i
self.min_len = min_len
self.max_len = max_len
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else: else:
return self._pool[0] if look else self._pool.pop(0) return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class DataReader(object): class DataReader(object):
...@@ -140,7 +172,7 @@ class DataReader(object): ...@@ -140,7 +172,7 @@ class DataReader(object):
fpattern, fpattern,
batch_size, batch_size,
pool_size, pool_size,
sort_type=SortType.NONE, sort_type=SortType.GLOBAL,
clip_last_batch=True, clip_last_batch=True,
tar_fname=None, tar_fname=None,
min_length=0, min_length=0,
...@@ -170,92 +202,68 @@ class DataReader(object): ...@@ -170,92 +202,68 @@ class DataReader(object):
self._max_length = max_length self._max_length = max_length
self._field_delimiter = field_delimiter self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter self._token_delimiter = token_delimiter
self._epoch_batches = [] self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
unk_mark)
src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname) self._random = random.Random(x=seed)
self._src_seq_ids = [[
self._src_vocab.get(word, self._src_vocab.get(unk_mark)) def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
for word in ([start_mark] + src_seq + [end_mark]) unk_mark):
] for src_seq in src_seq_words] converters = [
Converter(
self._sample_count = len(self._src_seq_ids) vocab=self._src_vocab,
beg=self._src_vocab[start_mark],
end=self._src_vocab[end_mark],
unk=self._src_vocab[unk_mark],
delimiter=self._token_delimiter)
]
if not self._only_src: if not self._only_src:
self._trg_seq_ids = [[ converters.append(
self._trg_vocab.get(word, self._trg_vocab.get(unk_mark)) Converter(
for word in ([start_mark] + trg_seq + [end_mark]) vocab=self._trg_vocab,
] for trg_seq in trg_seq_words] beg=self._trg_vocab[start_mark],
if len(self._trg_seq_ids) != self._sample_count: end=self._trg_vocab[end_mark],
raise Exception("Inconsistent sample count between " unk=self._trg_vocab[unk_mark],
"source sequences and target sequences.") delimiter=self._token_delimiter))
else:
self._trg_seq_ids = None converters = ComposedConverter(converters)
self._sample_idxs = [i for i in xrange(self._sample_count)] self._src_seq_ids = []
self._sorted = False self._trg_seq_ids = None if self._only_src else []
self._sample_infos = []
random.seed(seed)
for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
def _parse_file(self, f_obj): src_trg_ids = converters(line)
src_seq_words = [] self._src_seq_ids.append(src_trg_ids[0])
trg_seq_words = [] lens = [len(src_trg_ids[0])]
for line in f_obj:
fields = line.strip().split(self._field_delimiter)
if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1):
continue
sample_words = []
is_valid_sample = True
max_len = -1
for i, seq in enumerate(fields):
seq_words = seq.split(self._token_delimiter)
max_len = max(max_len, len(seq_words))
if len(seq_words) == 0 or \
len(seq_words) < self._min_length or \
len(seq_words) > self._max_length or \
(self._use_token_batch and max_len > self._batch_size):
is_valid_sample = False
break
sample_words.append(seq_words)
if not is_valid_sample: continue
src_seq_words.append(sample_words[0])
if not self._only_src: if not self._only_src:
trg_seq_words.append(sample_words[1]) self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
return (src_seq_words, trg_seq_words) def _load_lines(self, fpattern, tar_fname):
def _load_data(self, fpattern, tar_fname):
fpaths = glob.glob(fpattern) fpaths = glob.glob(fpattern)
src_seq_words = []
trg_seq_words = []
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]): if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None: if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.") raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], 'r') f = tarfile.open(fpaths[0], "r")
part_file_data = self._parse_file(f.extractfile(tar_fname)) for line in f.extractfile(tar_fname):
src_seq_words = part_file_data[0] fields = line.strip("\n").split(self._field_delimiter)
trg_seq_words = part_file_data[1] if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
else: else:
for fpath in fpaths: for fpath in fpaths:
if not os.path.isfile(fpath): if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath) raise IOError("Invalid file: %s" % fpath)
part_file_data = self._parse_file(open(fpath, 'r')) with open(fpath, "r") as f:
src_seq_words.extend(part_file_data[0]) for line in f:
trg_seq_words.extend(part_file_data[1]) fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
return src_seq_words, trg_seq_words self._only_src and len(fields) == 1):
yield fields
@staticmethod @staticmethod
def load_dict(dict_path, reverse=False): def load_dict(dict_path, reverse=False):
...@@ -263,100 +271,52 @@ class DataReader(object): ...@@ -263,100 +271,52 @@ class DataReader(object):
with open(dict_path, "r") as fdict: with open(dict_path, "r") as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
if reverse: if reverse:
word_dict[idx] = line.strip('\n') word_dict[idx] = line.strip("\n")
else: else:
word_dict[line.strip('\n')] = idx word_dict[line.strip("\n")] = idx
return word_dict return word_dict
def _sample_generator(self): def batch_generator(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL: if self._sort_type == SortType.GLOBAL:
if not self._sorted: infos = sorted(
self._sample_idxs.sort( self._sample_infos, key=lambda x: x.max_len, reverse=True)
key=lambda idx: max(len(self._src_seq_ids[idx]), else:
len(self._trg_seq_ids[idx] if not self._only_src else 0)) if self._shuffle:
) infos = self._sample_infos
self._sorted = True self._random.shuffle(infos)
elif self._shuffle:
random.shuffle(self._sample_idxs)
for sample_idx in self._sample_idxs:
if self._only_src:
yield (self._src_seq_ids[sample_idx], )
else: else:
yield (self._src_seq_ids[sample_idx], infos = self._sample_infos
self._trg_seq_ids[sample_idx][:-1],
self._trg_seq_ids[sample_idx][1:])
def batch_generator(self): if self._sort_type == SortType.POOL:
pool = Pool(self._sample_generator, self._pool_size, True for i in range(0, len(infos), self._pool_size):
if self._sort_type == SortType.POOL else False) infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size], key=lambda x: x.max_len)
def next_batch():
batch_data = []
max_len = -1
batch_max_seq_len = -1
while True: # concat batch
sample = pool.next(look=True) batches = []
batch_creator = TokenBatchCreator(
self._batch_size
) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
if sample is None: for info in infos:
pool.push_back(batch_data) batch = batch_creator.append(info)
batch_data = [] if batch is not None:
continue batches.append(batch)
if isinstance(sample, EndEpoch): if not self._clip_last_batch and len(batch_creator.batch) != 0:
return batch_data, batch_max_seq_len, True batches.append(batch_creator.batch)
max_len = max(max_len, len(sample[0])) if self._shuffle_batch:
self._random.shuffle(batches)
if not self._only_src: for batch in batches:
max_len = max(max_len, len(sample[1])) batch_ids = [info.i for info in batch]
if self._use_token_batch: if self._only_src:
if max_len * (len(batch_data) + 1) < self._batch_size: yield [[self._src_seq_ids[idx]] for idx in batch_ids]
batch_max_seq_len = max_len else:
batch_data.append(pool.next()) yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
else: self._trg_seq_ids[idx][1:]) for idx in batch_ids]
return batch_data, batch_max_seq_len, False
else:
if len(batch_data) < self._batch_size:
batch_max_seq_len = max_len
batch_data.append(pool.next())
else:
return batch_data, batch_max_seq_len, False
if not self._shuffle_batch:
batch_data, batch_max_seq_len, last_batch = next_batch()
while not last_batch:
yield batch_data
batch_data, batch_max_seq_len, last_batch = next_batch()
batch_size = len(batch_data)
if self._use_token_batch:
batch_size *= batch_max_seq_len
if (not self._clip_last_batch and len(batch_data) > 0) \
or (batch_size == self._batch_size):
yield batch_data
else:
# should re-generate batches
if self._sort_type == SortType.POOL \
or len(self._epoch_batches) == 0:
self._epoch_batches = []
batch_data, batch_max_seq_len, last_batch = next_batch()
while not last_batch:
self._epoch_batches.append(batch_data)
batch_data, batch_max_seq_len, last_batch = next_batch()
batch_size = len(batch_data)
if self._use_token_batch:
batch_size *= batch_max_seq_len
if (not self._clip_last_batch and len(batch_data) > 0) \
or (batch_size == self._batch_size):
self._epoch_batches.append(batch_data)
random.shuffle(self._epoch_batches)
for batch_data in self._epoch_batches:
yield batch_data
import os
import time
import argparse import argparse
import ast import ast
import numpy as np
import multiprocessing import multiprocessing
import os
import time
from functools import partial from functools import partial
import paddle import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import reader
from config import *
from model import transformer, position_encoding_init from model import transformer, position_encoding_init
from optim import LearningRateScheduler from optim import LearningRateScheduler
from config import *
import reader
def parse_args(): def parse_args():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册